Affine couplings
Given a field configuration \(\phi\), affine coupling layers \(g_{i}\) return a masked field \(\phi'\).
We start by dividing the field \(\phi\) into two halves \(\phi_{a}, \phi_{b}\), then the affine coupling \(g_{i}\) acts as follows: \[ \begin{cases} \phi_{a} & \to \phi_{a} & \; (\text{one half in invariant}) \\ \phi_{b} & \to z_{b} = \phi_{b} \,\odot \, e^{ s_{i}(\phi_{a}) } + t_{i}\phi_{a}, \end{cases} \] where \(\odot\) is the element wise multiplication and \(s_{i}, t_{i}\) are parameters of the affine coupling. After every layer we invert \(\phi_{a}\) and \(\phi_{b}\), that is we invert who stays invariant and who does not.
Jacobian
We have said in Flow Based MCMC that we need to compute the jacobian \[ \left|\det \frac{ \partial f(\phi) }{ \partial \phi } \right|. \] Recalling that by definition \[ f(\phi) = g_{1} \circ g_{2} \circ \dots g_{n}(\phi), \] we have that the final jacobian is just the product of the single jacobian, so our task is to compute \[ \left| \frac{ \partial g_{i}(\phi) }{ \partial \phi } \right|. \] We first notice that half of the coupling is just the identity \((\phi_{a}\to \phi_{a})\), so we just need to derive the other mapping which, because we are performing an element wise multiplication, is simply \[ \frac{ \partial g_{i} }{ \partial (\phi_{b})_{j} } = e^{ (s_{i})_{j} }. \] Putting all together, if \(\phi\) can be rearranged in a \(D\)-dimensional vector, we have \[ \left| \frac{ \partial g_{i}(\phi) }{ \partial \phi } \right| = \prod_{j=1}^{D/2} e^{ (s_{i}(\phi))_{j} } \]
def coupling_layer_cnn(params, phi, mask, st_net):
"""
Parameters
----------
params:
weights of the st_net CNN for the specific layer.
phi: array (batch, L, L)
input field configuration
mask: array (batch, L, L)
binary mask to split the field phi
st_net: STNet
instance of the STNet class
Returns
-------
jnp.squeeze(phi_new, axis=-1): array (batch, L, L)
transformed field configuration
log_det: array (batch,)
log determinant of the Jacobian of the transformation
"""
# CNNs in flax want another dimension for channels, so we add None to phi and mask.
phi_input = phi[..., None]
mask_input = mask[..., None]
# masked field
phi_masked = phi_input * mask_input
# CNN applies params to our masked field getting s and t
s, t = st_net.apply({'params': params}, phi_masked)
# We want to apply the transformation only to half field
s = s * (1.0 - mask_input)
t = t * (1.0 - mask_input)
# Transformation
phi_new = phi_input * mask_input + (phi_input * jnp.exp(s) + t) * (1.0 - mask_input)
# Log of the determinant of the jacobian is the sum over s
log_det = jnp.sum(s, axis=(1, 2, 3))
# we remove the channel dimension before returning the new field
return jnp.squeeze(phi_new, axis=-1), log_det