문과생도 이해하는 딥러닝 (7) - 오차역전파법 실습 2

2019. 10. 19. 10:34·IT기술 관련/A.I 인공지능
반응형

출처: https://sacko.tistory.com/41?category=632408

 

오차역전파법은 계층 형태로 순전파와 역전파를 메서드로 구현하여 효율적으로 기울기를 계산할 수 있도록 모듈화하여 신경망의 layer를 자유롭게 쌓고 쉽게 만들며 계산 속도를 빠르게 해준다. 딥러닝의 신경망 학습 모형은 이러한 layer들과 그 안의 함수들을 모듈로서 레고 블럭을 조립하듯이 조립하여 신경망을 구현할 수 있다.

 

 

오차역전파법 실습 2

문과생도 이해하는 딥러닝 (7)

앞선 "신경망 학습 실습" 포스팅에서는 순전파만을 고려하였고 기울기 계산을 수치 미분으로 하면서 계산이 오래 걸려서 결국.... 결과를 내지 못했다. 다행히 이번 오차역전파를 적용한 실습에서는 계산이 빨라 신경망 학습에 대한 결과까지 그래프로 확인할 수 있었다. 학습의 정확도는 약 98%, 테스트의 정확도는 약 97%까지 나왔다.

 

코드에 대한 설명은 추후 추가

 

from scratch.common.layers import * from scratch.common.gradient import numerical_gradient from collections import OrderedDict

 

class TwoLayerNet: def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01): # initializa Weights self.params = {} self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size) self.params['b1'] = np.zeros(hidden_size) self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size) self.params['b2'] = np.zeros(output_size) # Build Layers self.layers = OrderedDict() self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1']) self.layers['Relu1'] = Relu() self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2']) self.lastLayer = SoftmaxWithLoss() def predict(self, x): for layer in self.layers.values(): x = layer.forward(x) return x def loss(self, x, t): y = self.predict(x) return self.lastLayer.forward(y, t) def accuracy(self, x, t): y = self.predict(x) y = np.argmax(y, axis=1) if t.ndim != 1 : t = np.argmax(t, axis=1) accuracy = np.sum(y==t) / float(x.shape[0]) return accuracy def numerical_gradient(self, x, t): loss_W = lambda W: self.loss(x, t) grads = {} grads['W1'] = numerical_gradient(loss_W, self.params['W1']) grads['b1'] = numerical_gradient(loss_W, self.params['b1']) grads['W2'] = numerical_gradient(loss_W, self.params['W2']) grads['b2'] = numerical_gradient(loss_W, self.params['b2']) return grads def gradient(self, x, t): # foward self.loss(x, t) # backward dout = 1 dout = self.lastLayer.backward(dout) layers = list(self.layers.values()) layers.reverse() for layer in layers: dout = layer.backward(dout) grads = {} grads['W1'] = self.layers['Affine1'].dW grads['b1'] = self.layers['Affine1'].db grads['W2'] = self.layers['Affine2'].dW grads['b2'] = self.layers['Affine2'].db return grads

 

from scratch.dataset.mnist import load_mnist

 

(x_train, t_train), (x_test, t_test) = load_mnist(normalize = True, one_hot_label = True) network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10) iters_num = 10000 print('number of iterations is', iters_num) train_size = x_train.shape[0] print('train size is', train_size) batch_size = 100 print('batch size is', batch_size) learning_rate = 0.1 train_loss_list = [] train_acc_list = [] test_acc_list = [] iter_per_epoch = max(train_size / batch_size, 1) print('='*20 + '>') epoch = 0 for i in range(iters_num): batch_mask = np.random.choice(train_size, batch_size) x_batch = x_train[batch_mask] t_batch = t_train[batch_mask] grad = network.gradient(x_batch, t_batch) for key in ('W1', 'b1', 'W2', 'b2'): network.params[key] -= learning_rate * grad[key] loss = network.loss(x_batch, t_batch) train_loss_list.append(loss) if i % iter_per_epoch == 0: epoch += 1 train_acc = network.accuracy(x_train, t_train) test_acc = network.accuracy(x_test, t_test) train_acc_list.append(train_acc) test_acc_list.append(test_acc) print('Epoch', epoch, ': ',train_acc,'\t', test_acc)

 

number of iterations is 10000 train size is 60000 batch size is 100 ====================> Epoch 1 : 0.102966666667 0.1055 Epoch 2 : 0.904416666667 0.9074 Epoch 3 : 0.9235 0.9252 Epoch 4 : 0.936166666667 0.9344 Epoch 5 : 0.9447 0.9432 Epoch 6 : 0.951983333333 0.9506 Epoch 7 : 0.955983333333 0.9537 Epoch 8 : 0.960733333333 0.9562 Epoch 9 : 0.965166666667 0.9602 Epoch 10 : 0.967366666667 0.9616 Epoch 11 : 0.969633333333 0.9652 Epoch 12 : 0.9721 0.966 Epoch 13 : 0.973833333333 0.9676 Epoch 14 : 0.975983333333 0.9686 Epoch 15 : 0.976883333333 0.9703 Epoch 16 : 0.9778 0.9694 Epoch 17 : 0.97895 0.9702

 

import matplotlib.pyplot as plt plt.figure(figsize=(20,8)) plt.plot(train_loss_list[:1000], linewidth=0.5) plt.title('Train Loss Graph') plt.xlabel('iteration') plt.ylabel('loss') plt.show()

 

 

plt.figure(figsize=(20,8)) plt.plot(train_acc_list, linewidth=1) plt.plot(test_acc_list, '-.' ,linewidth=1) plt.legend(['train acc', 'test acc'], loc=0) plt.xlabel('epochs') plt.ylabel('accuracy') plt.show()

 

 

 

 

반응형

'IT기술 관련 > A.I 인공지능' 카테고리의 다른 글

문과생도 이해하는 딥러닝 (9) - 신경망 초기 가중치 설정  (0) 2019.10.21
문과생도 이해하는 딥러닝 (8) - 신경망 학습 최적화  (0) 2019.10.20
문과생도 이해하는 딥러닝 (6) - 오차역전파법 실습 1  (0) 2019.10.18
bagging , boosting, stacking (배깅,부스팅,스태깅)  (0) 2019.10.16
문과생도 이해하는 딥러닝 (5) - 신경망 학습 실습  (0) 2019.10.16
'IT기술 관련/A.I 인공지능' 카테고리의 다른 글
  • 문과생도 이해하는 딥러닝 (9) - 신경망 초기 가중치 설정
  • 문과생도 이해하는 딥러닝 (8) - 신경망 학습 최적화
  • 문과생도 이해하는 딥러닝 (6) - 오차역전파법 실습 1
  • bagging , boosting, stacking (배깅,부스팅,스태깅)
호레
호레
창업 / IT / 육아 / 일상 / 여행
    반응형
  • 호레
    Unique Life
    호레
  • 전체
    오늘
    어제
    • 분류 전체보기
      • 법률
        • 기본
        • 개인정보보호법
        • 정보통신망법
        • 전자금융거래법
        • 전자금융감독규정
        • 신용정보법
        • 온라인투자연계금융업법
      • 창업
        • 외식업 관련
        • 임대업 관련
        • 유통업 관련
        • 세무 관련
        • 마케팅 관련
        • 기타 지식
        • 트렌드
        • Youtube
      • IT기술 관련
        • 모바일
        • 윈도우
        • 리눅스
        • MAC OS
        • 네트워크
        • 빅데이터 관련
        • A.I 인공지능
        • 파이썬_루비 등 언어
        • 쿠버네티스
        • 기타 기술
      • 퍼블릭 클라우드 관련
        • Azure
        • GCP
        • AWS
      • 정보보안 관련
        • QRadar
        • Splunk
        • System
        • Web
      • 기타
        • 세상 모든 정보
        • 서적
      • 게임 관련
        • 유니티
      • 부동산
      • 맛집 찾기
        • 강남역
        • 양재역
        • 판교역
        • ★★★★★
        • ★★★★
        • ★★★
        • ★★
        • ★
      • 결혼_육아 생활
        • 리얼후기
        • 일상
        • 육아
        • 사랑
        • Food
      • 영어
        • 스피킹
        • 문법
        • 팝송
        • 영화
      • K-컨텐츠
        • 드라마
        • 영화
        • 예능
      • 독서
      • 프로젝트 관련 조사
        • 시스템 구축
        • 로그 관련
        • 웹
        • APT
        • 모의 해킹
        • DB
        • 허니팟
        • 수리카타
        • 알고리즘
        • FDS
      • 기업별 구내 식당 평가
        • 한국관광공사
        • KT telecop
        • KT M&S
        • KT powertel
        • KT cs 연수원
        • 진에어
      • 대학 생활
        • 위드윈연구소
        • 진로 고민
        • 채용정보
        • 자동차
        • 주식
        • 악성코드
        • 게임 보안
      • 쉐어하우스
  • 블로그 메뉴

    • 홈
    • 게임 관련
    • IT 기술 관련
    • 태그
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    돈까스
    무역전쟁
    상호관세
    대통령
    쥬쥬랜드
    이재곧죽습니다
    보안가이드
    유니티
    AWS
    수제버거존맛
    점심
    판교맛집
    마케팅
    판교역
    복리후생
    맛집
    런치
    수제버거맛집
    수제버거
    판교
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
호레
문과생도 이해하는 딥러닝 (7) - 오차역전파법 실습 2
상단으로

티스토리툴바