對抗生成網路概述

林嶔 (Lin, Chin)

Lesson 15

原理簡介(1)

F15_1

– 從數學上來說,就是形成一個預測函數,而該函數的目標是做隨機亂數與新的物件的映射:

F15_2

原理簡介(2)

F15_3

– 但其實這兩個網路是不可能合在一起one-stage訓練的,你看得出為什麼嗎?

實現一個手寫數字產生器(1)

library(mxnet)

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    val.x <- mx.nd.array(val.x)
                                    val.y.0 <- array(rep(0, batch_size), dim = c(1, 1, 1, batch_size))
                                    val.y.0 <- mx.nd.array(val.y.0)
                                    val.y.1 <- array(rep(1, batch_size), dim = c(1, 1, 1, batch_size))
                                    val.y.1 <- mx.nd.array(val.y.1)
                                    list(noise = rand, img = val.x, label.0 = val.y.0, label.1 = val.y.1)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

實現一個手寫數字產生器(2)

– 首先定義Generator:

gen_data <- mx.symbol.Variable('data')

gen_deconv1 <- mx.symbol.Deconvolution(data = gen_data, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')

gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')

gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')

gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')

– 接著定義Discriminator:

dis_img <- mx.symbol.Variable('img')
dis_label <- mx.symbol.Variable('label')

dis_conv1 <- mx.symbol.Convolution(data = dis_img, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.25, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.25, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.25, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.25, name = "dis_relu4")

dis_conv5 <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_conv5')
dis_pred <- mx.symbol.sigmoid(data = dis_conv5, name = 'dis_pred')

– 我們再來定義Loss function,只有Discriminator有Loss function:

eps <- 1e-8
ce_loss_pos <-  mx.symbol.broadcast_mul(mx.symbol.log(dis_pred + eps), dis_label)
ce_loss_neg <-  mx.symbol.broadcast_mul(mx.symbol.log(1 - dis_pred + eps), 1 - dis_label)
ce_loss_mean <- 0 - mx.symbol.mean(ce_loss_pos + ce_loss_neg)
ce_loss <- mx.symbol.MakeLoss(ce_loss_mean, name = 'ce_loss')

實現一個手寫數字產生器(3)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
                               data = c(1, 1, 10, 32),
                               ctx = mx.cpu(), grad.req = "write")

dis_executor <- mx.simple.bind(symbol = ce_loss,
                               img = c(28, 28, 1, 32), label = c(1, 1, 1, 32),
                               ctx = mx.cpu(), grad.req = "write")
# Initial parameters

mx.set.seed(0)

gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
                                        input.shape = list(data = c(1, 1, 10, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

dis_arg <- mxnet:::mx.model.init.params(symbol = ce_loss,
                                        input.shape = list(img = c(28, 28, 1, 32), label = c(1, 1, 1, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

# Update parameters

mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)

實現一個手寫數字產生器(4)

# Generate data

my_iter$reset()
my_iter$iter.next()
## [1] TRUE
my_values <- my_iter$value()

# Generator (forward)
    
mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']]), match.name = TRUE)
mx.exec.forward(gen_executor, is.train = TRUE)
gen_pred_output <- gen_executor$ref.outputs[[1]]

# Discriminator (fake)
    
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.0']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)

# Discriminator (real)
    
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], label = my_values[['label.1']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)

# Generator (backward)

mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.1']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
img_grads <- dis_executor$ref.grad.arrays[['img']]
mx.exec.backward(gen_executor, out_grads = img_grads)
gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
library(imager)

par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:9) {
  img <- as.array(gen_pred_output)[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

實現一個手寫數字產生器(5)

set.seed(0)
n.epoch <- 20
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
  
  current_batch <- 0
  my_iter$reset()
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    # Generator (forward)
    
    mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']]), match.name = TRUE)
    mx.exec.forward(gen_executor, is.train = TRUE)
    gen_pred_output <- gen_executor$ref.outputs[[1]]
    
    # Discriminator (fake)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.0']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Discriminator (real)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], label = my_values[['label.1']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Generator (backward)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.1']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    img_grads <- dis_executor$ref.grad.arrays[['img']]
    mx.exec.backward(gen_executor, out_grads = img_grads)
    gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
    
    logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    if (current_batch %% 100 == 0) {
      
      # Show current images
      
      par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
      for (i in 1:9) {
        img <- as.array(gen_pred_output)[,,,i]
        plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
        rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
      }
      
      # Show loss
      
      message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
  par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  for (i in 1:9) {
    img <- as.array(gen_pred_output)[,,,i]
    plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
    rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
  }
  dev.off()
  
  gen_model <- list()
  gen_model$symbol <- gen_pred
  gen_model$arg.params <- gen_executor$ref.arg.arrays[-1]
  gen_model$aux.params <- gen_executor$ref.aux.arrays
  class(gen_model) <- "MXFeedForwardModel"
  
  dis_model <- list()
  dis_model$symbol <- dis_pred
  dis_model$arg.params <- dis_executor$ref.arg.arrays[-1]
  dis_model$aux.params <- dis_executor$ref.aux.arrays
  class(dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = gen_model, prefix = 'model/gen_v1', iteration = j)
  mx.model.save(model = dis_model, prefix = 'model/dis_v1', iteration = j)
  
}