본문 바로가기
Machine learning

Deep Learning numpy를 통한 기초이론_3. backpropagation 오차역전파법

by ahsung 2020. 1. 5.

오차역전파법은 왜 쓰이는가

 

기초이론_2에서 각각의 가중치와 편향의 손실함수에 대한 미분값을 구하면

경사하강법을 통해 최소가 되는 지점에 다가가는 법을 알아보았다.

 

하지만 n개의 입력을 받아 m개의 출력을 내뱉는 인공망의 한층이 있다고 해보자.

 

1개의 출력을 위해 n개의 입력마다 각각의 weight를 곱한후 bias를 더하면 한개의 출력이 나온다.

 

이런 출력이 m개 라는 것은  n*m개의 weight가,  m개의 bias가 존재한다.

화살표 한개가 weight 한개!!

화살표 한개가 하나의 weight라 할 수 있고, 

위 사진은 2층 형태의 구조로서 

 

1) 3개의 입력 4개의 출력

2) 4개의 입력 2개의 출력

 

으로 구성되어 있다, 1),2)는 각 넘어가는 화살표의 과정이다.

 

식은 상세하게 보지 않아도 된다. 인공망의 수식으로서 표현이 어색하다면 살펴보자.

hidden Node안에는 각 input에서 주어진 input1,2,3

 

hidden1 = input1*w11 + input2*w12 + input3*w13 + bias1

hidden2 = input1*w21 + input2*w22 + input3*w23 + bias2

hidden3 = input1*w31 + input2*w32 + input3*w33 + bias3

hidden4 = input1*w41 + input2*w42 + input3*w43 + bias4

 

out Node안에는 각 hidden에서 주어진 hidden1,2,3

 

out1 = hidden1*out_W11 + hidden2*out_W12 + hidden3*out_W13 + hidden4*out_W14 + out_b1

out2 = hidden1*out_W21 + hidden2*out_W22 + hidden3*out_W23 + hidden4*out_W24 + out_b2

 

 

만약 n*m 과 m*k 의 두개의 층에서

n,m,k가 각각 3자리 숫자만 되더라도 단 한번의 학습마다 미분값을 구해야 할 매개변수들이 너무 많아진다.

 

그래서 고안하여 나온 것이 오차역전파법이다.!!

 

오차역전파법은 각 매개변수의 변화율을 하나하나 수치적 미분을 통해서 구하는 것이 아닌,

 

역전파라는 말 답게, 결과쪽에서의 변화율에서부터 반대방향으로 weight들의 변화율을 구해나가는 것이다.

 

 

오차역전 파법의 간단한 이론

 

여기서부터는 고등학교 수준의 간단한 미분 이론이 들어간며 직관적인 수학적 이해를 바탕으로 설명하겠다.

 

w가 가중치이고 loss(w)가 w를 매개변수로 했을때의 손실함수 값이라하자.

그렇다면 w에 대한 loss의 미분값은 loss'(x)이다

 

g(x)라는 함수와 f(y)라는 함수가 있다하자.

가) y에 대한 f(y)의 미분값은 f'(y)이다.

나) x에 대한 g(x)의 미분값은 g'(x)이다.

 

자 다시한번, y=g(x)라 가정해보자.

그렇다면 f(y) = f(g(x))라는 합성함수가 생성된다.

 

자 위의 식 가)를 다시 써보면

가) g(x)에 대한 f(g(x))의 미분값은 f'(g(x)) 이다.

 

그리고!,

 

x에 대한 f(g(x))의 미분값은  f'(g(x))*g'(x) 이다.

이것은 고등학교 수학의 합성함수 미분법으로서도 알 수 있고

 

dt/dx  = dt/dy  * dy/dx    

t를 x로 미분한 값  =  t를 y로 미분한 값  *  y를 x로 미분한 값 

라는 식으로도 쉽게 알 수 있다.

 

즉 다시 해석하면,  

f'(g(x))는 g(x)의 변화율에 따른 f(g(x))의 증가량이고

g'(x)는 x의 변화율에 따른 g(x)의 증가량이다.

 

x의 변화율에 따른 f(g(x))의 변화율은    f'(g(x))*g'(x) 라는 사실은 당연하다.

 

 

오차역전파는 이런 합성함수의 특징을 통해 "연쇄적"으로 각 단계의 weight들의 편미분값들을 구할 수 있다.

 

g(x) = 3x , f(g(X)) = 2g(x) 라하자.

x는 입력값이고 g(x)는 1층(hidden),  f(g(x))는 2층(최종 결과) 값이라하자.

 

1층 g(x)에서의 가중치는 3 

2층 f(g(x))에서의 가중치는 2

 

g(x)는 1층 단계의 신경망 weight를 거쳐 나온 중간 결과값이자 2층으로의 입력값이다.

즉 2층단계에서의 입력값 g(x)에 대한 미분값은  f'(g(x)) = 2  이다.

 

그리고 1층 단계에서의 입력값 x에 대한 g'(x) = 3 = 1층의 weight이다.

 

2층 단계는 같은 방식으로 g(x)에 대한 f'(g(x)) = 2 = 2층의 weight이다.

 

입력값에 대한 총 함수의 변화율(미분값)을 역방향으로 순서대로 계산하면

 

f(g(x))에 대한 f(g(x)) 변화율은 당연하게도 1

 

g(x) 에 대한 f(g(x)) 변화율은 1*f'(g(x)) = 1 * 2

 

x에 대한 총 결과 함수 f(g(x))의 변화율은  1*f'(g(x))*g'(x) = 1*2*g'(x) = 1*2*3 = 6

 

단계적으로 앞에서 구했던 미분값 2에 이어서 g'(x)의 값을 곱해 나가면 미분값을 간단하게 구할 수 있다.

 

즉 전 단계에서 구한 입력값에 대한 미분값이 있으면, 다음 단계에서는 손쉽게 연쇄작용으로 구할 수 있다.

 

 

하지만

위에서 구한 식은 입력값에 대한 미분값이다.

우리가 구하고 싶은것은 weight와 bias에 대한 미분값이다.

 

역전파의 특성에 대해 알아보자.

 

f = a * b 라는 식이 있다. a는 1커질때 마다  f는 b만큼 커진다.  a에 대해  a 1당, f는 b의 변화율을 가지며 미분값이 b이다.

그 역 b에대한 f의 변화율은 a 도 성립한다. 한마디로,

weight의 변화율은 입력받은 값, 그게 바로 미분값이라는 뜻이 된다.

(이하 표현에서 변화율 = 증가량 = 미분값 을 뜻한다.)

 

위에서 입력값(x)에 대한 미분값이  weight *( 전달받은 미분값) 이었다.

그렇다면 weight에 대한 미분값은  입력값 * 전달받은 미분값이 된다.

 

 

__더욱 직관으로 간단하게 설명하자면__

#일단 bias와 활성화함수의 존재는 무시하고

1층에서의 가중치 W1이 있다하자,  2층은 W2 ...n층은 Wn

이때 마지막 결과는 input값이 x라면  result =  x*W1*W2*W3...*Wn일 것이다.

 

W1이 1 커질 때 마다 result는 얼마나 커지는가  x*W2*W3*.....Wn이다.

W2이 1 커질 때 마다 result는 얼마나 커지는가  x*W1*W3*W4*...Wn이다.

Wn이 1 커질 때 마다 result는 얼마나 커지는가  x*W1*W2*...Wn-1이다.

 

Wk가 1 커질때 reuslt는  ( x*W1...*Wk-1 )  *  ( Wk+1...*Wn )  두 부분으로 나누어 일반화 시킬 수 있다.

 

이제 직관적으로 앞부분은 k층에서 받는 입력값(Wk의 미분값)이고

 

앞선 예시로 각층의 Weight는 각층의 입력값의 변화율(미분값)이라는 것을 알고있다.

뒷부분 값은 k층 이후의 입력의 미분값들의 총곱이다.

 

매번 각층의 weight의 미분값을 구하기위해 그 다음층의 미분값들을 매번 구한다면

엄청난 손해가 아닐 수 없다. 어차피 미분값들의 총 곱이라면 뒷부분부터

차례대로 구하며 이미 계산한 곱들의 연산을 뒤로 전달하면되고

입력값의 경우는 처음 순방향으로 결과값을 뽑기위해 계산했을 때 (보통 손실함수는 결과값과 정답값을 비교함)

각 층마다 저장했다면 입력값, 다음층의 미분값 모두 엄청난 중복 연산을 줄일 수 있다!

이게 단순한 역전파의 발상이다.

 

 

 

 

예시를 들어보자.

그저 수식적인 판단이 아닌, 직관적으로 수학이 왜 이렇게 되는지 이해하기 위해 비슷한 내용이 반복됩니다.

 

위의 g(x) = 3x, f(g(x)) = 2g(x) 라하고

 

x = 100 이라는 입력값을 받았다고 해보자.

weight1 = 3,  weight2 =2

 

x        ->  g(x)        ->             f(g(x))   = 600

100   100*3 = 300    300*2  = 600

 

지금 이 과정은 순서대로 값을 구한 순전파라고 한다.

 

그렇다면 가장 뒤의 600부터 하여  각 위치의 증가율(미분값)을 거꾸로 weight에 대한 역전파를 구해보자.

600 → 2(변화율,가중치) * 300(입력값) →   3(변화율,가중치) * 100(입력값) →100(최초 입력값)

Df(g(x))=1 → (Dweight2 = Df(g(x))*300),  Dg(x) = Df(g(x))*2 (Dweight1 = Dg(x)*100),  Dx = Dg(x)*3     

# D변수는 변수가 결과값에 주는 증가량이란 뜻  (Dx =  df(g(x)) / dx  를 뜻함)

수식만으로는 이해하기 어려우니 아래 내용 참고.

 

이제 역전파 답게  최후의 출력값 노드부터 거꾸로 최초의 입력값 노드 방향으로 역전파 전달 과정을 보겠습니다.

 

1 (결과 값에 대한 결과값의 증가율 당연히 1입니다.) -->

weight2에 대한 출력값 증가율은 300(입력)이다. 전달받은 미분값이 1 이므로 총 결과값에 대한 증가율(미분값)은 1*300이다. -->

 

주의!! 다음 노드로 전달할 미분값은 300이 아니다.  여기서 구한것은 weight에 대한 미분값일 뿐이다.!

역전파로 전달할 값은 받은 값을 거꾸로 그 값에 대한 미분값을 보내줘야 한다! 

맨 뒤부터 시작..

현재 노드 f(g(x)가 전달 받았던 값은 입력값으로 받은 g(100) = 300이다.

현 노드로의 입력값g(x)에 대한 전체의 증가량은 당연하게도 현 노드에서 곱해진 weight와 그 후 노드에서 발생될 증가량이다!

즉 f(g(x))의 weight2 * 이후 노드에서 더욱 영향을 줄 증가값  (현 노드가 마지막 노드이므로 1이다.)  = 2* 1 = 2이다.

 

---> 다음 g(x) 노드로 이동.

 

이제 g(x) 노드에서는 역전파로 2를 전달 받았다. 이는 현 노드에서 1 출력이 증가할 때 총 결과값에 2의 영향을 준다는 것을 뜻한다

weight1의 미분값은 현 노드에서 입력값인 100만큼의 영향력을 지닌다. 100*2 = 200

weight1  1증가당 200의 영향력을 총 결과에 줄 수 있다. 

 

입력값(x)의 미분값은 현노드에서 g(x)의 weight = 3 만큼의 영향력을 지니고, 3 * 2 = 6

x가 1증가시 6의 영향력을 총 결과에 줄 수 있다.

 

다시 요약하면

 

즉 위의 합성함수에서 말했던 연쇄법칙처럼,  역전파에서 얻어온 증가율은 입력값에 대한 미분값 1*2가 된다.   -->

 

weight1에 대한 출력값 증가율은 100(입력)이다. 전달받은 미분값이 1*2이므로 총 결과값에 대한 미분값은 1*2*100 = 200 이다.

 

그리고 그 다음 노드에 전할 미분값은 입력(100)에 대한 미분값 3(weight)* 이전 연쇄값들(1*2) = 6 이다.

 

즉 마지막 노드 100은 증가율(미분값)  6을 가진다.

 

다시 정리해보면

input1 = 100

weight1 = 3

input2 = 100*3 = 300

weight2 = 2

result = 2*300 = 600

 

각각 결과값에 미치는 영향력을 직관으로 계산해보자. ( 핵심!! 이것만 읽어도 직관적인 이해는 가능..)

input1은 1커질 때 마다 결과값에 6에 해당하는 증가율을 보여준다.

weight1은 1커질 때 마다 결과값에 200에 해당하는 증가율을 보여준다.

input2은  1커질 때 마다 결과값에 2에 해당하는 증가율을 보여준다.

weight2는 1커질 때 마다 결과값에 300에 해당하는 증가율을 보여준다.

result는 결과값 그 자체이므로 당연히 1커질때마다 1에 해당하는 증가율이다.

 

d(~~)는 result를 ~~로 미분한 값이라 하면,

d(result)      =                       (당연히 본인이 1 증가하면 1커지는 것..)

d(input2)    =      * (1)  =           

d(weight2) = 300  * (1)  = 300       

d(input1)    = 3       * (2 * 1)   =  6    

d(weight1) = 100  * (2 * 1)  = 200 

 

만약 3층 구조였다면, d(input1)을 연쇄 전달값으로 계속 전해 나갔을 것이다.

현 노드에서만  입력값을 출력값으로 만드는 증가량은 단순하게, input이라면 곱해지는 weight,, weight라면 input일 것이다.

하지만 총 결과값에 영향을 미치는 증가량은, 그 이후에 통과할 노드들의 증가량도 곱해지기 때문에

위와 같은 식을 직관적으로 이해할 수 있을 것이다.

 

 

물론 각각의 값들은 서로 종속성을 가지고 있지만,

편미분을 한다는 가정이기 때문에, 현재 변수에대한 미분을 진행할 때 다른 변수들은 모두 상수라고 가정하고 진행한다.

 

다행히 bias의 미분값은 간단하다. bias는 현재 노드 자체에서는 증가율이 존재하지 않는다.

하지만 bias가 1증가할때마다 결과 값에는 얼만큼 영향을 주겠는가?

지금 노드에서 출력되어 여러 노드를 거치며 증가하게된 증가량 만큼만 결과값에 영향이 있을 것이다.

즉 역전파에서 받아온 미분값 만큼만 영향력이 존재한는 것이다.!! bias의 미분값은 전달받은 미분값 그 자체이다.

 

간단한 수학식으로 설명하자면 output = X*W + b 일때,  output을 b로 편미분하게되면,

Doutput/Db = 1이다. 즉 output에 대한 미분값이 1이라는 말이다. 하지만 우리가 구할것은 총 마지막 결과값의 변화율이므로 Dresult/Db이다.

Dresult/Doutput이 output에대한 결과값 변화율이고 전달받은 값이므로, 

Dresult/Doutput  * Doutput/Db = 그냥 output에 대한 결과값 변화율

즉 전달받은 미분값이다.

 

역전파로 weight의 미분값을 구하기 위해서는 전달받은 미분값과

순방향(순전파)일 때 자신의 노드에 입력 받았던 값이 필요하므로 기억 시켜놓는 것이 중요하다.

 

그리고 노드에서 다음 역방향 노드에 전달할 미분값은 , 자신이 전달받은 미분값 * 입력값에 대한 미분값이다.

입력값에 대한 미분값 = weight 이다.

 

dout, 전달받은 미분값

dx = 현 노드에 입력받은 x에 대한 결과 미분값

dw = 현 노드의 weight에 대한 결과 미분값             

이라면

 

dx = dout*w   (위의 합성함수 예시에서 dout 이 f'(g(x)) 부분 w가 g'(x) 부분으로 생각 할 수 있다.)

dw = dout*x

이다.

 

dw는 우리가 구하고 싶었던 w에 대한 결과값 미분값

dx는 역방향으로 전달해야할 전 노드로 부터 받은 입력값(x)에 대한 결과값 미분값

 

 

 

역전파를 이용하면, 그저 한 번의 행렬곱 연산으로 다음 layer 신경망으로의 미분값 전달이 되므로

일일이 weight마다 미분값을 구하기 위해 각 weight의 작은변화에 따른 결과값을 계산하던 방식에서 벗어날 수 있다.

 

역전파는 그리 어려운 미분 이론은 아니지만, 미분에 대한 직관적인 이해도가 있어야 쉽게 접근할 수 있다.

 

일반적인 weight와 입력값에 대해서는 선형(1차식) 모델의 미분이기에 간단하게 계수값이었지만,

 

활성함수를 같이쓰는 신경망에서는 각 활성함수에대한 역전파 미분값을 구해야 한다.

 

또한 실제 신경망에서 학습할 때의 결과 값은 손실함수의 값이다.

 

그렇기 때문에 손실함수의 미분값이 간단하게 떨어질 수 있는 활성함수 짝을 잘 선택하는 것도 중요하다.

 

 

 

# 위 식은 행렬로서도 확장될 수 있다.

# 단 행렬은 행과 열이 곱할 수 있는 순서로 곱해진다.

# 예) 순전파(순방향)으로  (1,3)행렬 * (3,2) 행렬 = (1,2)행렬   { 3개의 입력값으로 2개의 결과 출력 (3,2)의 Weight 행렬 }

#  역전파(역방향)에서는   (3,1) * (1,2)  = (3,2) 행렬  {dy결과의 미분값에서,  입력값을 곱하니,, (3,2)의 Weight 미분값이 나왔다.}

 

 def gradient(self, x, t):

         # x는 입력되는 값의 numpy 

         # t는 결과 출력과 비교할 정답 numpy
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
        grads = {}
        
        #batch는 총 task 수

         # 한번에 많은 task를 학습시키는 수
        batch_num = x.shape[0]

        # 순전파, 순방향
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        

        # 역전파, 역방향

        # 행렬곱을 통해 각 task별 매개변수들 값이 모두 더해진다. 처음 batch_num으로 나누었으므로 평균이 된다.


        dy = (y - t) / batch_num        #마지막 활성함수로 softmax를 사용했다.
                       			#entropy 손실함수에 대해서 (y-t)라는 간단한 미분값을
                       			#가지기 때문에 dy를 바로 사용하였다..
        grads['W2'] = np.dot(z1.T, dy)   		  
        grads['b2'] = np.sum(dy, axis=0)    
        
        da1 = np.dot(dy, W2.T)
        dz1 = sigmoid_grad(a1) * da1           # sigmoid_grad()는 a1 입력값으로 sigmoid함수를 미분하는 함수이다.  
        grads['W1'] = np.dot(x.T, dz1)
        grads['b1'] = np.sum(dz1, axis=0)

        return grads

 

교차엔트로피 오차 (Cross Entropy Error, CEE)는  softmax 활성 함수와 함께 사용 할 때 (y-t)라는 역전파,

평균제곱오차 (Mean squared error - mse)는 항등함수와 함께 사용 할 때 (y-t)라는 역전파가 간단히 나옵니다.

 

 

 

 

 

 

댓글