深度學習理論與實務

林嶔 (Lin, Chin)

Lesson 11 對抗生成網路進階應用

第一節:解決對抗生成網路的訓練問題(1)

– 在當初的神經網路我們為了訓練一個很深的網路,常常需要對超參數做大量的嘗試修正,有時候還需要自編碼器的輔助,直到Residual Learning結束了這一切。

– 我們回顧一下我們的Cross-entropy損失函數以及他的導函數:

\[ \begin{align} CE(y, p) & = \frac{{1}}{n}\sum \limits_{i=1}^{n} -\left(y_{i} \cdot log(p_{i}) + (1-y_{i}) \cdot log(1-p_{i})\right) \\ \frac{\partial}{\partial p}CE(y, p) & = \frac{p-y}{p(1-p)} \end{align} \]

\[ \begin{align} S(x) & =\frac{1}{1+e^{-x}} \\ \frac{\partial}{\partial x}S(x) & = S(x)(1-S(x)) \end{align} \]

第一節:解決對抗生成網路的訓練問題(2)

– 但對於Generator而言呢,他這時候需要大量的更新試圖重新騙過Discriminator,但這時候他的梯度將是…

\[ \begin{align} \lim_{p \rightarrow 1} CE(0, p) & = - log(1-p) \\ \frac{\partial}{\partial p} \lim_{p \rightarrow 1} CE(0, p) & = \frac{p}{p(1-p)} \end{align} \]

\[ \begin{align} \frac{\partial}{\partial x} \lim_{S(x) \rightarrow 1} CE(0, p) & = \frac{S(x)^2(1-S(x))}{S(x)(1-S(x))} \end{align} \]

– 除此之外,對於Discriminator以及一般網路隨著\(p \rightarrow 1\)的過程中,他的梯度會慢慢變小,這有點學習率遞減的概念,但此時對於Generator而言卻是學習率遞增的,而學習率不見得比較大就收斂比較快!

第一節:解決對抗生成網路的訓練問題(3)

  1. 去除Sigmoid函數,因為它會在極端狀況下導致近似值計算失準。

  2. 無論Discriminator跟Generator誰佔優勢,選擇一個平滑的損失函數來描述目前的競賽狀況。

F02

註:數值越小代表越好!

第一節:解決對抗生成網路的訓練問題(4)

F01

\[ \begin{align} loss(y, x) & = (1-y)x - yx \end{align} \]

第二節:實作WGAN(1)

library(data.table)

DAT = fread("data/MNIST.csv", data.table = FALSE)
DAT = data.matrix(DAT)

#Split data

set.seed(0)
Train.sample = sample(1:nrow(DAT), nrow(DAT)*0.6, replace = FALSE)

Train.X = DAT[Train.sample,-1]
Train.Y = DAT[Train.sample,1]
Test.X = DAT[-Train.sample,-1]
Test.Y = DAT[-Train.sample,1]

fwrite(x = data.table(cbind(Train.Y, Train.X)),
       file = 'data/train_data.csv',
       col.names = FALSE, row.names = FALSE)

fwrite(x = data.table(cbind(Test.Y, Test.X)),
       file = 'data/test_data.csv',
       col.names = FALSE, row.names = FALSE)

sub_Train.DAT <- data.table(cbind(Train.Y, Train.X))[1:500,]

fwrite(x = sub_Train.DAT,
       file = 'data/sub_train_data.csv',
       col.names = FALSE, row.names = FALSE)

第二節:實作WGAN(2)

library(imager)
library(magrittr)
library(mxnet)

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    val.x <- mx.nd.array(val.x)
                                    
                                    digit.real <- mx.nd.array(val[1,])
                                    digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
                                    digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
                                      
                                    digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
                                    digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
                                    digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))

                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    
                                    label.real <- array(runif(10, 0, 0), dim = c(1, 1, 1, batch_size))
                                    label.real <- mx.nd.array(label.real)
                                    label.fake <- array(runif(10, 1, 1), dim = c(1, 1, 1, batch_size))
                                    label.fake <- mx.nd.array(label.fake)
                                    label.gen <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
                                    label.gen <- mx.nd.array(label.gen)
                                    
                                    list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

第二節:實作WGAN(3)

gen_data <- mx.symbol.Variable('data')
gen_digit <- mx.symbol.Variable('digit')

gen_concat <- mx.symbol.concat(data = list(gen_data, gen_digit), num.args = 2, dim = 1, name = "gen_concat")

gen_deconv1 <- mx.symbol.Deconvolution(data = gen_concat, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')

gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')

gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')

gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')
dis_img <- mx.symbol.Variable('img')
dis_digit <- mx.symbol.Variable("digit")
dis_label <- mx.symbol.Variable('label')

dis_concat <- mx.symbol.broadcast_mul(lhs = dis_img, rhs = dis_digit, name = 'dis_concat')

dis_conv1 <- mx.symbol.Convolution(data = dis_concat, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.2, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.2, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.2, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.2, name = "dis_relu4")

dis_pred <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_pred')
w_loss_pos <-  mx.symbol.broadcast_mul(dis_pred, dis_label)
w_loss_neg <-  mx.symbol.broadcast_mul(dis_pred, 1 - dis_label)
w_loss_mean <- mx.symbol.mean(w_loss_neg - w_loss_pos)
w_loss <- mx.symbol.MakeLoss(w_loss_mean, name = 'w_loss')

第二節:實作WGAN(4)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
                               data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32),
                               ctx = mx.gpu(), grad.req = "write")

dis_executor <- mx.simple.bind(symbol = w_loss,
                               img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32),
                               ctx = mx.gpu(), grad.req = "write")
# Initial parameters

mx.set.seed(0)

gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
                                        input.shape = list(data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.gpu())

dis_arg <- mxnet:::mx.model.init.params(symbol = w_loss,
                                        input.shape = list(img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.gpu())

# Update parameters

mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)

第二節:實作WGAN(5)

set.seed(0)
n.epoch <- 20
w_limit <- 0.1
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
  
  current_batch <- 0
  my_iter$reset()
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    # Generator (forward)
    
    mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']], digit = my_values[['digit.fake']]), match.name = TRUE)
    mx.exec.forward(gen_executor, is.train = TRUE)
    gen_pred_output <- gen_executor$ref.outputs[[1]]
    
    # Discriminator (fake)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.fake']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Discriminator (real)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], digit = my_values[['digit.real']], label = my_values[['label.real']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Weight clipping (only for discriminator)
    
    dis_weight_names <- grep('weight', names(dis_executor$ref.arg.arrays), value = TRUE)
    
    
    
    for (k in dis_weight_names) {
      
      current_dis_weight <- dis_executor$ref.arg.arrays[[k]] %>% as.array()
      current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
        mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
        mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
        list()
      names(current_dis_weight_list) <- k
      mx.exec.update.arg.arrays(dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
      
    }
    
    # Generator (backward)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.gen']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    img_grads <- dis_executor$ref.grad.arrays[['img']]
    mx.exec.backward(gen_executor, out_grads = img_grads)
    gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
    
    logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    if (current_batch %% 100 == 0) {
      
      # Show current images
      
      current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
      
      par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
      
      for (i in 1:9) {
        img <- as.array(gen_pred_output)[,,,i]
        plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
        rasterImage(as.raster(img), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
        text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
      }
      
      # Show loss
      
      message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
  
  current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
  
  par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  
  for (i in 1:9) {
    img <- as.array(gen_pred_output)[,,,i]
    plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
    rasterImage(as.raster(img), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
    text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
  }
  
  dev.off()
  
  gen_model <- list()
  gen_model$symbol <- gen_pred
  gen_model$arg.params <- gen_executor$ref.arg.arrays[-c(1:2)]
  gen_model$aux.params <- gen_executor$ref.aux.arrays
  class(gen_model) <- "MXFeedForwardModel"
  
  dis_model <- list()
  dis_model$symbol <- dis_pred
  dis_model$arg.params <- dis_executor$ref.arg.arrays[-c(1:2)]
  dis_model$aux.params <- dis_executor$ref.aux.arrays
  class(dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = gen_model, prefix = 'model/cwgen_v1', iteration = j)
  mx.model.save(model = dis_model, prefix = 'model/cwdis_v1', iteration = j)
  
}