..

Mixture Density Network에 대한 간단한 정리와 구현

MDN

강병규

오늘은 일반적인 딥러닝 네트워크의 변형인 Mixture Density network에 대해서 알아보고 간단한 파이토치(Pytorch) 코드로 구현해보겠습니다. 어려운 내용이라 저도 잘못 이해하고 있는 부분이 있을 수 있으니 피드백 환영합니다. 코드는 이 링크에서 가져왔습니다.

들어가며

일반적인 딥러닝 네트워크를 생각해봅시다. 이 네트워크는 어떠한 입력이 주어졌을 때 이에 해당하는 출력을 만들어낼 겁니다. 이를 다시 표현하자면 $p(y \mid x,\theta)$의 조건부확률로 표현할 수 있죠. 이때 $\theta$는 모델의 파라미터이며, x는 네트워크의 입력, y는 출력입니다.

우리가 만약 분류, Classification을 수행한다면 이 네트워크는 아마 입력이 주어졌을 때 각 y, 곧 어떤 class에 대응할 확률을 소프트맥스(Softmax) 등을 사용해서 만들어 낼 것입니다. 곧 y는 label이라고 생각할 수 있습니다. 마찬가지로 회귀, Regression을 수행한다고 생각한다면 y는 이제 연속적인 값을 갖게 될 것입니다. 이때 y가 분류 문제저럼 이산적인 경우는 크게 문제가 되지 않지만, 연속적인 경우에서 $p(y \mid x)$를 이런 식으로만 표현할 경우에는 제약이 발생합니다.

간단한 선형회귀 문제를 풀어봅시다. 어떤 입력 x에 대해서 대응하는 y값들이 존재할 것이고, 이들은 아마 연속적인 실수 값들일 것입니다. Mean squared error로 선형회귀를 하게 되면 아마 우리는 주어진 점들을 가장 잘 설명하는 직선을 얻을 수 있을 것입니다.

image

데이터들이 갖는 원래의 함수를 $f(x)$라고 하고, 우리가 예측한 직선을 $\hat{f}(x)$라고 한다면 $f(x) = \hat{f}(x) + \epsilon$이라고 할 수 있습니다. 모든 점들에 완전히 적합하는 직선은 존재할 수가 없고 필연적으로 오차가 발생할 수 밖에 없습니다. 그리고 이러한 오차는 보통 정규분포를 갖는다고 가정합니다. 이를 다시 다른 식으로 표현하면 $y \mid x \sim N(w^Tx, \sigma^2)$라고 생각할 수 있습니다. $w$는 각각의 가중치를 의미하구요. 조금 더 복잡한 회귀 문제를 파이토치를 통해 풀어봅시다.

Example

\[y_{true}(x) = 7\sin{(0.75x)} + 0.5x + \epsilon\]

의 식을 갖는 함수를 만들어봅시다. 이때 입실론은 어떤 무작위의 노이즈입니다. 우선 쭉 사용할 패키지들을 불러옵시다.

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

그러면 샘플을 뽑아내는 함수를 정의해봅시다.

def generate_data(n_samples):
    epsilon = np.random.normal(size=(n_samples))
    x_data = np.random.uniform(-10.5, 10.5, n_samples)
    y_data = 7*np.sin(0.75*x_data) + 0.5*x_data + epsilon
    return x_data, y_data

n_samples = 1000
x_data, y_data = generate_data(n_samples)

입실론은 정규분포를 따르도록 뽑아냅니다. x는 -10.5 ~ 10.5 사이의 값을 갖도록 했고, y는 해당하는 x값에 대해서 위의 함수를 따르도록 구했습니다. 이들을 갖고 그림을 찍어보면…

plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.2)
plt.show()

image

요로코롬 생긴 곡선을 얻을 수 있습니다. 자 이제 간단한 모델을 만들어 이 함수를 근사해보도록 합시다.

n_input = 1
n_hidden = 20
n_output = 1

network = nn.Sequential(nn.Linear(n_input, n_hidden),
                        nn.Tanh(),
                        nn.Linear(n_hidden, n_output))
loss_fn = nn.MSELoss()

매우 간단한 모델입니다. x는 1차원의 값이므로 이를 확장할 수 있게 1->20의 선형 레이어를 갖고, 이후 활성화함수로 tanh를 거치도록 합니다. 이후 다시 선형 레이어를 넣어 마지막에 다시 한 개의 값을 뽑아내도록 했습니다. 이때 실수를 만들어내는 회귀 문제이므로 마지막 레이어에는 별도의 활성화함수를 추가할 필요가 없습니다. 이를 MSE를 최소화하도록 학습시켜 봅시다. 이때 주의해야하는 점은 numpy array를 pytorch가 사용할 수 있는 tensor로 바꿔줘야합니다. 또한 numpy의 기본 형태인 np.float64를 pytorch의 기본형인 np.float32로 바꿔줘야합니다.

x_tensor = torch.from_numpy(np.float32(x_data).reshape(n_samples, n_input))
y_tensor = torch.from_numpy(np.float32(y_data).reshape(n_samples, n_input))
x_variable = Variable(x_tensor)
y_variable = Variable(y_tensor, requires_grad=False)

이때 1000개의 샘플을 한번에 처리하기 위해서, 형태를 [1000, 1]로 바꿔줍니다. 자 이제 학습을 시켜 봅시다. 기본적인 동작은 이전의 네트워크 구현과 동일한 방식입니다. 네트워크에 x를 넣어 순전파를 시킨다음, 이에 대응하는 loss를 구하고, 이를 역전파시키고, 파라미터를 갱신하는 것이죠. 이를 코드로 표현하면 아래와 같습니다.

def train():
    for epoch in range(3000):
        y_pred = network(x_variable)
        loss = loss_fn(y_pred, y_variable)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 300 == 0:
            print(epoch, loss.data[0])

train()

그 다음에는 학습 데이터외의 x에 대해서 모델이 어떻게 예측하고 있는지를 알아봅시다. 똑같이 x를 샘플링한다음 이를 네트워크에 넣어 y값들을 얻어낸다음, 이를 다시 그래프로 표현하면 됩니다.

x_test_data = np.linspace(-10, 10, n_samples)

x_test_tensor = torch.from_numpy(np.float32(x_test_data).reshape(n_samples, n_input))
x_test_variable = Variable(x_test_tensor)
y_test_variable = network(x_test_variable)

y_test_data = y_test_variable.data.numpy()

plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.2)
plt.scatter(x_test_data, y_test_data, alpha=0.2)
plt.show()

image

보면 네트워크가 이 함수를 아주 잘 표현하고 있다는 것을 알 수 있습니다. 이는 Universal approximation theorem이라는 이론과 관련되어 있습니다. 이론 상으로는 은닉층 하나만으로도 Multi-Layer Perceptron은 임의의 연속함수를 근사할 수 있습니다. 물론 깊이가 1이라는거지, 너비는 무한정 늘어날 수도 있습니다… 여튼 이런 one-to-one인 경우와 many-to-one, 곧 하나 이상의 x와 하나의 출력 y를 가지는 경우에는 딥러닝이 이를 잘 표현할 수 있음을 확인할 수 있습니다.

문제는 여러 개의 출력을 가질 수 있는 경우입니다. 지금까지는 한 개의 mode를 갖는 정규분포에서 회귀를 했다고 생각할 수 있습니다. 하지만 만약 여러 개의 mode로 표현되는 정규분포에서 전과 똑같이 네트워크를 학습시킨다면 어떻게 될까요? 곧 multimodal regression을 기존의 네트워크를 가지고 해보겠다는 겁니다. 아까의 예제 코드를 그대로 활용해봅시다.

plt.figure(figsize=(8, 8))
plt.scatter(y_data, x_data, alpha=0.2)
plt.show()

image

그냥 간단하게 y와 x를 바꿨습니다. 중요한 점은 이제 어떤 입력 x가 주어졌을 떄 여러 개의 y가 가능하다는 겁니다. 곧 이전의 예시처럼 하나의 정규분포로 표현되는 상황을 넘어, 여러 개의 정규분포를 갖는 상황이 된거죠. 이런 상황에서 기존의 네트워크를 학습시키고 결과를 확인해봅시다.

x_variable.data = y_tensor
y_variable.data = x_tensor

train()

x_test_data = np.linspace(-15, 15, n_samples)
x_test_tensor = torch.from_numpy(np.float32(x_test_data).reshape(n_samples, n_input))
x_test_variable.data = x_test_tensor

y_test_variable = network(x_test_variable)

# move from torch back to numpy
y_test_data = y_test_variable.data.numpy()

# plot the original data and the test data
plt.figure(figsize=(8, 8))
plt.scatter(y_data, x_data, alpha=0.2)
plt.scatter(x_test_data, y_test_data, alpha=0.2)
plt.show()

image

보시면 아주 이상한 선이 그려짐을 확인할 수 있습니다. 이는 기본적으로 우리가 MSE를 최소화하도록 학습을 시키고 각 입력에 대해 하나의 출력만 가능했기 때문입니다. 이러한 문제를 해결하기 위해서 제안된 것이 Mixture Density Network, MDN입니다.

MDN

MDN은 Christopher Bishop이 제안한 구조입니다. 하나의 입력이 주어졌을 때 여러 개의 결과를 만들어낼 수 있는 방법이죠. 곧 같은 x에 대해서 다른 분포를 따르는 y에서 $p(y \mid x)$를 예측하는 것입니다. 정말로 여러 분포가 가능하지만, 여기서는 정규분포만을 가정하고 접근해봅시다. 이를 식으로 표현하면 $p(y \mid x) = \sum_{i=1}^n p(c = i \mid x)N(y; \mu^i, \sigma^i)$라고 할 수 있습니다. n개의 정규분포를 가정을 하고, 각 분포에서 y가 나올 확률을 이 분포에 속할 확률과 곱해서 결과를 예측을 하는 것입니다. 이렇게 $p(y\mid x)$를 만들어낸다음에는 샘플링을 통해 최종 예측을 해줍니다.

결국 이 네트워크에서 만들어내야하는 것은 각 정규분포 n개에서 세 가지 값입니다. $p(c = i \mid x)$, $\mu^i$, $\sigma^i$말이죠. 우선 $p(c = i \mid x)$의 경우에는 모두 다 더해서 1이 되어야한다는 제약이 있으므로 softmax를 사용해 이를 normalize해줍니다. $\mu^i$의 경우에는 특별한 제약이 없지만 $\sigma^i$의 경우에는 양수가 되야한다는 제약이 존재합니다. 학습 과정에서는 더 이상 MSE를 사용할 수가 없습니다. 따라서 교차엔트로피를 사용해 이를 최소화하도록 해줍니다. 교차엔트로피 식은 다음과 같습니다.

\[E = -\log{ \sum_{i=1}^m p(c = i \mid x)N(y; \mu^i, \sigma^i)}\]

전체적인 과정을 먼저 설명하고 구현으로 넘어갑시다. 우선 입력으로부터 20차원의 값을 만들어냅니다. 이렇게 만든 20차원의 값으로 필요한 파라미터들 - $p(c = i \mid x)$, $\mu^i$, $\sigma^i$를 예측할 겁니다. 이때 이렇게 만들어낸 p의 경우에는 확률의 정의를 만족하도록 소프트맥스에 넣어줄 겁니다. 자 네트워크를 정의합시다.

class MDN(nn.Module):
    def __init__(self, n_hidden, n_gaussians):
        super(MDN, self).__init__()
        self.z_h = nn.Sequential(
            nn.Linear(1, n_hidden),
            nn.Tanh()
        )
        self.z_pi = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma = nn.Linear(n_hidden, n_gaussians)
        self.z_mu = nn.Linear(n_hidden, n_gaussians)  

    def forward(self, x):
        z_h = self.z_h(x)
        pi = nn.functional.softmax(self.z_pi(z_h), -1)
        sigma = torch.exp(self.z_sigma(z_h))
        mu = self.z_mu(z_h)
        return pi, sigma, mu

$z_h$를 먼저 만들어내고, 이 값을 가지고 $\mu, \sigma, p$를 만들어낼 겁니다. 위에 써놨듯, p는 확률의 정의를 만족해야 하므로 소프트맥스에 넣어줍니다. 또한 $\sigma$의 경우에는 양의 값을 가져야하므로 $e^x$를 거치게 해줍시다.

image

Loss의 경우에는 교차엔트로피를 사용할 겁니다. 그 전에 우선 $\mu^i$와 $\sigma^i$에서 주어진 y가 나올 확률을 뽑아내는 함수를 만들어봅시다.

\[N(\mu, \sigma)(x) = \frac{1}{\sigma \sqrt{2\pi}} \exp (-\frac{(x-\mu)^2}{2\sigma^2})\]

를 구현합니다.

oneDivSqrtTwoPI = 1.0 / np.sqrt(2.0*np.pi)
def gaussian_distribution(y, mu, sigma):
    result = (y.expand_as(mu) - mu) * torch.reciprocal(sigma)
    result = -0.5 * (result * result)
    return (torch.exp(result) * torch.reciprocal(sigma)) * oneDivSqrtTwoPI

reciprocal의 경우에는 입력의 역수를 구해주는 함수입니다. 자 이를 가지고 loss를 구현합시다.

def mdn_loss_fn(pi, sigma, mu, y):
    result = gaussian_distribution(y, mu, sigma) * pi
    result = torch.sum(result, dim=1)
    result = -torch.log(result)
    return torch.mean(result)

각 분포로부터 y가 나올 확률과 그 분포에 대응할 확률을 곱하고, 이들을 다 더한다음 로그와 평균을 취해주면 됩니다. 여기서는 20개의 hidden layer unit과 5개의 정규분포를 사용해 학습을 시켜보겠습니다. optimizer로는 Adam을 사용합니다.

network = MDN(n_hidden=20, n_gaussians=5)
optimizer = torch.optim.Adam(network.parameters())

아까처럼 미리 텐서를 만들어두고 학습을 진행합시다.

mdn_x_data = y_data
mdn_y_data = x_data

mdn_x_tensor = y_tensor
mdn_y_tensor = x_tensor

x_variable = Variable(mdn_x_tensor)
y_variable = Variable(mdn_y_tensor, requires_grad=False)

자 이제 학습을 진행할 시간입니다.

def train_mdn():
    for epoch in range(10000):
        pi_variable, sigma_variable, mu_variable = network(x_variable)
        loss = mdn_loss_fn(pi_variable, sigma_variable, mu_variable, y_variable)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(epoch, loss.data[0])

train_mdn()

학습이 끝난 다음에는 입력의 변화에 따라 각 분포의 평균과 분산이 어떻게 달라지는지 확인할 수 있습니다.

pi_variable, sigma_variable, mu_variable = network(x_test_variable)

pi_data = pi_variable.data.numpy()
sigma_data = sigma_variable.data.numpy()
mu_data = mu_variable.data.numpy()

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(8,8))
ax1.plot(x_test_data, pi_data)
ax1.set_title('$p(c = i | x)$')
ax2.plot(x_test_data, sigma_data)
ax2.set_title('$\sigma$')
ax3.plot(x_test_data, mu_data)
ax3.set_title('$\mu$')
plt.xlim([-15,15])
plt.show()

image

혹은 $\mu \pm \sigma$ 영역을 강조해서 영역을 그려볼 수도 있죠.

plt.figure(figsize=(8, 8), facecolor='white')
for mu_k, sigma_k in zip(mu_data.T, sigma_data.T):
    plt.plot(x_test_data, mu_k)
    plt.fill_between(x_test_data, mu_k-sigma_k, mu_k+sigma_k, alpha=0.1)
plt.scatter(mdn_x_data, mdn_y_data, marker='.', lw=0, alpha=0.2, c='black')
plt.xlim([-10,10])
plt.ylim([-10,10])
plt.show()

image

보면 한 x에 따라 여러개의 y가 가능할 수도 있음을 확인할 수도 있습니다. 이들을 선택할 확률은 $p(c=i \mid x)$에 의해서 결정되는 것입니다. 더 많은 정규분포를 사용해서 loss를 줄일 수도 있겠지만, 결과를 해석하기는 더 어려워집니다.

학습시킨 네트워크에서 결과를 얻고 싶다면 특정한 정규분포를 하나 고르고 그로부터 값을 뽑아내야 합니다. 이를 위해서는 Gumbel softmax sampling을 사용하면 된다는데, 이 사이트를 참조하세요

def gumbel_sample(x, axis=1):
    z = np.random.gumbel(loc=0, scale=1, size=x.shape)
    return (np.log(x) + z).argmax(axis=axis)

k = gumbel_sample(pi_data)

이제 우리는 각 x에 대해서 어떤 정규분포를 선택해야되는지를 알았으니 각각의 평균과 분산을 이용해 이를 샘플링하기만 하면 됩니다.

indices = (np.arange(n_samples), k)
rn = np.random.randn(n_samples)
sampled = rn * sigma_data[indices] + mu_data[indices]

rn의 경우에는 무작위의 노이즈이며 표준정규분포를 따르므로 이에 $\sigma$를 곱하고 $\mu$를 더해주기만하면 원래의 정규분포를 얻을 수 있습니다. 이렇게 해서 최종 결과물을 얻어보면

plt.figure(figsize=(8, 8))
plt.scatter(mdn_x_data, mdn_y_data, alpha=0.2)
plt.scatter(x_test_data, sampled, alpha=0.2, color='red')
plt.show()

image

와 같습니다.

정리

기존의 네트워크의 경우에는 하나의 입력에 대해 여러 개의 경우가 가능한 경우에 효과적으로 대응할 수 없었습니다. MDN은 이러한 문제를 해결하고자 등장했습니다. 여러 개의 정규분포(혹은 다른 분포)의 평균과 분산을 예측하고, 각 정규분포에 속할 확률을 통해 이를 효과적으로 근사할 수 있게 한 것이죠.

Reference

Bishop의 원래 논문

Deep learning, Ian Goodfellow

가져온 코드