Related to: Machine Learning

개요

PyTorch를 사용한 딥러닝 모델 학습의 표준적인 순서를 코드와 함께 정리합니다.

핵심 개념

PyTorch 딥러닝 학습의 기본 순서

  1. Gradient 초기화: optimizer.zero_grad()
  2. 예측값 계산: output = model(inputs)
  3. Loss 계산: loss = loss_function(output, ground_truth)
  4. Gradient 계산(역전파): loss.backward()
  5. 파라미터 갱신: optimizer.step()

예시 / 코드

# Optimize 대상인 각 parameter들의 gradient 값 초기화
optimizer.zero_grad()
 
# 예측값 계산
output = model(inputs)
 
# Ground_truth와 예측 값 사이의 loss 계산
loss = loss_function(output, ground_truth)
 
# loss 값으로 각 parameter의 gradient 값 계산
loss.backward()
 
# gradient 값으로 각 parameter 갱신
optimizer.step()

관련 개념

참조

Week 2