[D] Quality of VAE embeddings, depending on likelihood function?
tl;dr: Which log-likelihood function should I use for training VQ-VAE, when I only care about the embeddings?
Since no one responded on r/MLQuestions, this might be better suited here:
I have a question concerning the use of VAEs for encoding, as opposed to using them for data generation. In particular, I want to use a vector-quantising variational autoencoder to find a discrete representation of continuous data for some downstream task. I wonder if the choice of the decoder likelihood function would have a noticeable impact on the quality of the discrete representation.
The objective for a batch size of 1 has the following form (omitting some VQ-VAE specific terms):
max log p_dec(x|z_enc) – KL( q(z) || p(z) )
where x is a training sample, z_enc is a latent random value generated by the encoder, p_dec is the likelihood function for the decoder, q is the posterior estimate of the decoder over latent variable z and p(z) is the actual prior.
When I assume p_dec(X | z_enc) to be a multivariate normal distribution where the mean is given by some neural network and the covariance is an identity matrix, I can replace the log-likelihood term with the negative mean squared error, as in normal regression. This is what I’ve seen being used in some implementations.
But I could also let the decoder output an arbitrary covariance matrix. This would of course change the log-likelihood function.
Do you think it makes sense to use a more involved log-likelihood function (i.e. arbitrary pos. semi-definite covariance matrix), so that the encoder is forced to find a representation that is better for explaining data coming from a complex distribution? Do you know of any non-domain-specific papers investigating the use of VQ-VAEs for encoding?