GAN cơ bản



Hello xin chào ae, hôm nay mình sẽ chia sẻ một chút hiểu biết của mình về mạng GAN, một kiểu mô hình Neural network mà mình rất thích



1. Giới thiệu về GAN
Đầu tiên thì giới thiệu qua một chút về mạng GAN là cái gì đã. GAN là viết tắt của generative adversarial networks. Đây là một mạng có khả năng sinh ra một dữ liệu mới dựa trên những giữ liệu đã có sẵn. Ví dụ như nó có thể sinh ra một người hoàn toàn không có thật ở thế giới mà mắt người thường nhìn qua sẽ ko thể nào phân biệt được (ko tính mất thể loại vinasoi nhé). Trong deeplearning thì GAN nó có vai trò quan trọng trong việc sinh ra dữ liệu mới phục vụ cho việc training model đạt được hiệu suất tốt hơn,...
                                                    nguồn https://arxiv.org/abs/1511.06434

2. GAN làm việc như thế nào ?
GAN nó gồm 2 mạng neural network là Generator (mình gọi tắt là G) và Discriminator (mình gọi tắt là D). Thằng Generator có nhiệm vụ là sinh ra dữ liệu mới (fake) còn thằng Discriminator có nhiệm vụ là phân biệt data đầu vào của nó là data thật hay là data fake được tạo ra bởi Generator và hai mạng này sẽ được trainning đồng thời sao cho có loss của cả 2 là thấp nhất. Tất nhiên mỗi thằng có chức năng riêng nên hàm loss của mỗi thằng cũng khác nhau một chút, mình sẽ thảo luận nó sau. 
Có thể hiểu nôm na thằng GAN nó hoạt động như là câu truyện cảnh sát đi bắt kẻ làm tiền giả vậy. Trong câu truyện này thì thằng Generator nó đóng vai trò như là kẻ làm tiền giả, kẻ làm tiền giả tất nhiên sẽ phải cố gắng làm sao tạo ra được đồng tiền giả mà nó giống với tiền thật mà qua mắt được cảnh sát, thì thằng Generator cũng vậy nó phải học làm sao để tạo ra được data fake (ví dụ như một bức ảnh của một người ko có thật nhưng lại rất giống với con người) để qua mắt được thằng Discriminator. Còn thằng Discriminator nó giống như là cảnh sát vậy, cố gắng phân biệt đồng tiền là giả hay thật để còn tóm cổ kẻ làm tiền giả kia. Và do kẻ làm tiền giả được học nên nó ngày càng làm ra tiền giả tinh vi nên cảnh sát cũng phải học lên ko là bị qua mặt ngay. Và cứ như thế là cảnh sát và kẻ làm tiền giả sẽ cùng nhau học đến khi khả năng thành công của cả hai (kẻ làm tiền giả qua mắt được cảnh sát, cảnh sát bắt được kẻ làm tiền giả) là 50% thì khi đó là mạng đã cân bằng rồi, có nghĩa tên làm tiền giả kia đã fake được đồng tiền giống tiền thật mà cảnh sát cũng ko chắc cú được là thật hay giả nữa thì là mạng cân bằng rồi. Chém gió câu chuyện trên cho dễ tưởng tượng, bây giờ chúng ta sẽ đi vào xem cụ thể thằng G và D nó như thế nào.

3. Generator.
Generator thì nó đơn giản chỉ là một mạng neural network với các lớp densen hoặc là mạng neural network với các lớp convolutional neural network (DCGAN). Trong bài này mình nói về DCGAN nên sẽ là các lớp convolutional neural network nhé tại để sinh ra ảnh phức tạp thì CNN nó tiện hơn còn mạng neural networdk thì cũng tương tự thôi ko có gì đặc biệt cả.

                                                nguồn https://arxiv.org/abs/1511.06434

Như bạn có thể thấy thì Generator nó được tạo ra từ các layer CNN đầu tiên là mình sẽ lấy một vector noise z 100 chiều bằng cách random ngẫu nhiên theo phân phối chuẩn. sau đó mình sẽ cho qua layer Dense để đưa nó về số chiều mới là 4*4*1024 rồi sau đó mình sẽ reshape nó về dạng tensor 3D có (c, h, w) = (1024, 4, 4). Tiếp theo mình sẽ sử dụng transpose convolution để upsample nó lên h, w có kích thước lớn hơn và giảm channel nó xuống thành (512, 8, 8) , tiếp theo mình lại transpose convolution nó lên (256, 16, 16) tiếp tục transpose convolution thành (128, 32, 32) và cuối cùng là giảm transpose convolution thành (3, 64, 64) và đây chính là bức ảnh fake G(z) mà mình sẽ dùng để đánh lừa thành Discriminator. Ở đầu ra của các layer thì mình sẽ sử dụng bactchnorm để scale lại đầu ra  và dùng activation là relu hoặc leakyrelu đều ok trừ thằng G(z) cuối cùng mình sẽ để activation là tanh. Nhiệm vụ của Generator là sẽ học làm sao để tạo ra cái ảnh fake đầu ra G(z) giống ảnh thực nhất có thể để đánh lừa được thằng Discriminator. Còn phần transpose convolution thì bạn nào học qua về segmentation chắc đề biết rồi nên mình sẽ ko nói ở bài này hoặc các bạn có thể xem ở link này.

4. Discrimitor.
Discrimitor thì nó ngược lại với generator, nếu Generator dùng tích chập chuyển vị để tăng kích thước của feature map thì Discrimitor sẽ dùng CNN để downsample các feature map xuống kích thước nhỏ hơn

Đầu tiên là tích chập để giảm kích thước của feature map xuống từ 3x64x64 thành 128x32x32 ⟶ 256x16x16 ⟶ 512x8x8⟶ 1024x4x4 rồi sẽ flatten cái tensor cuối cùng thành 1 vector có chiều 1024.4.4 rồi đưa qua lớp dense để chuyển nó về 1 đầu ra là 1 tại vì Discriminator nó chỉ có chức năng phân biệt thật hoặc fake. Data thật thì đầu ra sẽ là 1 còn data fake đầu ra sẽ là 0. Ở đầu ra của các layer thì ta sẽ sử dụng thêm batchnorm để scale ouput và activation là relu hoạc leakyrelue đều được trừ layer cuối cùng thì mình sẽ để activation là sigmoid vì nó chỉ phân biệt thật và giả thôi mà.

5. Loss function.
Vì mục đích của ta là sẽ train mạng Generator để tạo ra một bức ảnh fake thật nhất qua mặt được Discriminator nên loss của mỗi thằng sẽ khác nhau. Loss của Generator và Discriminator là binary crossentropy:
Nhìn công thức của hàm los này có vẻ hơi khó hiểu một tí nhưng đơn giản có thể diễn giải nó như sau:
gọi x là data thực, z là noise random, G(z) là data fake (ảnh fake).
B1: Random một noise z sau đó đưa nó vào Generator để tạo ra một ảnh fake là G(z). 
B2: G(z) và x sẽ được đưa vào Discriminator được ouput trong khoảng (0, 1) (do có sigmoid).
B3: Nếu đầu vào của Discriminator là x thì nó sẽ có label là 1 còn đầu vào là G(z) có label là 0 và binary loss sẽ được sử dụng để tính.
sau đó nó sẽ tối ưu các tham số của thằng Discriminator.
⟶ Lúc này ta có thể coi như là ô cảnh sát đã được học qua một lần để phân biệt tiền thật và tiền giả rồi. Bây giờ là đến giờ học của anh kẻ trộm (Generator)
B1: Random một noise z sau đó đưa nó vào Generator để tạo ra một ảnh fake là G(z). 
B2: G(z) sẽ được đưa vào Discriminator được ouput trong khoảng (0, 1) (do có sigmoid).
B3: Label của đầu ra lúc này sẽ là 1 do mục đích của mình là muốn tối ưu tham số của Generator để đánh lừa được Discriminator mà. Và loss Binary lại đk áp dụng.
Cứ lặp lại như vậy cho đến khi nào xác suất phân biệt của discriminator là 0.5 là mạng dừng lại.

5. Code:
Nói đôi khi nó cũng trìu tượng hơi khó hiểu nên mk sẽ để code ở đây cho các bạn đọc để hiểu. Mình sẽ sửa đổi một chút model cho tập data cifar10 và dùng activation relu mình cũng đã comment trong code nhé. 



from google.colab import drive
drive.mount('/content/gdrive')

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
import matplotlib.pyplot as plt

root_data = './data'
trainset = datasets.CIFAR10(root=root_data, train=True, transform=transforms.ToTensor(), download=True)
testset = datasets.CIFAR10(root=root_data, train=False, transform= transforms.ToTensor(), download=True )
dataloader = DataLoader(trainset, batch_size = 2048, shuffle=True)
print(len(dataloader.dataset))

# build Generator
class DCGAN_generator(nn.Module):
    def __init__(self, ngpu = 1):
        super(DCGAN_generator, self).__init__()
        self.ngpu = ngpu


        nz = 100 # noise dimension
        ngf = 64 # number of features map on the first layer
        nc = 3 # number of channels

        self.main = nn.Sequential(
          # input is Z, going into a convolution
          nn.ConvTranspose2d(in_channels= nz, out_channels=ngf*4, kernel_size=4, stride=1, padding=0, bias=False),
          nn.BatchNorm2d(ngf * 4),
          nn.ReLU(True),
          # state size. (ngf*8) x 4 x 4
          nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf * 2),
          nn.ReLU(True),
          # state size. (ngf*4) x 8 x 8
          nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf),
          nn.ReLU(True),
          # state size. (ngf*2) x 16 x 16
          nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
          nn.Tanh()
          # state size. (nc) x 32 x 32
        )
    
    def forward(self, input):
        # 1x1x100 -> 6x6x256 -> 12x12x128 -> 24x24x64 -> 48x48x3
        output = self.main(input)
        return output

# build Discriminator
class DCGAN_discriminator(nn.Module):
    def __init__(self, ngpu = 1):
        super(DCGAN_discriminator, self).__init__()
        self.ngpu = ngpu

        ndf = 64
        nc = 3

        self.main = nn.Sequential(
          nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf),
          nn.LeakyReLU(0.2, inplace=True),

          nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 2),
          nn.LeakyReLU(0.2, inplace=True),
          # state size. (ndf*4) x 8 x 8
          nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 4),
          nn.LeakyReLU(0.2, inplace=True),
          # state size. (ndf*8) x 4 x 4
          nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
          nn.Sigmoid()
        )
    
    def forward(self, input):
        output = self.main(input)

        return output.view(-1, 1).squeeze(1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

netD = DCGAN_discriminator()
netG = DCGAN_generator()

netD.to(device)
netG.to(device)

optD = optim.Adam(netD.parameters(), betas=(0.5, 0.999), lr = 2e-4)
optG = optim.Adam(netG.parameters(), betas=(0.5, 0.999), lr = 2e-4)

criterion = nn.BCELoss()

def plot_images(image):
    bs, c , h, w= image.shape
    size_view = int(np.sqrt(bs))
    image = image.permute(0,2,3,1).contiguous()
    image = image.view(-1,32*size_view, 32, 3)
    image = image.permute(0,2,1,3).contiguous()
    image = image.view(-1, 32*size_view, 3)
    plt.imshow(image.transpose(0,1))
    plt.show()

def train(num_epoch):
  for epoch in range(num_epoch):
    for image, _ in dataloader:
      image = image.to(device)
      bs = image.shape[0]
      init_vector = torch.randn(bs,100,1,1).to(device)
      image_fake = netG(init_vector)
      true_label = torch.ones(bs).to(device)
      fake_label = torch.zeros(bs).to(device)

      optD.zero_grad()
      predict_true = netD(image)
      loss_true = criterion(predict_true, true_label)
      loss_true.backward(retain_graph=True)
      
      predict_fake = netD(image_fake)
      loss_fake = criterion(predict_fake, fake_label)
      loss_fake.backward(retain_graph=True)

      optD.step()

      optG.zero_grad()
      predict_fake_G = netD(image_fake)
      loss_g = criterion(predict_fake_G, true_label)
      loss_g.backward()
      optG.step()
    
    test = torch.randn(36,100,1,1).to(device)
    image_fake = netG(test)
    if epoch%2:
      torch.save(netG.state_dict(), f'{epoch}_netG.pth')
      torch.save(netD.state_dict(), f'{epoch}_netD.pth')
    print(epoch+1)
    plot_images(image_fake.detach().to('cpu'))

Tham khảo: 

Nhận xét

Bài đăng phổ biến