Related to: Machine Learning

개요

PyTorch에서 추론(inference) 시 불필요한 gradient 계산을 방지하여 메모리를 절약하는 torch.no_grad() 사용법을 설명합니다.

핵심 개념

torch.no_grad()

  • backward를 사용하지 않을 때(inference 시점) 메모리 소비를 줄임
    • 모델 내 tensor 중 required_grad = True인 경우에도 required_grad=False처리하여 계산
  • 주로 모델 평가(validation/test) 시 사용

예시 / 코드

with torch.no_grad():
    output = model(input)

관련 개념

참조

https://pytorch.org/docs/stable/generated/torch.no_grad.html

Week 2