下載資料

– 與之前所有任務相同,這個任務的資料下載是需要經過申請的,請你找助教申請帳號。

讀取資料

library(data.table)

train_dat <- fread('train.csv', data.table = FALSE)
test_dat <- fread('test.csv', data.table = FALSE)
submit_dat <- fread('sample_submission.csv', data.table = FALSE)

熟悉資料型態

head(train_dat)
##             obj_name  col_left col_right   row_bot     row_top img_id
## 1             Blasts 0.3877827 0.4932385 0.1347430 0.008709886  U0007
## 2            Myeloid 0.5005113 0.6077853 0.1269871 0.022282680  U0007
## 3           Monocyte 0.7023318 0.7841510 0.2181187 0.107597387  U0007
## 4           Monocyte 0.7641507 0.8568791 0.1754614 0.082390769  U0007
## 5           Monocyte 0.8896067 0.9732441 0.1386209 0.039733416  U0007
## 6 Unable to identify 0.8496063 0.9277890 0.1967900 0.103719446  U0007
library(OpenImageR)
library(imager)

boxtext <- function(x, y, labels = NA, col.text = NULL, col.bg = NA, 
                    border.bg = NA, adj = NULL, pos = NULL, offset = 0, 
                    padding = c(0.5, 0.5), cex = 1, font = graphics::par('font')){
  
  ## The Character expansion factro to be used:
  theCex <- graphics::par('cex')*cex
  
  ## Is y provided:
  if (missing(y)) y <- x
  
  ## Recycle coords if necessary:    
  if (length(x) != length(y)){
    lx <- length(x)
    ly <- length(y)
    if (lx > ly){
      y <- rep(y, ceiling(lx/ly))[1:lx]           
    } else {
      x <- rep(x, ceiling(ly/lx))[1:ly]
    }       
  }
  
  ## Width and height of text
  textHeight <- graphics::strheight(labels, cex = theCex, font = font)
  textWidth <- graphics::strwidth(labels, cex = theCex, font = font)
  
  ## Width of one character:
  charWidth <- graphics::strwidth("e", cex = theCex, font = font)
  
  ## Is 'adj' of length 1 or 2?
  if (!is.null(adj)){
    if (length(adj == 1)){
      adj <- c(adj[1], 0.5)            
    }        
  } else {
    adj <- c(0.5, 0.5)
  }
  
  ## Is 'pos' specified?
  if (!is.null(pos)){
    if (pos == 1){
      adj <- c(0.5, 1)
      offsetVec <- c(0, -offset*charWidth)
    } else if (pos == 2){
      adj <- c(1, 0.5)
      offsetVec <- c(-offset*charWidth, 0)
    } else if (pos == 3){
      adj <- c(0.5, 0)
      offsetVec <- c(0, offset*charWidth)
    } else if (pos == 4){
      adj <- c(0, 0.5)
      offsetVec <- c(offset*charWidth, 0)
    } else {
      stop('Invalid argument pos')
    }       
  } else {
    offsetVec <- c(0, 0)
  }
  
  ## Padding for boxes:
  if (length(padding) == 1){
    padding <- c(padding[1], padding[1])
  }
  
  ## Midpoints for text:
  xMid <- x + (-adj[1] + 1/2)*textWidth + offsetVec[1]
  yMid <- y + (-adj[2] + 1/2)*textHeight + offsetVec[2]
  
  ## Draw rectangles:
  rectWidth <- textWidth + 2*padding[1]*charWidth
  rectHeight <- textHeight - 2*padding[2]*charWidth    
  graphics::rect(xleft = xMid - rectWidth/2, 
                 ybottom = yMid - rectHeight/2, 
                 xright = xMid + rectWidth/2, 
                 ytop = yMid + rectHeight/2,
                 col = col.bg, border = border.bg)
  
  ## Place the text:
  graphics::text(xMid, yMid, labels, col = col.text, cex = theCex, font = font, 
                 adj = c(0.5, 0.5))    
  
  ## Return value:
  if (length(xMid) == 1){
    invisible(c(xMid - rectWidth/2, xMid + rectWidth/2, yMid - rectHeight/2,
                yMid + rectHeight/2))
  } else {
    invisible(cbind(xMid - rectWidth/2, xMid + rectWidth/2, yMid - rectHeight/2,
                    yMid + rectHeight/2))
  }    
}

Show_img <- function (img, box_info = NULL, show_prob = TRUE,
                      col_bbox_list = rainbow(8, alpha = 0.5),
                      selections = c("Erythroid", "Blasts", "Myeloid", "Lymphoid", "Plasma cells", "Monocyte", "Megakaryocyte", "Unable to identify"),
                      text_cex = 1.5, plot_xlim = c(0.04, 0.96), plot_ylim = c(0.96, 0.04)) {
  
  par(mar = rep(0, 4))
  plot(NA, xlim = plot_xlim, ylim = plot_ylim, xaxt = "n", yaxt = "n", bty = "n")
  img = (img - min(img))/(max(img) - min(img))
  img = as.raster(img)
  rasterImage(img, 0, 1, 1, 0, interpolate=FALSE)
  
  if (!is.null(box_info)) {
    if (mean(c('obj_name', 'col_left', 'col_right', 'row_bot', 'row_top') %in% colnames(box_info)) == 1) {
      if (nrow(box_info) > 0) {
        if (!'prob' %in% colnames(box_info)) {box_info[,'prob'] <- 1L}
        box_info <- box_info[order(box_info[,'prob']),]
        for (i in 1:nrow(box_info)) {
          col_label <- col_bbox_list[which(selections %in% box_info[i,'obj_name'])]
          if (length(col_label) != 1) {col_label <- '#00000080'}
          size <- max(box_info[i,'col_right'] - box_info[i,'col_left'], 0.2)
          rect(xleft = box_info[i,'col_left'], xright = box_info[i,'col_right'],
               ybottom = box_info[i,'row_bot'], ytop = box_info[i,'row_top'],
               col = '#FFFFFF00', border = col_label, lwd = 5*sqrt(size))
          if (text_cex > 0) {
            if (show_prob) {current_label <- paste0(box_info[i,'obj_name'], '(', round(box_info[i,'prob'] * 100),'%)')} else {current_label <- box_info[i,1]}
            boxtext(x = box_info[i,'col_left'], y = box_info[i,'row_top'], labels = current_label, 
                    col.bg = col_label, col.text = 'white', adj = c(0, 0.6), font = 2, cex = text_cex*sqrt(size))
          }
        }
      }
    }
  }
  
}
selected_img <- 'U0007'

sub_box_info <- train_dat[train_dat[,'img_id'] %in% selected_img,]
img <- readImage(paste0('image/', selected_img, '.jpg'))

Show_img(img = img, box_info = sub_box_info)

selected_img <- 'U0059'

sub_box_info <- submit_dat[submit_dat[,'img_id'] %in% selected_img,]
img <- readImage(paste0('image/', selected_img, '.jpg'))

Show_img(img = img, box_info = sub_box_info)

Yolo v3

資料處理

  • 我們需要知道這個研究中,每個物件框對應於物件種類的關係:
obj_ids <- 1:8
names(obj_ids) <- c("Erythroid", "Blasts", "Myeloid", "Lymphoid", "Plasma cells", "Monocyte", "Megakaryocyte", "Unable to identify")

box_info <- train_dat

box_info[,'obj_id'] <- obj_ids[box_info[,'obj_name']]
box_info[,'bbox_center_row'] <- (box_info[,'row_bot'] + box_info[,'row_top']) / 2
box_info[,'bbox_center_col'] <- (box_info[,'col_right'] + box_info[,'col_left']) / 2
box_info[,'bbox_width'] <- box_info[,'col_right'] - box_info[,'col_left']
box_info[,'bbox_height'] <- box_info[,'row_bot'] - box_info[,'row_top']

plot(box_info[,'bbox_width'] * 2048, box_info[,'bbox_height'] * 1920, col = rainbow(8, alpha = 0.5)[box_info[,'obj_id']],
     xlab = 'width', ylab = 'height', log = 'xy', pch = 19, cex = 0.7)

legend('bottomright', names(obj_ids), col = rainbow(8, alpha = 0.5), pch = 19)

  • 你應你應該有注意到紫色的點(Megakaryocyte)特別的遠,這是因為這種細胞特別的大,讓我們選擇找一張有這種細胞的圖來看看:
selected_img <- 'U0072'

sub_box_info <- train_dat[train_dat[,'img_id'] %in% selected_img,]
img <- readImage(paste0('image/', selected_img, '.jpg'))

Show_img(img = img, box_info = sub_box_info)

  • 所以,在選擇錨框(anchor box)的時候,我們可以強硬的將Megakaryocyte與其他種類的分開。

– 這樣,錨框的像素大小應該是[193 185]與[570 521],我們可以使用降維64倍的高階特徵圖針對第一個錨框,並且使用降維128倍的高階特徵圖針對第二個錨框:

anchor_box.1 <- data.frame(width = exp(mean(log(box_info[box_info[,'obj_id'] != 7, 'bbox_width']))),
                            height = exp(mean(log(box_info[box_info[,'obj_id'] != 7, 'bbox_height']))),
                            rank = 1, lvl = 1, seq = 1)
anchor_box.2 <- data.frame(width = exp(mean(log(box_info[box_info[,'obj_id'] == 7, 'bbox_width']))),
                            height = exp(mean(log(box_info[box_info[,'obj_id'] == 7, 'bbox_height']))),
                            rank = 2, lvl = 2, seq = 1)

anchor_boxs <- rbind(anchor_box.1, anchor_box.2)

box_info[,'anchor_width'] <- anchor_boxs[(box_info[,'obj_id'] %in% 7) + 1, 'width']
box_info[,'anchor_height'] <- anchor_boxs[(box_info[,'obj_id'] %in% 7) + 1, 'height']
box_info[,'rank'] <- anchor_boxs[(box_info[,'obj_id'] %in% 7) + 1, 'rank']
box_info[,'lvl'] <- anchor_boxs[(box_info[,'obj_id'] %in% 7) + 1, 'lvl']
box_info[,'seq'] <- anchor_boxs[(box_info[,'obj_id'] %in% 7) + 1, 'seq']

– 在這樣的前提下,我們可以使用這裡的Encode函數,將原來的框資訊進行編碼,先讀取這些函數:

# Libraries

library(abind)

# Custom function

# Note: this function made some efforts to keep the coordinate system consistent.
# The major challenge is that 'bottomleft' is the original point of "plot" function,
# but the original point of image is 'topleft'
# The Show_img function can help us to encode the bbox info

IoU_function <- function (label, pred) {
  
  overlap_width <- min(label[,2], pred[,2]) - max(label[,1], pred[,1])
  overlap_height <- min(label[,3], pred[,3]) - max(label[,4], pred[,4])
  
  if (overlap_width > 0 & overlap_height > 0) {
    
    pred_size <- (pred[,2]-pred[,1])*(pred[,3]-pred[,4])
    label_size <- (label[,2]-label[,1])*(label[,3]-label[,4])
    overlap_size <- overlap_width * overlap_height
    
    return(overlap_size/(pred_size + label_size - overlap_size))
    
  } else {
    
    return(0)
    
  }
  
}

Encode_fun <- function (box_info, n.grid.row = c(30, 15), n.grid.col = c(32, 16),
                        img_ids = NULL, eps = 1e-8, n.anchor = 1, n.obj = 8) {
  
  if (is.null(img_ids)) {
    
    img_ids <- unique(box_info$img_id)
    
  }
  
  num_pred <- 5 + n.obj
  
  out_array_list <- list()
  
  if (length(img_ids) == 0) {
    
    for (k in 1:length(n.grid.row)) {
      
      out_array_list[[k]] <- array(0, dim = c(n.grid.row[k], n.grid.col[k], n.anchor * num_pred, 1))
      
    }
    
  } else {
    
    for (k in 1:length(n.grid.row)) {
      
      out_array_list[[k]] <- array(0, dim = c(n.grid.row[k], n.grid.col[k], n.anchor * num_pred, length(img_ids)))
      
    }
    
    for (j in 1:length(img_ids)) {
      
      sub_box_info <- box_info[box_info$img_id == img_ids[j],]
      
      for (k in 1:length(n.grid.row)) {
        
        if (k %in% sub_box_info$lvl) {
          
          rescale_box_info <- sub_box_info[sub_box_info$lvl == k,c('obj_id', 'bbox_center_row', 'bbox_center_col', 'bbox_width', 'bbox_height', 'anchor_width', 'anchor_height', 'seq')]
          rescale_box_info[,c('bbox_center_row', 'bbox_height', 'anchor_height')] <- rescale_box_info[,c('bbox_center_row', 'bbox_height', 'anchor_height')] * n.grid.row[k]
          rescale_box_info[,c('bbox_center_col', 'bbox_width', 'anchor_width')] <- rescale_box_info[,c('bbox_center_col', 'bbox_width', 'anchor_width')] * n.grid.col[k]
          
          for (i in 1:nrow(rescale_box_info)) {
            
            center_row <- ceiling(rescale_box_info[i,'bbox_center_row'])
            center_col <- ceiling(rescale_box_info[i,'bbox_center_col'])
            
            row_related_pos <- rescale_box_info[i,'bbox_center_row'] %% 1
            row_related_pos[row_related_pos == 0] <- 1
            col_related_pos <- rescale_box_info[i,'bbox_center_col'] %% 1
            col_related_pos[col_related_pos == 0] <- 1
            
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+1,j] <- 1
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+2,j] <- row_related_pos
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+3,j] <- col_related_pos
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+4,j] <- log(rescale_box_info$bbox_width[i]/rescale_box_info$anchor_width[i] + eps)
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+5,j] <- log(rescale_box_info$bbox_height[i]/rescale_box_info$anchor_height[i] + eps)
            out_array_list[[k]][center_row,center_col,(rescale_box_info[i,'seq']-1)*num_pred+5+rescale_box_info[i,'obj_id'],j] <- 1 
            
          }
          
        }
        
      }
      
    }
    
  }
  
  return(out_array_list)
  
}

Decode_fun <- function (encode_array_list, anchor_boxs = anchor_boxs,
                        cut_prob = 0.5, cut_overlap = 0.3,
                        obj_name = names(obj_ids),
                        obj_col = rainbow(8, alpha = 0.5),
                        multiply_prob = FALSE,
                        img_id_list = NULL,
                        remove_all_overlap = TRUE) {
  
  num_list <- length(encode_array_list)
  num_img <- dim(encode_array_list[[1]])[4]
  num_feature <- length(obj_name) + 5
  pos_start <- (0:(dim(encode_array_list[[1]])[3]/num_feature-1)*num_feature)
  
  box_info <- NULL
  
  # Decoding
  
  for (j in 1:num_img) {
    
    sub_box_info <- NULL
    
    for (k in 1:num_list) {
      
      for (i in 1:length(pos_start)) {
        
        sub_encode_array <- as.array(encode_array_list[[k]])[,,pos_start[i]+1:num_feature,j]
        if (multiply_prob) {sub_encode_array[,,1] <- sub_encode_array[,,1] * apply(sub_encode_array[,,6:num_feature], 1:2, max)}
        
        pos_over_cut <- which(sub_encode_array[,,1] >= cut_prob)
        
        if (length(pos_over_cut) >= 1) {
          
          pos_over_cut_row <- pos_over_cut %% dim(sub_encode_array)[1]
          pos_over_cut_row[pos_over_cut_row == 0] <- dim(sub_encode_array)[1]
          pos_over_cut_col <- ceiling(pos_over_cut/dim(sub_encode_array)[1])
          anchor_box <- anchor_boxs[anchor_boxs$lvl == k & anchor_boxs$seq == i, 1:2]
          
          for (l in 1:length(pos_over_cut)) {
            
            encode_vec <- sub_encode_array[pos_over_cut_row[l],pos_over_cut_col[l],]
            
            if (encode_vec[2] < 0) {encode_vec[2] <- 0}
            if (encode_vec[2] > 1) {encode_vec[2] <- 1}
            if (encode_vec[3] < 0) {encode_vec[3] <- 0}
            if (encode_vec[3] > 1) {encode_vec[3] <- 1}
            
            center_row <- (encode_vec[2] + (pos_over_cut_row[l] - 1))/dim(sub_encode_array)[1]
            center_col <- (encode_vec[3] + (pos_over_cut_col[l] - 1))/dim(sub_encode_array)[2]
            width <- exp(encode_vec[4]) * anchor_box[1,1]
            height <- exp(encode_vec[5]) * anchor_box[1,2]
            
            if (is.null(img_id_list)) {new_img_id <- j} else {new_img_id <- img_id_list[j]}
            if (multiply_prob) {current_prob <- encode_vec[1]} else {current_prob <- encode_vec[1] * max(encode_vec[-c(1:5)])}
            
            new_box_info <- data.frame(obj_name = obj_name[which.max(encode_vec[-c(1:5)])],
                                       col_left = center_col-width/2,
                                       col_right = center_col+width/2,
                                       row_bot = center_row+height/2,
                                       row_top = center_row-height/2,
                                       prob = current_prob,
                                       img_id = new_img_id,
                                       col = obj_col[which.max(encode_vec[-c(1:5)])],
                                       stringsAsFactors = FALSE)
            
            sub_box_info <- rbind(sub_box_info, new_box_info)
            
          }
          
        }
        
      }
      
    }
    
    if (!is.null(sub_box_info)) {
      
      # Remove overlapping
      
      sub_box_info <- sub_box_info[order(sub_box_info$prob, decreasing = TRUE),]
      
      if (remove_all_overlap) {
        
        obj_sub_box_info <- sub_box_info
        
        if (nrow(obj_sub_box_info) <= 1) {
          
          box_info <- rbind(box_info, obj_sub_box_info)
          
        } else {
          
          overlap_seq <- NULL
          
          for (m in 2:nrow(obj_sub_box_info)) {
            
            for (n in 1:(m-1)) {
              
              if (!n %in% overlap_seq) {
                
                overlap_prob <- IoU_function(label = obj_sub_box_info[m,2:5], pred = obj_sub_box_info[n,2:5])
                
                overlap_width <- min(obj_sub_box_info[m,3], obj_sub_box_info[n,3]) - max(obj_sub_box_info[m,2], obj_sub_box_info[n,2])
                overlap_height <- min(obj_sub_box_info[m,4], obj_sub_box_info[n,4]) - max(obj_sub_box_info[m,5], obj_sub_box_info[n,5])
                
                if (overlap_prob >= cut_overlap) {
                  
                  overlap_seq <- c(overlap_seq, m)
                  
                }
                
              }
              
            }
            
          }
          
          if (!is.null(overlap_seq)) {
            
            obj_sub_box_info <- obj_sub_box_info[-overlap_seq,]
            
          }
          
          box_info <- rbind(box_info, obj_sub_box_info)
          
        }
        
      } else {
        
        for (obj in unique(sub_box_info$obj_name)) {
          
          obj_sub_box_info <- sub_box_info[sub_box_info$obj_name == obj,]
          
          if (nrow(obj_sub_box_info) == 1) {
            
            box_info <- rbind(box_info, obj_sub_box_info)
            
          } else {
            
            overlap_seq <- NULL
            
            for (m in 2:nrow(obj_sub_box_info)) {
              
              for (n in 1:(m-1)) {
                
                if (!n %in% overlap_seq) {
                  
                  overlap_prob <- IoU_function(label = obj_sub_box_info[m,2:5], pred = obj_sub_box_info[n,2:5])
                  
                  overlap_width <- min(obj_sub_box_info[m,3], obj_sub_box_info[n,3]) - max(obj_sub_box_info[m,2], obj_sub_box_info[n,2])
                  overlap_height <- min(obj_sub_box_info[m,4], obj_sub_box_info[n,4]) - max(obj_sub_box_info[m,5], obj_sub_box_info[n,5])
                  
                  if (overlap_prob >= cut_overlap) {
                    
                    overlap_seq <- c(overlap_seq, m)
                    
                  }
                  
                }
                
              }
              
            }
            
            if (!is.null(overlap_seq)) {
              
              obj_sub_box_info <- obj_sub_box_info[-overlap_seq,]
              
            }
            
            box_info <- rbind(box_info, obj_sub_box_info)
            
          }
          
        }
        
      }
      
    }
    
  }
  
  return(box_info)
  
}
  • 讓我們試著對一張圖的所有框進行編碼:
selected_img <- 'U0007'

sub_box_info <- box_info[box_info[,'img_id'] %in% selected_img,]
encode_list <- Encode_fun(box_info = sub_box_info, img_ids = selected_img)
  • 讓我們還原這個編碼,並跟原圖做比較:
decode_box_info <- Decode_fun(encode_array_list = encode_list,
                              anchor_boxs = anchor_boxs,
                              img_id_list = selected_img)

img <- readImage(paste0('image/', selected_img, '.jpg'))
Show_img(img = img, box_info = decode_box_info)

  • 有了這些基礎資料後,我們將能建構Iterator,這裡是他的函數:
library(abind)
library(jpeg)
library(mxnet)

img_dir <- 'image/'
sample_ids <- unique(box_info[,'img_id'])

my_iterator_core <- function (batch_size, aug_col = TRUE, crop_size = 256, aug_flip = TRUE) {
  
  batch <-  0
  
  batch_per_epoch <- floor(length(sample_ids)/batch_size)
  
  reset <- function() {batch <<- 0}
  
  iter.next <- function() {
    
    batch <<- batch + 1
    if (batch > batch_per_epoch) {return(FALSE)} else {return(TRUE)}
    
  }
  
  value <- function() {
    
    idx <- 1:batch_size + (batch - 1) * batch_size
    idx[idx > length(sample_ids)] <- sample(1:(idx[1]-1), sum(idx > length(sample_ids)))
    idx <- sort(idx)
    
    batch.box_info <- box_info[box_info[,'img_id'] %in% sample_ids[idx],]
    
    img_array_list <- list()
    
    for (i in 1:batch_size) {
      
      img_array_list[[i]] <- readJPEG(paste0(img_dir, sample_ids[idx[i]], '.jpg'))
      
    }
    
    img_array <- abind(img_array_list, along = 4)
    
    if (aug_flip) {
      
      if (sample(c(TRUE, FALSE), 1)) {
        
        img_array <- img_array[,dim(img_array)[2]:1,,,drop = FALSE]
        batch.box_info[,'bbox_center_col'] <- 1 - batch.box_info[,'bbox_center_col']
        
      }
      
      if (sample(c(TRUE, FALSE), 1)) {
        
        img_array <- img_array[dim(img_array)[1]:1,,,,drop = FALSE]
        batch.box_info[,'bbox_center_row'] <- 1 - batch.box_info[,'bbox_center_row']
        
      }
      
    }
    
    if (crop_size >= 128) {
      
      crop_size <- round(crop_size / 128) * 128 
      revised_dim <- dim(img_array)[1:2] - crop_size
      
      random.row <- sample(0:crop_size, 1)
      random.col <- sample(0:crop_size, 1)
      
      img_array <- img_array[random.row+1:revised_dim[1],random.col+1:revised_dim[2],,,drop = FALSE]
      
      batch.box_info[,c('bbox_center_col', 'bbox_width', 'anchor_width')] <- batch.box_info[,c('bbox_center_col', 'bbox_width', 'anchor_width')] * (revised_dim[2] + crop_size) / revised_dim[2]
      batch.box_info[,c('bbox_center_row', 'bbox_height', 'anchor_height')] <- batch.box_info[,c('bbox_center_row', 'bbox_height', 'anchor_height')] * (revised_dim[1] + crop_size) / revised_dim[1]
      
      batch.box_info[,'bbox_center_col'] <- batch.box_info[,'bbox_center_col'] - random.col / revised_dim[2]
      batch.box_info[,'bbox_center_row'] <- batch.box_info[,'bbox_center_row'] - random.row / revised_dim[1]
      
      batch.box_info[,'col_left'] <- batch.box_info[,'bbox_center_col'] - batch.box_info[,'bbox_width'] / 2
      batch.box_info[,'col_right'] <- batch.box_info[,'bbox_center_col'] + batch.box_info[,'bbox_width'] / 2
      batch.box_info[,'row_bot'] <- batch.box_info[,'bbox_center_row'] + batch.box_info[,'bbox_height'] / 2
      batch.box_info[,'row_top'] <- batch.box_info[,'bbox_center_row'] - batch.box_info[,'bbox_height'] / 2
      
      save_pos <- which(batch.box_info[,'col_left'] <= 1 & batch.box_info[,'col_right'] >= 0 & batch.box_info[,'row_top'] <= 1 & batch.box_info[,'row_bot'] >= 0)
      batch.box_info <- batch.box_info[save_pos,]
      batch.box_info[batch.box_info[,'col_left'] < 0,'col_left'] <- 0
      batch.box_info[batch.box_info[,'col_right'] > 1,'col_right'] <- 1
      batch.box_info[batch.box_info[,'row_top'] < 0,'row_top'] <- 0
      batch.box_info[batch.box_info[,'row_bot'] > 1,'row_bot'] <- 1
      
      batch.box_info[,'bbox_center_col'] <- (batch.box_info[,'col_right'] + batch.box_info[,'col_left']) / 2
      batch.box_info[,'bbox_center_row'] <- (batch.box_info[,'row_bot'] + batch.box_info[,'row_top']) / 2
      
      batch.box_info[,'bbox_width'] <- (batch.box_info[,'col_right'] - batch.box_info[,'col_left'])
      batch.box_info[,'bbox_height'] <- (batch.box_info[,'row_bot'] - batch.box_info[,'row_top'])
      
      save_pos <- which(batch.box_info[,'bbox_width'] > 0 & batch.box_info[,'bbox_height'] > 0)
      batch.box_info <- batch.box_info[save_pos,]
      
    } 
    
    img_array <- mx.nd.array(img_array)
    
    if (aug_col) {
      
      add_item <- mx.nd.array(array(runif(batch_size * 1, min = -0.05, max = 0.05), dim = c(1, 1, 1, batch_size)))
      img_array <- mx.nd.broadcast.add(img_array, add_item)
      
      mul_item <- mx.nd.array(array(exp(runif(batch_size * 1, min = -0.2, max = 0.2)), dim = c(1, 1, 1, batch_size)))
      img_array <- mx.nd.broadcast.mul(img_array, mul_item)
      
      pow_item <- mx.nd.array(array(exp(runif(batch_size * 1, min = -0.2, max = 0.2)), dim = c(1, 1, 1, batch_size)))
      img_array <- mx.nd.broadcast.power(img_array, pow_item)
      
    }
    
    img_array <- mx.nd.broadcast.maximum(img_array, mx.nd.array(array(0, dim = c(1, 1, 1, 1))))
    img_array <- mx.nd.broadcast.minimum(img_array, mx.nd.array(array(1, dim = c(1, 1, 1, 1))))
    
    label <- Encode_fun(box_info = batch.box_info,
                        n.grid.row = dim(img_array)[1]/c(64, 128),
                        n.grid.col = dim(img_array)[2]/c(64, 128),
                        img_ids = sample_ids[idx],
                        eps = 1e-8, n.anchor = 1, n.obj = 8)
    
    result_list <- list()
    result_list[[1]] <- img_array
    for (k in 1:length(label)) {result_list[[k+1]] <- mx.nd.array(label[[k]])}
    names(result_list) <- c('data', paste0('label', 1:length(label)))
    
    return(result_list)
    
  }
  
  return(list(reset = reset, iter.next = iter.next, value = value, batch_size = batch_size, batch = batch, sample_ids = sample_ids))
  
}

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "batch_size", "aug_col", "crop_size", "aug_flip"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, batch_size = 4, aug_col = TRUE, crop_size = 256, aug_flip = TRUE) {
                                    .self$iter <- my_iterator_core(batch_size = batch_size, aug_col = aug_col, crop_size = crop_size, aug_flip = aug_flip)
                                    .self
                                  },
                                  value = function(){
                                    .self$iter$value()
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)
  • 讓我們還原這個編碼,並跟原圖做比較:
# Build the iterator

my_iter <- my_iterator_func(iter = NULL, batch_size = 2, aug_col = TRUE, crop_size = 256, aug_flip = TRUE)

# Use the iterator

my_iter$reset()
my_iter$iter.next()
## [1] TRUE
test <- my_iter$value()

# Select an image

img_seq <- 1
iter_img <- as.array(test$data)[,,,img_seq]

# If you use 'aug_crop = TRUE', you need to revise the anchor_boxs

revised_anchor_boxs <- anchor_boxs
revised_anchor_boxs[,1] <- revised_anchor_boxs[,1] * 2048 / (2048 - 256)
revised_anchor_boxs[,2] <- revised_anchor_boxs[,2] * 1920 / (1920 - 256)

label_list <- test[-1]
iter_box_info <- Decode_fun(encode_array_list = label_list, anchor_boxs = revised_anchor_boxs)

# Show image

Show_img(img = iter_img, box_info = iter_box_info[iter_box_info[,'img_id'] %in% img_seq,])

模型結構

  • 這裡當然還是先使用預訓練的影像模型進行分析,我們在這裡提供了一個已經預訓練好的模型列表

– 這個範例使用「resnet-50」進行,這是一個50層深的ResNet,非常標準的深度神經網路。我們要建構新的最後一層的函數:

library(magrittr)
library(mxnet)

res_model <- mx.model.load(prefix = "resnet-50", iteration = 0)
all_layers <- res_model$symbol$get.internals()
relu1_output <- which(all_layers$outputs == 'relu1_output') %>% all_layers$get.output()

# Convolution layer for specific mission and training new parameters

CONV_Unit <- function (name, indata, num_filter, num_group,
                       kernel, stride, pad,
                       with_relu, bn_momentum) {
  
  conv <- mx.symbol.Convolution(name = name, data = indata, num.filter = num_filter, num.group = num_group,
                                kernel = kernel, stride = stride, pad = pad,
                                no.bias = TRUE)
  bn <- mx.symbol.BatchNorm(name = paste0(name, '_bn'), data = conv,
                            fix_gamma = FALSE, momentum = bn_momentum, eps = 2e-5)
  
  if (with_relu) {
    relu <- mx.symbol.Activation(data = bn, act.type = 'relu', name = paste0(name, '_relu'))
    return(relu)
  } else {
    return(bn)
  }
  
}

SE_unit <- function (name, indata, num_filter, se_coef) {
  
  pool <- mx.symbol.mean(name = paste0(name, '_pool'), data = indata, axis = c(2, 3), keepdims = TRUE)
  conv1 <- mx.symbol.Convolution(name = paste0(name, '_conv1'), data = pool, num.filter = num_filter * se_coef, num.group = 1,
                                 kernel = c(1, 1), stride = c(1, 1), pad = c(0, 0),
                                 no.bias = FALSE)
  relu <- mx.symbol.Activation(data = conv1, act.type = 'relu', name = paste0(name, '_relu'))
  conv2 <- mx.symbol.Convolution(name = paste0(name, '_conv2'), data = relu, num.filter = num_filter, num.group = 1,
                                 kernel = c(1, 1), stride = c(1, 1), pad = c(0, 0),
                                 no.bias = FALSE)
  sigmoid <- mx.symbol.Activation(data = conv2, act_type = "sigmoid", name = paste0(name, '_sigmoid'))
  scaled_data <- mx.symbol.broadcast_mul(lhs = indata, rhs = sigmoid, name = paste0(name, '_scaled_data'))
  
  return(scaled_data)
  
}

CONV_Module <- function(name, indata, num_filter, num_group, squeeze_coef,
                        shotcut = TRUE, dim_match = TRUE, bn_momentum,
                        se_coef = NULL) {
  
  conv1 <- CONV_Unit(name = paste0(name, '_conv1'), indata = indata, num_filter = num_filter * squeeze_coef, num_group = 1,
                     kernel = c(1, 1), stride = c(1, 1), pad = c(0, 0),
                     with_relu = TRUE, bn_momentum = bn_momentum)
  
  conv2 <- CONV_Unit(name = paste0(name, '_conv2'), indata = conv1, num_filter = num_filter * squeeze_coef, num_group = num_group,
                     kernel = c(3, 3), stride = c(2, 2) - dim_match, pad = c(1, 1),
                     with_relu = TRUE, bn_momentum = bn_momentum)
  
  conv3 <- CONV_Unit(name = paste0(name, '_conv3'), indata = conv2, num_filter = num_filter, num_group = 1,
                     kernel = c(1, 1), stride = c(1, 1), pad = c(0, 0),
                     with_relu = FALSE, bn_momentum = bn_momentum)
  
  if (!is.null(se_coef)) {
    
    conv3 <- SE_unit(name = paste0(name, '_se'), indata = conv3, num_filter = num_filter, se_coef = se_coef)
    
  }
  
  if (shotcut) {
    
    if (dim_match) {
      
      outdata <- mx.symbol.broadcast_plus(lhs = indata, rhs = conv3, name = paste0(name, '_plus'))
      
    } else {
      
      pool <- mx.symbol.Pooling(name = paste0(name, '_pool'), data = indata, global_pool = FALSE, pool_type = "max",
                                kernel = c(3, 3), pad = c(1, 1), stride = c(2, 2))
      
      outdata <- mx.symbol.concat(list(pool, conv3), dim = 1, num.args = 2, name = paste0(name, '_concat'))
      
    }
    
    return(outdata)
    
  } else {
    
    return(conv3)
    
  }
  
}

DECONV_function <- function (updata, downdata, num_filters = 256, name = 'lvl1') {
  
  deconv <- mx.symbol.Deconvolution(data = updata, kernel = c(2, 2), stride = c(2, 2), pad = c(0, 0),
                                    no.bias = TRUE, num.filter = num_filters,
                                    name = paste0(name, '_deconv'))
  deconv_bn <- mx.symbol.BatchNorm(data = deconv, fix_gamma = FALSE, name = paste0(name, '_deconv_bn'))
  deconv_relu <- mx.symbol.Activation(data = deconv_bn, act.type = 'relu', name = paste0(name, '_deconv_relu'))
  
  plus_map <- mx.symbol.broadcast_plus(lhs = deconv_relu, rhs = downdata, name = paste0(name, "_plus_map"))
  
  return(plus_map)
  
}

YOLO_map_function <- function (indata, final_map = 75, num_box = 3, drop = 0.2, name = 'lvl1') {
  
  dp <- mx.symbol.Dropout(data = indata, p = drop, name = paste0(name, '_drop'))
  
  conv <- mx.symbol.Convolution(data = dp, kernel = c(1, 1), stride = c(1, 1), pad = c(0, 0),
                                no.bias = FALSE, num.filter = final_map, name = paste0(name, '_linearmap'))
  
  inter_split <- mx.symbol.SliceChannel(data = conv, num_outputs = final_map,
                                        axis = 1, squeeze_axis = FALSE, name = paste0(name, "_inter_split"))
  
  new_list <- list()
  
  for (k in 1:final_map) {
    if (!(k %% num_box) %in% c(4:5)) {
      new_list[[k]] <- mx.symbol.Activation(inter_split[[k]], act.type = 'sigmoid', name = paste0(name, "_yolomap_", k))
    } else {
      new_list[[k]] <- inter_split[[k]]
    }
  }
  
  yolomap <- mx.symbol.concat(data = new_list, num.args = final_map, dim = 1, name = paste0(name, "_yolomap"))
  
  return(yolomap)
  
}

# Additional some architecture for better learning

lvl1_conv1 <- CONV_Module(name = 'lvl1_conv1', indata = relu1_output,
                          num_filter = 2048, num_group = 32, squeeze_coef = 1/2,
                          shotcut = FALSE, dim_match = FALSE,
                          bn_momentum = 0.9, se_coef = 1/16)

lvl2_conv1 <- CONV_Module(name = 'lvl2_conv1', indata = lvl1_conv1,
                          num_filter = 2048, num_group = 32, squeeze_coef = 1/2,
                          shotcut = FALSE, dim_match = FALSE,
                          bn_momentum = 0.9, se_coef = 1/16)

lvl2_out <- CONV_Module(name = 'lvl2_conv2', indata = lvl2_conv1,
                        num_filter = 2048, num_group = 32, squeeze_coef = 1/2,
                        shotcut = TRUE, dim_match = TRUE,
                        bn_momentum = 0.9, se_coef = 1/16)

lvl1_plus_map <- DECONV_function(updata = lvl2_out, downdata = lvl1_conv1, num_filters = 2048, name = 'lvl1')

lvl1_out <- CONV_Module(name = 'lvl1_conv2', indata = lvl1_plus_map,
                        num_filter = 2048, num_group = 32, squeeze_coef = 1/2,
                        shotcut = TRUE, dim_match = TRUE,
                        bn_momentum = 0.9, se_coef = 1/16)

# Yolo output

yolomap_list <- list()

yolomap_list[[1]] <- YOLO_map_function(indata = lvl1_out, final_map = 13, num_box = 1, drop = 0, name = 'lvl1')
yolomap_list[[2]] <- YOLO_map_function(indata = lvl2_out, final_map = 13, num_box = 1, drop = 0, name = 'lvl2')

out_pred <- mx.symbol.Group(c(yolomap_list[[1]], yolomap_list[[2]]))
  • 接著要編寫損失函數,這個損失函數相當複雜:
MSE_loss_function <- function (indata, inlabel, obj, lambda, pre_sqrt = FALSE) {
  
  if (pre_sqrt) {
    indata <- mx.symbol.sqrt(indata)
    inlabel <- mx.symbol.sqrt(inlabel)
  }
  
  diff_pred_label <- mx.symbol.broadcast_minus(lhs = indata, rhs = inlabel)
  square_diff_pred_label <- mx.symbol.square(data = diff_pred_label)
  obj_square_diff_loss <- mx.symbol.broadcast_mul(lhs = obj, rhs = square_diff_pred_label)
  MSE_loss <- mx.symbol.mean(data = obj_square_diff_loss, axis = 0:3, keepdims = FALSE)
  
  return(MSE_loss)
  
}

LOGCOSH_loss_function <- function (indata, inlabel, obj, lambda) {
  
  diff_pred_label <- mx.symbol.broadcast_minus(lhs = indata, rhs = inlabel)
  cosh_diff_pred_label <- mx.symbol.cosh(data = diff_pred_label)
  logcosh_diff_pred_label <- mx.symbol.log(data = cosh_diff_pred_label)
  obj_logcosh_diff_pred_label <- mx.symbol.broadcast_mul(lhs = obj, rhs = logcosh_diff_pred_label)
  MSE_loss <- mx.symbol.mean(data = obj_logcosh_diff_pred_label, axis = 0:3, keepdims = FALSE)
  
  return(MSE_loss * lambda)
  
}

CE_loss_function <- function (indata, inlabel, obj, lambda, gamma = 0, pos_freq = 0.5, eps = 1e-4) {
  
  pos_weight <- (1 - pos_freq) * 2
  neg_weight <- pos_freq * 2
  
  log_pred_1 <- mx.symbol.log(data = indata + eps)
  log_pred_2 <- mx.symbol.log(data = 1 - indata + eps)
  multiple_log_pred_label_1 <- mx.symbol.broadcast_mul(lhs = log_pred_1, rhs = inlabel)
  multiple_log_pred_label_2 <- mx.symbol.broadcast_mul(lhs = log_pred_2, rhs = 1 - inlabel)
  obj_weighted_loss <- mx.symbol.broadcast_mul(lhs = obj, rhs = (1 - indata + eps)^gamma * multiple_log_pred_label_1 * pos_weight + (indata + eps)^gamma * multiple_log_pred_label_2 * neg_weight)
  average_CE_loss <- mx.symbol.mean(data = obj_weighted_loss, axis = 0:3, keepdims = FALSE)
  CE_loss <- 0 - average_CE_loss * lambda
  
  return(CE_loss)
  
}

YOLO_loss_function <- function (indata, inlabel, final_map = 33, num_box = 3, lambda = 10, gamma = 0, weight_classification = 0.2,
                                pos_freq_list = NULL, name = 'lvl1') {
  
  num_feature <- final_map/num_box
  
  if (is.null(pos_freq_list)) {pos_freq_list <- rep(0.5, num_feature - 5)}
  
  my_loss <- 0
  
  yolomap_split <- mx.symbol.SliceChannel(data = indata, num_outputs = final_map, 
                                          axis = 1, squeeze_axis = FALSE, name = paste(name, '_yolomap_split'))
  
  label_split <- mx.symbol.SliceChannel(data = inlabel, num_outputs = final_map, 
                                        axis = 1, squeeze_axis = FALSE, name = paste(name, '_label_split'))
  
  for (j in 1:num_box) {
    for (k in 1:num_feature) {
      if (k %in% 1:5) {weight <- 1} else {weight <- weight_classification}
      if (!k %in% c(2:5)) {
        if (k == 1) {
          my_loss <- my_loss + CE_loss_function(indata = yolomap_split[[(j-1)*num_feature+k]],
                                                inlabel = label_split[[(j-1)*num_feature+k]],
                                                obj = label_split[[(j-1)*num_feature+1]],
                                                pos_freq = 0.5,
                                                lambda = lambda * weight,
                                                gamma = gamma,
                                                eps = 1e-4)
          my_loss <- my_loss + CE_loss_function(indata = yolomap_split[[(j-1)*num_feature+k]],
                                                inlabel = label_split[[(j-1)*num_feature+k]],
                                                obj = 1 - label_split[[(j-1)*num_feature+1]],
                                                pos_freq = 0.5,
                                                lambda = 1,
                                                gamma = gamma,
                                                eps = 1e-4)
        } else {
          my_loss <- my_loss + CE_loss_function(indata = yolomap_split[[(j-1)*num_feature+k]],
                                                inlabel = label_split[[(j-1)*num_feature+k]],
                                                obj = label_split[[(j-1)*num_feature+1]],
                                                pos_freq = pos_freq_list[k-5],
                                                lambda = lambda * weight,
                                                gamma = gamma,
                                                eps = 1e-4)
        }
      } else {
        my_loss <- my_loss + LOGCOSH_loss_function(indata = yolomap_split[[(j-1)*num_feature+k]],
                                                   inlabel = label_split[[(j-1)*num_feature+k]],
                                                   obj = label_split[[(j-1)*num_feature+1]],
                                                   lambda = lambda * weight)
      }
    }
  }
  
  return(my_loss)
  
}

label1 <- mx.symbol.Variable(name = "label1")
label2 <- mx.symbol.Variable(name = "label2")

pos_freq_list <- table(box_info[,'obj_id'])
pos_freq_list <- pos_freq_list / nrow(box_info)
pos_freq_list[pos_freq_list < 0.02] <- 0.02

lvl1_loss <- YOLO_loss_function(indata = yolomap_list[[1]], inlabel = label1, final_map = 13, num_box = 1, lambda = 10, gamma = 0,
                                weight_classification = 0.5, pos_freq_list = pos_freq_list, name = 'lvl1')
lvl2_loss <- YOLO_loss_function(indata = yolomap_list[[2]], inlabel = label2, final_map = 13, num_box = 1, lambda = 10, gamma = 0,
                                weight_classification = 0.5, pos_freq_list = pos_freq_list, name = 'lvl2')

final_yolo_loss <- mx.symbol.MakeLoss(data = lvl1_loss + lvl2_loss)

開始訓練

  • 最難的過程已經過去,訓練就一點都不難了,這跟之前的語法很像。

– 我們可以將最後那幾層以外的部分填入resnet-50的參數,並以這為基礎開始訓練任務:

new_arg <- mxnet:::mx.model.init.params(symbol = final_yolo_loss,
                                        input.shape = list(data = c(1664, 1792, 3, 2), label1 = c(26, 28, 13, 2), label2 = c(13, 14, 13, 2)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

for (i in 1:length(new_arg$arg.params)) {
  pos <- which(names(res_model$arg.params) == names(new_arg$arg.params)[i])
  if (length(pos) == 1) {
    if (all.equal(dim(res_model$arg.params[[pos]]), dim(new_arg$arg.params[[i]])) == TRUE) {
      new_arg$arg.params[[i]] <- res_model$arg.params[[pos]]
    }
  }
}

for (i in 1:length(new_arg$aux.params)) {
  pos <- which(names(res_model$aux.params) == names(new_arg$aux.params)[i])
  if (length(pos) == 1) {
    if (all.equal(dim(res_model$aux.params[[pos]]), dim(new_arg$aux.params[[i]])) == TRUE) {
      new_arg$aux.params[[i]] <- res_model$aux.params[[pos]]
    }
  }
}
  • 可以開始訓練了,如果你的電腦沒有GPU,請你將「mx.gpu()」改成「mx.cpu()」
my_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-08, wd = 1e-4)

my.eval.metric.loss <- mx.metric.custom(
  name = "my-loss", 
  function(real, pred) {
    return(as.array(pred))
  }
)

mx.set.seed(0)

my_model <- mx.model.FeedForward.create(symbol = final_yolo_loss, X = my_iter, optimizer = my_optimizer,
                                        arg.params = new_arg$arg.params, aux.params = new_arg$aux.params,
                                        array.batch.size = 2, num.round = 30, ctx = mx.gpu(0),
                                        input.names = c('data'), output.names = c('label1', 'label2'),
                                        eval.metric = my.eval.metric.loss)
  • 這個模型要訓練非常久,並且在預測時還必須使用特殊的預測函數(因為有兩個輸出):
my_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-08, wd = 1e-4)

my.eval.metric.loss <- mx.metric.custom(
  name = "my-loss", 
  function(real, pred) {
    return(as.array(pred))
  }
)

mx.set.seed(0)

my_model <- mx.model.FeedForward.create(symbol = final_yolo_loss, X = my_iter, optimizer = my_optimizer,
                                        arg.params = new_arg$arg.params, aux.params = new_arg$aux.params,
                                        array.batch.size = 2, num.round = 30, ctx = mx.gpu(0),
                                        input.names = c('data'), output.names = c('label1', 'label2'),
                                        eval.metric = my.eval.metric.loss)
  • 接著我們能進行預測,你可以先將這麼不容易訓練好的模型存下來:
my_model$symbol <- out_pred
mx.model.save(my_model, 'BMSNet', 0)
  • 這是特殊的預測函數:
my_predict <- function (model, img, ctx = mx.gpu()) {
  
  all_layers <- model$symbol$get.internals()
  
  lvl1_output <- which(all_layers$outputs == 'lvl1_yolomap_output') %>% all_layers$get.output()
  lvl2_output <- which(all_layers$outputs == 'lvl2_yolomap_output') %>% all_layers$get.output()

  out <- mx.symbol.Group(c(lvl1_output, lvl2_output))
  executor <- mx.simple.bind(symbol = out, data = dim(img), ctx = ctx)
  
  mx.exec.update.arg.arrays(executor, model$arg.params, match.name = TRUE)
  mx.exec.update.aux.arrays(executor, model$aux.params, match.name = TRUE)
  if (class(img)!='MXNDArray') {img <- mx.nd.array(img)}
  mx.exec.update.arg.arrays(executor, list(data = img), match.name = TRUE)
  mx.exec.forward(executor, is.train = FALSE)
  
  pred_list <- list()
  
  pred_list[[1]] <- as.array(executor$ref.outputs$lvl1_yolomap_output)
  pred_list[[2]] <- as.array(executor$ref.outputs$lvl2_yolomap_output)

  return(pred_list)
  
}
  • 讓我們快速的對測試組中所有的樣本進行分析吧,最終我們能獲得0.3286的MAP50,而分數低的原因主要是因為分類不準確。
img_dir <- 'image/'

pred_box_info <- list()

pb <- txtProgressBar(max = nrow(test_dat), style = 3)

for (k in 1:nrow(test_dat)) {
  
  img <- readJPEG(paste0(img_dir, test_dat[k,'img_id'], '.jpg'))
  dim(img) <- c(dim(img), 1)

  encode_list <- my_predict(model = my_model, img = img, ctx = mx.gpu(0))
  
  pred_box_info[[k]] <- Decode_fun(encode_array_list = encode_list,
                                   anchor_boxs = anchor_boxs,
                                   img_id_list = test_dat[k,'img_id'])
  
  setTxtProgressBar(pb, k)
  
}

close(pb)

pred_box_info <- do.call('rbind', pred_box_info)
write.csv(pred_box_info[,1:7], file = 'my_submission.csv', na = '', row.names = FALSE, quote = FALSE)
  • 如果你想看看你的預測結果,可以試著叫一張圖看看:
selected_img <- 'U0059'

img <- readImage(paste0('image/', selected_img, '.jpg'))
img_array <- img
dim(img_array) <- c(dim(img_array), 1)

encode_list <- my_predict(model = my_model, img = img_array, ctx = mx.gpu(0))
pred_box_info <- Decode_fun(encode_array_list = encode_list,
                            anchor_boxs = anchor_boxs,
                            img_id_list = selected_img)

Show_img(img = img, box_info = pred_box_info)

  • 感覺至少在拉框上還可以吧?那現在其實已經變成了一個單純的分類任務囉。想一下還有哪些提升效能的方法呢?