2021. 3. 6. 22:35ㆍ딥러닝
github code 출처 : github/AntixK/PyTorch-VAE
def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
전체 beta-VAE의 구조.
encode->reparam->decode.
여기서 encode의 output : mu(128 dim), log_var(128 dim)
z : 128 dim(yaml file을 보면 됨.)
encode의 input size : [N C H W].
model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
loss_type: 'B'
gamma: 10.0
max_capacity: 25
Capacity_max_iter: 10000
exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.0005
weight_decay: 0.0
scheduler_gamma: 0.95
trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 50
logging_params:
save_dir: "logs/"
name: "BetaVAE_B"
manual_seed: 1265
def loss_function(self,
*args,
**kwargs) -> dict:
self.num_iter += 1
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
loss = recons_loss + self.beta * kld_weight * kld_loss
elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
self.C_max = self.C_max.to(input.device)
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
else:
raise ValueError('Undefined loss type.')
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
loss_type이 'H'인것만 보면 됨. 'B'도 조금 비슷함.(자세히 볼필요 없는듯?)
$M_N = \frac{M}{N}$ 이고, $M$ 은 latent space의 dim(128)이고, $N$ 은 input image의 pixel갯수($64\times64$ ). 아래 $\beta$ -VAE논문 참조.
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
위는 VAE의 loss function.
$\overrightarrow{z}=(z_1, z_2, \cdots, z_M)$이라고 할때, $\overrightarrow{z}$ ~ $N(0,I)$가 아니고, 각각 $z_i$ ~ $N(0,1)$임.
논문 Loss :
$$
L(\theta, \phi, x, z, \beta) = E_{q_{\phi}(z|x)}[\log_{q_\phi}(x|z)]-\beta D_{KL}(q_{\phi}(z|x)|p(z))
$$
and
$$E_{q_{\phi}(z|x)}[\log_{q_\phi}(x|z)] = E_{q_{\phi}(z|x)}[C(x|z - x)^2]
$$
for some $C$.
- Note that : if $X$ ~ $N(\mu, \sigma^2)$, then Entropy of $X$, i.e. $H(X)=E[-\log X]=\log \sqrt{2\pi e \sigma^2}$
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
reparam에서 decode의 gradient encode로 잘 흘러들어가게끔 만들어야 함.
아래는 $\beta$-VAE 논문에서 발췌
$\beta$-VAE vs VAE vs InfoGAN
Azimuth : 방위각
VAE를 보면 얼굴각도가 돌아가면서 표정(웃음->무표정)도 변화함. disentangle이 잘 안됨.
논문에서 $\beta=1$이면 VAE랑 같다고 함. 논문은 disentangle이 잘 되었는지 확인하는 disentanglement metric을 제시하고 있지만, 안봄. 크게 중요하지 않은듯(?). 주로 $\beta>1$로 설정.
$\beta$의 값이 trade-off between reconstruction fidelity(information preservation) and the quality of disentanglement within the learnt latent representations(latent channel capacity restriction)라고 함.
의문점
- $N(0,I)$는 $z_i$간에 independent를 알수 있는데, 왜 코드는 각각 $N(0,1)$이지? 이러면 independent가 아닌데...
- 사실 자세히 읽어보지 않음.
'딥러닝' 카테고리의 다른 글
얀 르쿤 페이스북 요약. Self-supervised learning: NLP vs VISION (0) | 2021.03.08 |
---|---|
NeRF-Neural Radiance Field (0) | 2021.03.06 |
Meta Pseudo Label (0) | 2021.03.06 |
GPU Memory Consumption of DL Models (0) | 2021.03.06 |
FixMatch (0) | 2021.03.06 |