티스토리 뷰

Deep Learning

Building Basic GAN, Week3 - Lecture

seoyoung02 2024. 3. 24. 22:48

※ Coursera의 Build Basic Generative Adversarial Networks (GANs) 강의를 듣고 작성한 글입니다.

 

Week3에서는 GAN의 문제점과 새로운 loss에 대해 알아봅니다.

 

Mode Collapse

Mode

분포에서 가장 높은 부분을 mode라고 합니다. 분포 내에서 여러개의 mode를 가질 수 있고, 일반적인 데이터셋은 대부분 여러개의 mode를 가집니다.

MNIST로 GAN을 학습한다고 가정하겠습니다. 0~9의 각 숫자마다 mode가 존재하게 됩니다. 이때 generator에서 여러 숫자를 생성해 냈다고 해봅시다. 그리고 discriminator에서 판단한 결과 1과 7을 제외한 모든 숫자들이 fake로 판단되었습니다. 그러면 generator는 다른 숫자를 잘 생성하려고 하기보다는 real로 판단한 1과 7을 생성하면 된다고 생각할 수 있습니다. 그리고 discriminator는 다시 1과 7에 대해서 real/fake을 판단합니다. 그러면 discriminator가 local minima에 빠지게 되고, 이제는 1에 대해서만 real로 판단합니다. 그럼 다시 generator는 1과 유사하게 생성하게 됩니다. 이러한 과정을 거쳐 single mode만 남게 되는 문제가 발생하고, 이를 mode collapse라고 합니다.

BCE Loss의 문제

Generator는 이미지와 같은 복잡한 output을 가지는 반면, discriminator는 real/fake만 구분하는 단순한 output을 가집니다. 따라서 discriminator를 학습하기 더 쉽습니다.

Discriminator는 생성된 이미지의 분포와 실제 데이터의 분포가 멀어지게 합니다.(서로 일치해야 실제에 가까운 이미지를 생성한다고 할 수 있습니다.) 분포가 멀어질 수록 오차는 증가하고 오차의 기울기는 감소하여 generator에 전달할 정보가 감소하는 vanishing gradient 문제가 존재합니다.

Earch Mover's Distance

위의 문제를 해결하기 위해 사용되는 방법입니다. Earth mover's distance는 두 분포를 동일하게 만들기 위해서 얼만큼 옮겨야하는지를 계산하는 것이므로 선형적인 형태를 가져서 gradient가 일정하다고 합니다. 

일종의 transportation 문제라고 합니다. 분포를 동일한 크기로 나누었다고 가정하고, 각 칸을 다른 분포의 위치로 이동하는 최단거리로 생각하는 것 같습니다. 이것을 어떻게 연속적으로 풀어내는지는 아직 잘 모르겠네요.. 

Wasserstein Loss

BCE Loss와 Wasserstein Loss

Wasserstein loss에 대해 한 마디로 간결하게 정리하고 지나갑니다. Earth mover's distance를 근사한 것. (자세한 설명은 이 링크를 참고하면 좋을 것 같습니다. 조만간 정리해볼게요.)

BCE Loss를 간단히 나타내면 위의 그림과 같습니다. 여기서 $log(d(x))$를 $c(x)$로, $1-log(d(g(z)))$를 $c(g(z))$로 바꾸면 wasserstein loss가 됩니다. $log(d(x))$는 실제 이미지를 넣었을 때(정답이 1일 때) gt와의 cross entropy라면, $c(x)$는 실제 이미지를 넣었을 때 gt 분포와의 거리입니다. $c(g(z))$는 GAN에서 생성된 이미지와 gt 분포와의 거리겠죠.

이 방식은 real/fake를 판단하는 것이 아니라 거리를 계산해서 discriminate loss보다는 critic loss라고 부른다고 합니다. 0과 1사이의 값을 가지는 것이 아닌 제한이 없는 값을 가지게 됩니다. 따라서 mode collaps와 vanishing gradient가 발생하지 않습니다.

W-loss의 조건

Critic의 네트워크는 특정 조건을 만족해야 합니다. 네트워크가 1-Lipschitz(1-L) continuous해야 한다고 해요. 이 말은 네트워크의 모든 점의 gradient가 1 이하여야 한다는 것입니다. 위의 이미지에서 초록색 영역 안에 그래프가 존재해야 하는 것입니다. 수식으로는 $||\nabla f(x)||_{2} < 1$

이 조건을 만족하게 하는 방법은 weight clipping과 gradient penalty가 있습니다.

Weight clipping은 이름 그대로 weight를 제한하는 것 입니다. Gradient descent로 weight를 업데이트하고 나서 정해진 구간을 넘어가는 값을 강제로 구간에 맞춰주는 것 입니다. 이 방식은 weight가 가지는 값의 다양성이 낮아져 성능 향상이 어려울 수 있다는 단점이 있습니다.

Gradient penalty는 loss function에 regularization을 추가하는 것 입니다.

Gradient penalty's regularization term

$\hat{x}$는 real 이미지와 생성된 이미지를 임의의 비율로 합쳐놓은 것 입니다. 이 위치에서 critic의 기울기가 1이 넘지 않아야 한다는 제한 사항을 넣은 것 입니다. 제곱의 형태로 넣음으로써 1을 더 많이 초과할 수록 더 큰 값을 가져 regularizaiton 효과를 줍니다. 최종적으로 아래와 같은 식이 됩니다.

$$\min_{g} \max_{c} \mathbb{E}(c(x))-\mathbb{E}(c(g(z)))+\lambda \mathbb{E}(||\nabla c(\hat{x})||_{2}-1)^{2}$$

댓글
최근에 올라온 글
Total
Today
Yesterday
최근에 달린 댓글
링크
공지사항
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함