Categorical Reparameterization with Gumbel-Softmax

vae

(Leo) #1

Essence

Contribution

  • Latent variable 이 categorical 분포를 따를 때, Gumbel-softmax 로 근사해서 backpropagate 가능한 방법을 제시 (★★)
  • Discrete random variable 을 가지는 stochastic neural network 에 reparametrization trick 적용 가능한 continuous relaxation 방법을 제시
  • 위 방법을 통해 Kingma(2014)에서 제시한 semi-supervised model 에서 categorical latent variable 에 대한 비효율적인 marginalization 계산을 회피하고 학습이 가능함을 실험적으로 보임 (★)
  • DeepMind 의 Maddison 논문과 같은 내용 (Eric Jang 블로그 참조)

Key Ideas

  • 연속확률분포(continuous random variable)인 Gumbel 분포softmax 함수를 이용해 범주형 변수(categorical variable)를 근사하는 방법을 제시
  • 이를 이용해 이산확률분포(discrete random variable)를 따르는 잠재변수(latent variable)에 대한 함수의 parameter gradient 를 계산하는 reparameterization trick 을 소개

Gumbel-Softmax Distribution

i 번째 coordinate 만 1 이고 나머지는 0 인 k-차원 one_hot 벡터를 \mathbf{e}_{i} 로 표기하겠습니다. 이 때 확률함수 z 는 확률분포 \pi = (\pi_{1},\ldots,\pi_{k}) 에 의해 \pi_{i} 의 확률로 one_hot 벡터 \mathbf{e}_{i} 로 매핑되는 categorical 변수라 정의하겠습니다:

z = \begin{cases} \mathbf{e}_{1} & \text{with probability } \pi_{1} \\ \vdots & \quad \vdots \\ \mathbf{e}_{i} & \text{with probability } \pi_{i} \\ \vdots & \quad \vdots \\ \mathbf{e}_{k} & \text{with probability } \pi_{k} \end{cases}

즉, \mathbb{E}_{\pi}[z]=\sum_{i=1}^{k}\pi_{i}\mathbf{e}_{i} = (\pi_{1},\ldots, \pi_{k}) 가 됩니다. 이 때 확률변수 z 를 sampling 할 수 있는 방법이 다음과 같은 Gumbel-max 분포를 사용하는 것입니다:

\hat{z} = \text{one_hot}\left( \underset{i}{\text{arg max}}[g_{i} + \log \pi_{i} ] \right),\quad g_{i}=- \log(-\log u_{i}),\quad u_{i}\sim\text{Uniform}(0,1)

Gumbel 이라는 이름의 어원은 위의 식에서 noise 를 더해주는 g_{i} 가 따르는 분포의 이름이 Gumbel 분포이기 때문입니다. Gumbel 분포는 아래와 같이 실수 전체집합에 정의되어 있고 Euler–Mascheroni 상수를 평균으로 가지고, positive-skew 된 모양을 가집니다. 그래서 \pi_{i} 의 값이 높을수록 보다 많이 sampling 되지만 negative g_{i} 값도 sampling 될 가능성도 대략 37% 정도 되어서 stochastic 하게 확률분포 \pi 를 근사하게 됩니다.

Gumbel 분포는 이러한 작동 원리로 discrete random variable 을 가지는 stochastic computation graph 에서 reparametrization trick 을 적용할 때 Gaussian 분포를 대체하기 위해 사용될 겁니다.

그러나 \hat{z} 을 직접 사용하기엔 argmax 의 특성 때문에 boundary 를 제외하고 \pi 에 대한 미분이 모두 0 이 되므로 적절하지 않습니다. 그래서 논문에서는 \hat{z} 대신에 아래와 같이 argmaxsoftmax 함수로 relaxation 한 sampling 방법을 제시합니다 (Maddison 의 논문에선 이 확률변수를 Concrete random variable 이라고 부릅니다):

\zeta_{i} =\text{softmax}[(g + \log \pi)/\tau]_{i}= \frac{\exp[ ( g_{i} + \log\pi_{i} )/\tau]}{\sum_{j=1}^{k}\exp[ ( g_{j} + \log\pi_{j} )/\tau]} \quad\cdots\quad (\star)

(\star) 을 좀 더 자세히 설명하면, softmax 함수는 argmax 를 모델링하는 함수입니다. 즉, 식 (\star) 는 위의 \hat{z} 에서 one_hot 함수 안의 argmax 부분을 approximate 합니다. 그리고 temperature \tau 가 0 으로 갈수록 \zeta = (\zeta_{1},\ldots, \zeta_{k}) 의 확률분포는 Gumbel 분포 g_{i} 의 effect 를 무시하고 확률분포 \pi 로 approximate 하게 됩니다. 반대로 \tau 가 무한대로 가는 경우 \pi 의 effect 를 무시하고 uniform 분포로 approximate 하게 됩니다:

\mathbb{P}( \lim_{\tau \to 0} \zeta_{i} = 1) = \frac{\pi_{i}}{\sum_{j=1}^{k}\pi_{j}}, \quad \mathbb{P}( \lim_{\tau \to \infty} \zeta_{i} = 1) = \frac{1}{k}

아래 그림을 보시면 \tau 의 크기에 따른 \zeta 의 확률분포의 변화를 좀 더 명확하게 이해할 수 있습니다:

Temperature \tau >0 이면 식 (\star) 에서 \zeta_{i}\pi 에 대해 미분가능하므로 gradient \nabla_{\pi}\zeta 계산이 가능해집니다. 그러므로 discrete sampling 대신 Gumbel-softmax trick 을 사용해서 gradient method 를 적용할 수 있습니다. 한 편 \tau 는 hyperparameter 로써 학습 전에 정해주거나 scheduling 을 통해 regularizer 로 사용할 수 있습니다. 일반적으로 \tau 가 낮으면 \zetaone_hot 에 가까워지지만 gradient 의 variance 는 증가하게 됩니다.

Reparametrization Trick via Gumbel-softmax

위에서 소개한 Gumbel-softmax 를 사용해서 discrete VAE 에서 사용하는 방법을 다음과 같습니다:

  • \mathbf{x}: real data
  • \mathbf{z}: one_hot 벡터 \{\mathbf{e}_{1},\ldots, \mathbf{e}_{k}\} 로 mapping 되는 discrete latent variable
  • P_{\alpha}(\mathbf{z}): \alpha 를 parameter 로 가지는 \mathbf{z} 위의 prior
  • Q_{\phi}(\mathbf{z} | \mathbf{x}): \phi 를 parameter 로 가지는 encoder 또는 inference network
  • p_{\theta}(\mathbf{x}|\mathbf{z}): \theta 를 parameter 로 가지는 decoder
  • p_{\theta}(\mathbf{x}, \mathbf{z})=p_{\theta}(\mathbf{x}|\mathbf{z}) P_{\alpha}(\mathbf{z}): generative model
  • \log p_{\theta}(\mathbf{x})= \log \int p_{\theta}(\mathbf{x},\mathbf{z})d\mathbf{z}: marginal likelihood

기존의 VAE 처럼 \log p_{\theta}(\mathbf{x}) 의 lower-bound 인 ELBO \mathcal{L}_{\theta, \phi}(\mathbf{x}) 는 아래와 같이 계산됩니다

\log p_{\theta}(\mathbf{x}) \geq \mathcal{L}_{\theta, \phi}( \mathbf{x}) = \mathbb{E}_{Z \sim Q_{\phi}(\mathbf{z}|\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x},Z)}{Q_{\phi}(Z|\mathbf{x})}\right] = \mathbb{E}_{Z \sim Q_{\phi}(\mathbf{z}|\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x}|Z)P_{\alpha}(Z)}{Q_{\phi}(Z|\mathbf{x})}\right]

여기서 P_{\alpha}, Q_{\phi} 는 discrete 확률변수인 Z 위의 (unnormalized) categorical 확률분포입니다. Kingma 의 VAE 에선 \nabla_{\phi} \mathcal{L}_{\theta,\phi} 을 계산할 때 Z 대신 \widetilde{Z} = g_{\phi}(\epsilon, \mathbf{x})change of variable 을 적용해서 stochasticity 를 Z 가 아닌 \epsilon 으로 대체했습니다. Discrete VAE 의 경우 \nabla_{\phi} 을 계산하는 것이 어렵기 때문에 위에서 소개한 Gumbel-softmax trick 을 적용해야 합니다:

P_{\alpha}(\mathbf{z})\overset{\text{relax}}{\rightsquigarrow}p_{\alpha, \tau}(\mathbf{\zeta}), \quad Q_{\phi}(\mathbf{z}|\mathbf{x}) \overset{\text{relax}}{\rightsquigarrow} q_{\phi,\tau}(\zeta|\mathbf{x})

\tau 는 Gumbel-softmax trick 으로 relaxation 할 때 사용되는 temperature 입니다. 기대값으로 정의된 ELBO 역시 아래와 같이 relaxation 하게 됩니다.

\mathcal{L}_{\theta, \phi}(\mathbf{x}) \overset{\text{relax}}{\rightsquigarrow} \mathcal{L}_{\theta, \phi,\tau}(\mathbf{x}) = \mathbb{E}_{\zeta \sim q_{\phi, \tau}(\zeta|\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x}|\zeta)p_{\alpha, \tau}(\zeta)}{q_{\phi, \tau}(\zeta|\mathbf{x})}\right]

그러므로 \mathcal{L}_{\theta, \phi} 의 gradient 는 \mathcal{L}_{\theta,\phi,\tau} 의 gradient 로 approximate 할 수 있고, 다시 \mathcal{L}_{\theta,\phi,\tau} 의 gradient 는 reparametrization trick 에 의해 Stochastic Gradient Variational Bayes (SGVB) estimator \widetilde{\mathcal{L}}^{A}_{\theta,\phi,\tau}(\mathbf{x}) 의 gradient 로 연산할 수 있습니다:

\mathcal{L}_{\theta,\phi,\tau}(\mathbf{x}) \approx \widetilde{\mathcal{L}}^{A}_{\theta,\phi,\tau}(\mathbf{x}) = \frac{1}{N}\sum_{n=1}^{N} \log \left(\frac{p_{\theta}(\mathbf{x}|\zeta^{(n)})p_{\alpha, \tau}(\zeta^{(n)})}{q_{\phi, \tau}(\zeta^{(n)}|\mathbf{x})}\right)
\nabla_{\theta}\widetilde{\mathcal{L}}^{A}_{\theta,\phi,\tau}(\mathbf{x}) = \frac{1}{N}\sum_{n=1}^{N} \nabla_{\theta}\log p_{\theta}(\mathbf{x}|\zeta^{(n)}), \quad \nabla_{\phi}\widetilde{\mathcal{L}}^{A}_{\theta,\phi,\tau}(\mathbf{x}) = - \frac{1}{N}\sum_{n=1}^{N} \nabla_{\phi}\log q_{\phi, \tau}(\zeta^{(n)}|\mathbf{x})

위 과정을 computation graph 로 그리면 다음과 같이 됩니다. 빨간색 화살표들은 back-propagation 흐름을 나타내고 g 는 Gumbel 분포에 의한 stochastic node 를 말합니다:

Straight-Through Gumbel-softmax estimator

Discrete VAE 내지 RL 에서 discrete action space 을 상정하면 Z 가 오로지 discrete value 로 sampling 되어야 하는 조건이 필요할수도 있습니다. 이 경우 Gumbel-softmax trick 으로 Z 의 확률분포를 \zeta 로 relaxation 해도 대입할 땐 argmax 로 다시 discrete 하게 변환해야 합니다. 이 경우 gradient 를 계산할 땐 argmax 를 건너뛰고 \nabla \zeta 를 back-prop 할 때 전달해주는 방법이 ST Gumbel estimator 입니다. Y. Bengio 의 2013년도 논문 이나 Hierarchical Multiscale RNN 과 유사한 방법인데, ST Gumbel estimator 의 variance 가 더 안정적이라고 주장합니다.

Results

C. Maddison 과 E. Jang 의 각각의 논문에서 Gumbel-softmax trick 을 사용한 비교실험들을 보여줍니다. Maddison 의 논문에선 주로 VIMCO(Variational Inference for Monte-Carlo Objectives) 와 비교하고 Jang 의 논문에선 다른 방법들과 비교하고 있습니다: