Affine couplings

Author

Kiize

Published

January 30, 2026

Given a field configuration \(\phi\), affine coupling layers \(g_{i}\) return a masked field \(\phi'\).

We start by dividing the field \(\phi\) into two halves \(\phi_{a}, \phi_{b}\), then the affine coupling \(g_{i}\) acts as follows: \[ \begin{cases} \phi_{a} & \to \phi_{a} & \; (\text{one half in invariant}) \\ \phi_{b} & \to z_{b} = \phi_{b} \,\odot \, e^{ s_{i}(\phi_{a}) } + t_{i}\phi_{a}, \end{cases} \] where \(\odot\) is the element wise multiplication and \(s_{i}, t_{i}\) are parameters of the affine coupling. After every layer we invert \(\phi_{a}\) and \(\phi_{b}\), that is we invert who stays invariant and who does not.

Jacobian

We have said in Flow Based MCMC that we need to compute the jacobian \[ \left|\det \frac{ \partial f(\phi) }{ \partial \phi } \right|. \] Recalling that by definition \[ f(\phi) = g_{1} \circ g_{2} \circ \dots g_{n}(\phi), \] we have that the final jacobian is just the product of the single jacobian, so our task is to compute \[ \left| \frac{ \partial g_{i}(\phi) }{ \partial \phi } \right|. \] We first notice that half of the coupling is just the identity \((\phi_{a}\to \phi_{a})\), so we just need to derive the other mapping which, because we are performing an element wise multiplication, is simply \[ \frac{ \partial g_{i} }{ \partial (\phi_{b})_{j} } = e^{ (s_{i})_{j} }. \] Putting all together, if \(\phi\) can be rearranged in a \(D\)-dimensional vector, we have \[ \left| \frac{ \partial g_{i}(\phi) }{ \partial \phi } \right| = \prod_{j=1}^{D/2} e^{ (s_{i}(\phi))_{j} } \]

def coupling_layer_cnn(params, phi, mask, st_net):
    """
    Parameters
    ----------
    params: 
        weights of the st_net CNN for the specific layer.
    phi: array (batch, L, L)
        input field configuration
    mask: array (batch, L, L)
        binary mask to split the field phi
    st_net: STNet
        instance of the STNet class

    Returns
    -------
    jnp.squeeze(phi_new, axis=-1): array (batch, L, L)
        transformed field configuration
    log_det: array (batch,)
        log determinant of the Jacobian of the transformation
    """

    # CNNs in flax want another dimension for channels, so we add None to phi and mask.
    phi_input = phi[..., None] 
    mask_input = mask[..., None]
    
    # masked field
    phi_masked = phi_input * mask_input
    
    # CNN applies params to our masked field getting s and t
    s, t = st_net.apply({'params': params}, phi_masked)
    
    # We want to apply the transformation only to half field
    s = s * (1.0 - mask_input)
    t = t * (1.0 - mask_input)
    
    # Transformation
    phi_new = phi_input * mask_input + (phi_input * jnp.exp(s) + t) * (1.0 - mask_input)
    
    # Log of the determinant of the jacobian is the sum over s
    log_det = jnp.sum(s, axis=(1, 2, 3))

    # we remove the channel dimension before returning the new field
    return jnp.squeeze(phi_new, axis=-1), log_det