Related to: Machine Learning
개요
PyTorch에서 추론(inference) 시 불필요한 gradient 계산을 방지하여 메모리를 절약하는 torch.no_grad() 사용법을 설명합니다.
핵심 개념
torch.no_grad()
- backward를 사용하지 않을 때(inference 시점) 메모리 소비를 줄임
- 모델 내 tensor 중
required_grad = True인 경우에도required_grad=False처리하여 계산
- 모델 내 tensor 중
- 주로 모델 평가(validation/test) 시 사용
예시 / 코드
with torch.no_grad():
output = model(input)관련 개념
- 자동 미분(Autograd) - no_grad로 비활성화하는 자동 미분 메커니즘
- torch.tensor의 requires_grad param의 기능 - requires_grad 파라미터 설명
- PyTorch 딥러닝 학습의 기본 순서 - 학습 및 추론 흐름
- Pytorch Performance Tuning Practices - 성능 최적화 기법
참조
https://pytorch.org/docs/stable/generated/torch.no_grad.html