演算法2-梯度下降法及MCMC

林嶔 (Lin, Chin)

Lesson 18

梯度下降法(1)

– 我們現在已經了解到一件事情,無論函數型態寫成什麼樣子,最大概似估計法重點就是能讓我們定義出一個求解函數,所以我們把問題簡化一點:我們希望有一個方法,能在某個不特定函數中找出該函數的極值。

– 我們從這個簡單的例子開始介紹「梯度下降法」

梯度下降法(2)

– 還是很難理解吧,我們來想想y = x^2的求解過程,我們已知y = x^2的微分方程是2x,意思是說在任何一個點的切線斜率是2x,而斜率的意思就是說「x每增加一個單位,y所改變的量」

– 想到這裡,我們就能了解,假設我們想要求得y = x^2的最小值,我們可以隨機的給一個x的起始點,並且讓這個點以「切線斜率」的反方向移動,這樣就能找出最小值

– 我了解到實在太難理解了,我們用R語言實現一下他的過程吧

original.fun = function(x) {
  return(x^2)
}

differential.fun = function(x) {
  return(2*x)
}

x = seq(-6, 6, by = 0.01)
y = original.fun(x)

start.value = 5
learning.rate = 0.1
num.iteration = 20

result.x = rep(NA, num.iteration)
result.y = rep(NA, num.iteration)

par(mfcol = c(4, 5))

for (i in 1:num.iteration) {
  if (i == 1) {
    result.x[1] = start.value
    result.y[1] = original.fun(start.value)
  } else {
    result.x[i] = result.x[i-1] - learning.rate * differential.fun(result.x[i-1])
    result.y[i] = original.fun(result.x[i])
  }
  plot(x, y, xlim = c(-5, 5), ylim = c(0, 25), type = "l", main = paste0("iteration = ", i))
  col.points = rep("black", num.iteration)
  col.points[i] = "red"
  points(result.x, result.y, pch = 19, col = col.points)
}

F18_1

梯度下降法(3)