Dual Discriminator Generative Adversarial Nets 리뷰/구현
D2GAN 리뷰와 구현
안녕하세요 오늘은 Dual Discriminator Generative Adversarial Nets(이하 D2GAN)라는 논문에 대해 간단히 정리해보고 이를 구현한 코드를 소개하려고 합니다.
우리가 흔히 생각하는 GAN은 Discriminator(이하 D)와 Generator(이하 G) 두개의 네트워크를 사용해 데이터를 만들어내게 됩니다. 그러나 이러한 방식의 구현은 몇 가지의 문제를 갖고 있는데요, 대표적으로 mode-collapse문제를 꼽을 수가 있습니다. 이 글 끝부분에 나온 그림을 보면 원형의 여러 mode를 가지는 데이터가 있다고 했을 때, 이들을 적절하게 따라하는 것이 아니라 하나의 mode에만 초점을 맞추고 데이터를 생성하는 모습을 볼 수 있습니다. 생각해보면 어찌됐든 한 mode에 초점을 맞추면 D가 판단에 어려움을 겪을 수 밖에 없다는 것을 떠올릴 수가 있게 됩니다.
여기서 소개하는 D2GAN의 경우에는 이러한 문제를 해결하기 위해 제안된 모델입니다. 기존의 GAN이 풀고자 했던 two player mini-max game을 three player로 확장했는데요, 그냥 세 개의 네트워크를 사용해 보다 균형있는 데이터를 만들어내고자 했다고 이해하시면 될 것 같습니다. 이러한 수식이 가능한 이유에 대한 수학적 증명도 논문에서는 보여주고 있지만 여기에서는 생략하도록 하겠습니다.
기존의 GAN
기존의 GAN에 대해서 간단하게 다시 설명해보죠.
G는 어떠한 저차원의 분포 - 보통 정규분포나 uniform distribution을 사용합니다 - 에서 임의로 뽑아낸 noise에서 데이터를 만들어내는 역할을 합니다. D는 G가 만들어낸 데이터와 진짜 데이터를 보고 “이것이 진짜다, 이것이 가짜다” 0과 1의 값으로 판단하는 역할을 수행합니다. 만약 G가 정말정말 잘 트레이닝되어서 진짜와 동일한 수준의 데이터를 만들어낸다면 D는 어떠한 데이터가 들어오든 “찍을” 수 밖에 없게 되죠. 즉 $P_G = P_{data}$라면 $D(x) = 0.5$가 될 것입니다.
이때 사용하는 two player game은 다음과 같은 식으로 표현할 수 있습니다.
\[\min_{G}\max_{D}{V(D,G)} = \mathbb{E}_{x\sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z\sim p_z(z)}[\log (1-D(G(z)))]\]D의 입장에서는 가짜를 가짜라고, 진짜를 진짜라고 판단해야하니 저 식을 최대화해야합니다. 반대로 G의 경우에는 D를 속여야하니, 가짜를 진짜처럼 인식하게 학습을 시켜야하는 것이죠. 물론 실제 구현에서는 $\log (1-D(G(z))$ 대신에 $\log (D(G(z)))$를 사용하긴 하지만요.
D2GAN
그러나 기존 GAN은 mode-collapse라는 문제를 겪게 됩니다. 이를 해결하기 위해서 여기선 새로운 구조를 제안합니다. 바로 Discriminator를 두 개 두는 것입니다. 이들 각각을 D1, D2라고 해봅시다. 기존 GAN과의 차이점은 이 두 D가 0~1사이의 확률을 결과값으로 출력하는 것이 아니라, 양의 실수 값들을 결과값으로 출력한다는 것입니다. 이때 당연히 D1, D2는 같은 역할을 수행하진 않습니다. 그렇게 하면 굳이 네트워크를 두 개가지고 있을 필요가 없죠.
D1은 진짜 데이터를 보고 진짜라고 판단하도록 학습시킵니다. 당연히 가짜 데이터를 보면 가짜라고 판단하게 되죠. D2는 반대로 가짜 데이터를 진짜라고 판단하도록 학습시킵니다. 진짜를 보면 가짜라고 판단을 하구요. 이를 기존 GAN에서처럼 three player game으로 나타내면 아래와 같아집니다.
\[\min_G \max_{D_1, D_2}{V(G, D_1, D_2)} = \alpha \times \mathbb{E}_{x\sim p_{data}}(\log(D_1(x))) + \mathbb{E}_{z \sim p_z}(-D_1(G(z))) + \mathbb{E}_{x \sim p_{data}}(-D_2(x))+ \beta \times \mathbb{E}_{z \sim p_z}(\log D_2(G(z)))\]$\alpha, \beta$는 0~1사이의 값인데요, 앞에서 말씀드렸듯이 D는 양의 실수 값 전체를 가질 수 있기 때문에 unbounded합니다. 따라서 이를 로그를 취해 어느정도 잡아주긴하지만, 보다 더 강하게 제약을 주어 학습을 안정화시켜줍니다.
Network의 구조는 전반적으로 DCGAN을 그대로 사용했습니다. 대신 32x32의 이미지에 대해 학습을 시켜줬다고 나와있어, G와 D들의 첫번째 레이어에서 커널의 크기를 반으로 줄여 32x32를 생성할 수 있게 하였습니다. 또한 D는 양의 실수 값을 출력하므로 기존의 sigmoid 대신 softplus라는 활성화 함수를 사용했습니다. 구현은 링크를 참조하시면 될 것 같습니다.