Week 4: Automatic Differentiation in Pytorch
Pytorch + NPU 온라인 모임 #4 | 2025-01-08
Agenda
- Background — Supervised Learning, Backprop
- Automatic Differentiation — Pytorch에서의 구현: Autograd (1.0), AOTAutograd (2.0)
Background: Supervised Learning
머신러닝에서 지도학습(Supervised Learning)은 주어진 데이터로부터 최적의 함수를 찾아가는 과정입니다. 가장 간단한 예로 라는 모델이 있을 때, 훈련 데이터 쌍이 주어지면, 기울기()를 조금 조정하고 bias()를 조금 이동시키면서 데이터를 가장 잘 표현하는 weight 값을 찾아갑니다. 이처럼 training data가 주어졌을 때 loss를 최소화하는 weight 값으로 조금씩 변해가는 과정이 바로 학습(training)입니다.
Forward & Backward
네트워크의 가중치(예: )는 랜덤하게 초기화되거나 학습 중간 체크포인트에서 시작됩니다. 현재 가중치를 기준으로, 훈련 데이터에 대한 성능을 평가하고 개선하는 과정이 학습의 핵심입니다.
- Forward Propagation (순전파): 입력 데이터를 네트워크에 통과시켜 각 층의 가중치를 사용한 연산을 거쳐 출력값을 계산합니다.
- Loss 계산: 기대 출력(ground truth)과 네트워크 예측값의 차이를 loss function을 통해 계산합니다.
- Backward Propagation (역전파): 손실을 줄이기 위해 loss에서 출발하여 입력층까지 역방향으로 gradient를 계산합니다. 중간층의 모든 가중치에 대한 보정값을 구하는 이 과정을 backpropagation이라 하며, 이것이 자동화된 것이 Automatic Differentiation입니다.
- Optimizer: 계산된 gradient를 기반으로 optimizer가 가중치를 조정합니다. 이때 모델 아키텍처 자체는 변하지 않고, 가중치만 업데이트됩니다.
이 과정을 반복하면 손실이 점차 줄어들며, 목표 성능에 수렴하면 학습이 종료됩니다. 이것이 **Supervised Learning(지도 학습)**의 일반적인 흐름입니다.
(Reverse-Mode) Automatic Differentiation
Reverse-mode automatic differentiation은 역전파(backpropagation)를 구현하는 알고리즘으로, 다음 세 가지로 구성됩니다:
- Computation graph의 구성
- Computation graph 각 노드에서의 gradient function 도출
- Gradient를 전파하기 위한 graph의 backward walking
이 과정의 수학적 기반은 Chain Rule입니다.
(i) 일 때:
(ii) 이고 일 때:
예제: Autodiff 계산
다음 함수를 예로 들어보겠습니다:
이 함수의 computation graph를 구성하면, 각 중간 연산 노드에서의 gradient를 chain rule을 통해 역방향으로 전파할 수 있습니다. 자세한 내용은 Baydin et al. (2018)을 참고하세요.
예제: Forward & Backward 계산
Pytorch 1.0: Autograd
PyTorch에서 reverse-mode automatic differentiation을 구현한 것이 바로 Autograd입니다. “Autograd”라는 이름은 하버드 대학에서 NumPy 기반의 automatic differentiation을 연구하던 팀이 먼저 사용했으며, 해당 팀의 핵심 멤버들은 이후 Google에 입사하여 JAX 팀에 합류했습니다. PyTorch도 같은 이름을 사용하면서 초기에 이름 충돌에 대한 논의가 있었습니다.
PyTorch의 Autograd는 텐서 연산의 미분 계산을 자동화하여, 개발자가 복잡한 계산 그래프를 수동으로 정의하지 않고도 backpropagation을 수행할 수 있게 합니다. 핵심 특징은 다음 세 가지입니다:
1. Eager Mode를 위한 Autograd
PyTorch는 Eager Mode를 기본으로 동작하며, 각 연산이 실행될 때마다 computation graph를 incremental하게 구축합니다 (반대 개념인 graph mode에서는 먼저 그래프를 모두 생성한 뒤 연산을 진행합니다).
이를 위해 PyTorch는 Dispatcher와 operator overloading을 활용합니다:
- Dispatcher는 연산 요청이 들어오면 해당 연산을 어떤 방식으로 처리할지 결정하고, 적합한 구현(Autograd dispatch key 등)을 호출합니다.
- Operator overloading을 통해 Python의 기본 연산자(예:
+,*,@)를 PyTorch 텐서 객체에서 사용할 때, Autograd와 함께 처리되도록 수정된 함수가 실행됩니다.
이를 통해 직관적이고 명령형 프로그래밍 스타일로 모델을 만들 수 있습니다.
2. Tensor별 Gradient 선택
각 텐서별로 gradient 계산 여부를 선택할 수 있습니다. 딥러닝 모델에서 모든 텐서에 대해 gradient가 필요한 것은 아닙니다. 예를 들어 입력 데이터(input)는 값이 변하지 않으므로 backprop이 필요 없지만, 학습 대상인 weight는 backprop이 필요합니다.
텐서에는 requires_grad라는 flag가 있어서, 이 flag가 켜져 있는 텐서들만 computation graph에 연산 기록을 남기게 됩니다. Backpropagation을 실행하기 전에 gradient가 필요한 텐서를 transitive하게 선별하여 연산 비용을 줄이고 불필요한 리소스 낭비를 방지합니다. 이 기능은 학습 중 특정 텐서를 freeze하거나 모델의 일부만 미세 조정(fine-tuning)할 때 유용합니다.
3. Custom Differentiable Function
PyTorch는 기본 제공되는 연산자(log, sin, polynomial 등) 외에도, 개발자가 직접 custom differentiable function을 정의할 수 있는 유연성을 제공합니다. 잘 알려진 수학 함수의 미분 형태는 PyTorch 내부에 이미 구현되어 있지만, 사용자가 custom op을 만든 경우 Autograd는 그 미분 방법을 알 수 없습니다. 이때 custom op을 추가한 개발자가 미분 규칙을 함께 기술하여 등록할 수 있도록 장치가 마련되어 있습니다.
Autograd 동작 예제
단순한 수학적 연산을 통해 backpropagation 과정을 단계별로 살펴보겠습니다. 머신러닝 모델이라고 보기는 어렵지만, Autograd의 동작 원리를 이해하기에 충분한 예제입니다.
- 변수: 는 상수이지만 gradient가 필요하도록 마킹, ,
- 목표: 에 대한 의 gradient 를 구하기
x = torch.tensor(2.0)
x.requires_grad = True
y = x * 2
z = y ** 2
z.backward()
print(x) # tensor(2., requires_grad=True)
print(y) # tensor(4., grad_fn=<MulBackward0>)
print(z) # tensor(16., grad_fn=<PowBackward0>)
print(x.grad) # tensor(16.)
각 값을 확인하면: 이고 requires_grad flag가 True, , 입니다. 실제로 뒤에서부터 미분을 수행하면 이 됩니다. 이 원리를 기반으로 신경망에서는 수많은 연산의 미분을 체계적으로 계산하여 학습이 진행됩니다.
Step 1: Tensor 생성
x = torch.tensor(2.0)
x.requires_grad = True
로 초기화하고 requires_grad = True로 gradient 필요 여부를 마킹합니다. 이 시점의 computation graph에는 x:2 노드만 존재합니다.
Step 2: Forward 연산 (y = x * 2)
y = x * 2
먼저 forward 계산으로 가 저장됩니다. 여기서 끝나는 것이 아니라, 가 미분이 필요하다고 마킹되어 있으므로 그 output인 도 gradient가 필요하게 됩니다. 이때 MulBackward0라는 grad function이 에 연결되어, multiplication에 의한 backward 과정이 필요하다는 것이 등록됩니다. Computation graph 상태는 다음과 같습니다:
[참고] Forward에서 수행되는 AutoGrad kernel 예제: add (old style)
이 코드는 사람이 직접 작성한 것이 아니라, Autograd 라이브러리가 자동으로 생성(codegen)한 커널입니다. Eager mode에서의 dispatch 개념을 떠올리면, PyTorch에는 dispatch key가 있어서 텐서의 상태에 따라 어떤 함수를 dispatch할지 결정합니다. Forward용 커널과 backward용 커널이 각각 존재하고, 하드웨어(CPU, GPU 등)에 따라 다른 함수를 dispatch할 수 있도록 설계되어 있습니다.
requires_grad가 켜진 텐서가 관여하는 연산에서는 Autograd dispatch key를 사용하여 dispatch가 이루어집니다. 이 커널은 backward에서 호출되는 함수가 아니라, forward에서 add를 처리하는 함수입니다. Forward를 수행하면서 나중에 backward가 될 때 어떻게 처리할지를 미리 기록해 두는 것입니다:
compute_requires_grad로 현재 계산이 나중에 gradient를 구해야 하는 계산인지 판별합니다.- 필요하다면 해당 연산의
grad_fn(예:AddBackward0)을 찾아 등록합니다.add라는 함수가 정해져 있으므로 PyTorch는 대응하는grad_fn이 무엇인지 알고 있습니다. next_edges를 통해 backward graph를 구성합니다. 마치 computation graph를 만드는 것처럼 역전파 경로를 기록해 둡니다.- 이 모든 기록이 끝난 후에 실제 forward 계산을 수행합니다.
또한 특정 텐서에 requires_grad=True가 설정되어 있으면, 그 텐서로부터 영향을 받는 모든 output 텐서들도 자동으로 requires_grad=True가 됩니다. PyTorch가 forward 과정에서 이를 추적하며 자동으로 설정합니다.
[참고] Grad Functions
Grad function들(예: AddBackward0)도 결국 PyTorch Op입니다. 이 미분 함수들은 derivatives.yaml에 정의된 Op들의 미분 규칙으로부터 자동으로 코드 생성(codegen)됩니다. PyTorch에는 Op을 정의하는 YAML 파일이 여러 개 있으며, 그 중 derivatives.yaml에 기본 Op들에 대한 미분 규칙이 정리되어 있고, PyTorch가 이를 읽어서 backward 함수들을 자동으로 생성합니다.
Step 3: Forward 연산 (z = y ** 2)
z = y ** 2
Step 2와 동일한 과정이 반복됩니다. 이 계산되면서, Autograd는 와 도 gradient가 필요하다는 것을 인지합니다. 앞서 설명한 compute_requires_grad의 if문 안으로 들어가 PowBackward grad function을 bookkeeping해 둡니다. PowBackward에 들어가는 입력은 이며, 미분식 를 반환합니다. 이 스텝까지 완료되면 forward 계산이 모두 끝나고 backward graph 준비가 완료된 상태입니다:
Step 4: Backward
z.backward()
Step 4.1 - Backward 시작: Forward에서 기록해 둔 과정을 거꾸로 복귀합니다. 출력 자신에 대한 편미분은 항상 1이므로, 상수 부터 시작합니다. 이 값이 PowBackward로 전달되면, 의 미분 함수인 가 적용됩니다. 이때 는 위에서 들어오는 입력과는 상관없이 값에만 의존하며, 이므로 이 됩니다.
Step 4.2 - Gradient 전파: 다음 단계에서 에 대한 의 편미분은 입니다. 위에서 온 민감도(gradient)가 계속 누적되면서 곱해지는 과정이므로, 이전 단계에서 내려온 민감도 에 편미분 함수인 를 곱하면 이 됩니다. 이 이 최종적으로 x.grad에 저장됩니다.
따라서 입니다. 실제 신경망에서는 이보다 훨씬 복잡하지만, 근본적으로 이 과정이 반복적으로 일어난다고 생각하면 됩니다.
전체 backward 흐름을 정리하면 다음과 같습니다:
검증: 이므로, ✓
[참고] backward(): Python → C++
Python에서 backward()를 호출하면 C++ 레벨의 run_backward()로 연결되고, 이것이 Autograd Engine의 Engine::execute()를 호출합니다. execute()는 내부적으로 graph_task를 생성하고, 그래프에 등록된 함수들을 확인하며 역전파 계산에 필요한 함수들을 준비(init_to_execute())합니다. 이후 evaluate_function()에서 실제 backward 연산이 수행되고 최종적으로 gradient가 계산됩니다.
backward()
└→ run_backward()
└→ Engine::execute()
└→ graph_task->init_to_execute()
execute_with_graph_task()
└→ Engine::evaluate_function()
└→ call_function() ← 실제 grad function이 불리는 곳
Python 레벨에서는 wrapper만 존재하고, 실제 구현은 대부분 C++로 작성되어 있습니다. GPU 환경에서는 CUDA 코드가 실행되지만, 전체 실행 흐름을 관리하는 큰 틀은 C++ 코드입니다. 각 단계에서 grad_fn을 찾아 실행하며, dispatch key를 통해 CPU 커널 또는 CUDA 커널을 선택합니다.
Eager mode에서는 매번 backward() 호출 시 이 준비 과정이 반복 수행되지만, gradient 함수를 찾아 등록하는 과정에 불과하므로 큰 오버헤드를 유발하지는 않습니다. 반면 graph mode에서는 gradient 계산을 포함한 전체 그래프를 미리 생성하여 최적화된 방식으로 실행합니다.
PyTorch 1.0에서 eager mode 기반 Autograd가 핵심 기능으로 도입된 이후, 기본적인 큰 틀은 현재 최신 PyTorch에서도 변하지 않고 유지되고 있습니다. 내부 구현이 조금씩 업데이트되었을 가능성은 있지만, 전반적인 메커니즘은 크게 달라지지 않았습니다.
Pytorch 2.0: AOTAutograd
지금까지는 eager mode에서 automatic differentiation이 Autograd를 통해 어떻게 구현되는지 살펴보았습니다. 그렇다면 PyTorch 2.0의 graph mode에서는 automatic differentiation을 어떻게 구현할 수 있을까요?
Eager mode와 graph mode의 가장 큰 차이는 forward 과정에 있습니다. Graph mode에서는 fake tensor를 이용하여 forward를 빠르게 에뮬레이션(trace)하면서 계산이 어떻게 이루어질지 전체적인 스케치를 먼저 합니다. 이 스케치가 FX Graph로 캡처되고, FX Graph를 backend compiler(예: Inductor)에 전달하면 컴파일 후 실제 계산이 이루어집니다.
이러한 graph mode에서 backward graph를 생성하는 방법으로 두 가지를 생각해 볼 수 있습니다:
-
Dynamo trace 중에 Autograd를 실행하여 backward graph 생성: Fake tensor를 이용한 tracing 과정에서 기존 eager mode의 Autograd를 함께 돌리는 방법입니다. 그러나 Autograd는 실제 tensor 값을 기반으로 동작하도록 설계되었기 때문에, fake tensor를 사용하는 tracing 과정에서 그대로 동작할지 확신하기 어렵습니다. 또한 Dynamo trace는 전체 계산을 하나의 trace로 만드는 것이 아니라, eager mode와 graph mode가 번갈아가면서 실행되므로 backward graph 생성이 복잡해질 수 있습니다.
-
FX Graph를 먼저 만든 후 backward graph를 생성: FX Graph로 변환되고 나면 복잡한 부분들은 정리된 상태이므로, 이 시점에서 backward graph를 만드는 것이 더 깔끔합니다. FX Graph로 묶인 부분을 eager mode 관점에서 **하나의 거대한 합성 연산(composite Op)**으로 보면, 그 Op에 대한 미분 함수만 알고 있으면 eager mode에서 자연스럽게 backward가 수행될 수 있습니다.
이 두 번째 아이디어를 바탕으로 **AOTAutograd(Ahead-of-Time Autograd)**가 개발되었습니다. 이제부터 AOTAutograd가 실제로 어떻게 동작하는지 하나하나 살펴보겠습니다.
What Happens When Training with torch.compile()?
Eager Mode
각 연산이 즉시 실행됩니다. requires_grad가 켜져 있으므로 forward 중 Autograd가 backward graph를 자동 생성합니다.
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()- Forward 중 Autograd가 backward graph를 자동 생성
loss.backward()호출 시 바로 gradient 계산
Graph Mode
torch.compile()로 Dynamo tracing 준비가 된 모델을 생성합니다. 실행 시 연산을 FX Graph로 캡처합니다.
compiled_model = torch.compile(model)
outputs = compiled_model(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()- AOTAutograd가 compiled forward/backward 생성
- Backend compiler가 최적화된 실행 수행
겉에서 보면 코드 차이가 거의 없지만, 내부에서 일어나는 과정은 상당히 복잡합니다. torch.compile()이 알아서 최적화된 방식으로 forward와 backward를 처리해 주는데, 그 과정이 어떻게 이루어지는지 살펴보겠습니다. 아래 슬라이드에서 왼쪽은 사용자가 작성하는 코드이고, 오른쪽은 AOTAutograd에 의해 생성된 compiled forward/backward function들(torch.autograd.Function으로 감싸짐)입니다.
torch.compile()을 적용한 학습의 동작 과정은 세 단계로 나눌 수 있습니다:
Step 1 — compiled_model = torch.compile(model) → outputs = compiled_model(images)
torch.compile(model)을 호출하면 Dynamo가 모델을 tracing하여 FX Graph를 생성합니다. 이 FX Graph가 AOTAutograd에 의해 torch.autograd.Function 객체로 변환됩니다. 이것이 바로 앞서 PyTorch 1.0 Autograd에서 언급했던 custom differentiable function 기능을 활용하는 것입니다. 이 객체는 마치 하나의 거대한 미분 가능한 custom function처럼 동작하며, compiled forward와 compiled backward를 모두 가지고 있습니다 (슬라이드 오른쪽의 def forward와 def backward 참조).
compiled_model(images)를 호출하면 이 객체의 forward 부분만 실행되어 실제 forward 계산이 이루어집니다. 동시에 output의 grad_fn으로 compiled backward가 등록됩니다. Backward 부분은 이 시점에서는 호출되지 않고, 나중에 backward()가 호출될 때 한 스텝으로 실행됩니다.
Step 2 — loss = F.cross_entropy(outputs, labels)
Output에는 실제 계산된 값이 들어 있으므로, 이를 가지고 cross entropy 등의 loss를 계산할 수 있습니다. 이 과정은 requires_grad가 켜져 있는 eager mode로 실행됩니다. Output에 등록되어 있던 grad_fn(compiled backward)이 loss function의 grad_fn으로 자연스럽게 연결됩니다.
Step 3 — loss.backward()
loss.backward()를 호출하면, loss function의 backward 부분이 먼저 실행되고, 이어서 그 입력이었던 output의 grad_fn, 즉 Step 1에서 등록된 compiled backward가 하나의 큰 Op으로서 호출됩니다. 이 두 backward가 순차적으로 역방향(concatenation)으로 실행되면서 model weight의 gradient가 계산됩니다.
AOTAutograd의 역할
AOTAutograd는 PyTorch 2.0에서 새로 추가된 기능으로, training mode에서 Autograd의 forward 및 backward 과정을 처리하는 핵심 엔진입니다. 앞 페이지에서 설명한 과정을 좀 더 쪼개서 살펴보겠습니다.
핵심은 training mode에서 FX Graph가 backend compiler로 바로 전달되지 않는다는 점입니다. AOTAutograd가 먼저 개입하여 다음 과정을 수행합니다:
- nn.Module → Dynamo가 tracing하여 FX Graph를 생성합니다. (이 부분은 graph mode에서 다뤘던 내용과 동일합니다.)
- FX Graph → AOTAutograd가 forward와 backward를 분리하고 최적화합니다.
- Forward와 backward를 각각 따로 backend compiler로 컴파일합니다.
- 컴파일된 forward와 backward를 모두 가지고 있는
torch.autograd.Function객체를 생성하여 반환합니다.
이렇게 반환된 torch.autograd.Function을 호출하면, 이미 컴파일된 형태의 forward와 backward가 실행되므로 가속된 training이 가능해집니다.
Graph Mode Backward의 도전 과제
Graph mode에서 backward를 처리하는 데에는 두 가지 큰 도전 과제가 있으며, AOTAutograd는 각각에 대한 해결책을 제시합니다:
| Graph mode backward의 challenges | AOT Autograd의 해결책 |
|---|---|
문제 1. Autograd engine은 **C++**로 구현되어 있음
| → 해결 1. Backward를 C++ dispatcher 수준에서 tracing
|
문제 2. Dynamo가 생성한 FX Graph의 op들이 꽤 복잡
| → 해결 2. FX Graph를 먼저 normalization한 이후에 tracing
|
AOTAutograd Architecture
위 그림은 AOTAutograd의 전체 흐름을 보여줍니다. requires_grad가 켜져 있는 input tensor들이 들어오면:
- Dynamo에 의해 FX Graph로 변환됩니다.
- FX Graph에 Functionalization, Decomposition 등의 normalization을 적용하여 더 간단한 그래프를 만듭니다.
- 정규화된 그래프를 C++ 수준에서 tracing하면, forward 그래프와 backward 그래프가 동시에 생성됩니다.
- Forward와 backward를 각각 backend compiler로 컴파일합니다.
- 컴파일된 forward/backward를
torch.autograd.Function으로 매핑하여 반환합니다.
이 과정을 통해 AOTAutograd는 들어오는 FX Graph를 forward와 backward를 모두 가진 최적화된 torch.autograd.Function 객체로 변환합니다. 이 과정을 좀 더 자세히 breakdown해서 살펴보겠습니다.
AOTAutograd 수행 순서
AOTAutograd의 전체 수행 순서는 다음과 같습니다:
| 단계 | 과정 |
|---|---|
| A | Dynamo가 생성한 FX Graph를 input으로 받음 |
| B | Forward와 backward를 합친 joint 함수를 생성 |
| C | Joint 함수를 normalize (Functionalization 적용 등) |
| D | Joint 함수를 수행하며 실행된 op을 C++ dispatcher 수준에서 trace |
| E | Decomposition 적용 |
| F | Joint trace로부터 joint FX Graph 생성 |
| G | Joint FX Graph를 forward/backward로 분리 (Min-Cut 알고리즘) |
| H | Forward/backward를 각각 backend compiler로 컴파일 |
| I | Compiled function들을 묶어 최종 torch.autograd.Function 생성 |
왜 joint 함수를 만드는가? Forward와 backward를 따로 최적화하면 중간 텐서의 관계를 파악하기 어렵습니다. 학습 시 backward는 forward에서 계산된 중간 텐서 값이 필요한데, 이 텐서들이 매우 클 수 있어 메모리를 많이 차지합니다. 모든 중간 텐서를 저장하는 대신, 필요할 때 다시 계산하는 recomputation 전략을 적용할 수 있는데, 이를 효율적으로 수행하려면 forward에서 만들어진 텐서가 backward의 어떤 입력으로 연결되는지를 알아야 합니다. Joint 함수를 만들면 이 input-output 관계가 모두 살아있는 하나의 계산 그래프가 만들어지므로, recomputation을 포함한 최적화를 훨씬 쉽게 적용할 수 있습니다.
Joint 함수 생성 과정의 구체적 동작: Eager mode에서는 forward를 한 단계씩 실행하면서 incremental하게 grad_fn을 텐서에 등록합니다. AOTAutograd에서는 이 과정을 먼저 한 번 실행하여 모든 grad_fn을 등록해 두고, 이어서 Autograd engine이 등록된 grad_fn들을 수행하는 부분까지 합쳐서 하나의 함수로 만듭니다. 이 함수를 C++ dispatcher 수준에서 trace하면, forward에서 계산된 텐서들과 backward에서 이를 사용하는 부분이 모두 연결된 하나의 trace가 한 번에 만들어집니다.
Normalization이란? Eager mode에서 암묵적(implicit)으로 처리되던 복잡한 연산들을 명시적(explicit)으로 드러내고, 이를 단순한 작은 Op들로 분해하는 과정입니다. Functionalization과 Decomposition이 여기에 해당합니다.
Backend compiler(H 단계): H 단계에서의 컴파일은 Dynamo의 tracing이 아니라, 실제 backend compiler(예: Inductor)가 최적화된 코드를 생성하는 과정입니다. Backend-aware 최적화는 forward/backward가 분리된 이 단계에서만 적용됩니다.
AOTAutograd 코드 레벨 동작
위의 단계들을 코드 레벨에서 살펴보면 다음과 같습니다. 가장 top-level 함수는 aot_dispatch_autograd()이며, 크게 세 부분으로 나뉩니다: (1) joint FX Graph 생성, (2) partitioning 및 컴파일, (3) torch.autograd.Function으로 wrap하여 반환.
# A: 최상위 함수
aot_dispatch_autograd(flat_fn)
# Joint FX Graph 생성
fx_g = aot_dispatch_autograd_graph(flat_fn, ...)
# B: Joint 함수 생성 — forward와 backward를 붙여서
# forward만 실행해도 둘 다 계산되는 함수를 만듦
joint_fn_to_trace = create_joint(flat_fn, ...)
def inner_fn(...)
outs = fn(*primals) # Forward 실행
backward_out = torch.autograd.grad(...) # Backward 실행
return outs, backward_out
# C: Normalization 수행 — eager mode의 복잡성을 제거
joint_fn_to_trace = create_functionalize_fn(joint_fn_to_trace, ...)
joint_fn_to_trace = aot_dispatch_subclass(joint_fn_trace, ...)
# D, E, F: 실제 tracing 수행 + Decomposition + FxGraph 생성
# D: joint_fn_to_trace에 대한 실제 tracing이 일어남
# E: Decomposition (Torch IR → ATen IR → Core ATen IR → Prims IR)
# F: 최종 Joint FX Graph 생성
fx_g = _create_graph(joint_fn_to_trace, ...)
# G: Partitioning — Joint FX Graph를 forward/backward로 분리
fw_module, bw_module = partition_fn(fx_g, ...)
# H: Compile 수행 — 각각 backend compiler로 컴파일
compiled_fw_func = fw_compiler(fw_module, ...)
compiled_bw_func = bw_compiler(bw_module, ...)
# I: 최종 결과물 — torch.autograd.Function으로 wrap하여 반환
compiled_fn = AOTDispatchAutograd.post_compile(compiled_fw_func, ...)
이 중 가장 복잡한 부분은 joint FX Graph를 만들어내는 과정(B~F)입니다:
- B 단계: Forward 부분과 backward 부분을 하나의 함수로 결합합니다.
- C 단계: Normalization을 적용하여 eager mode의 복잡성을 제거합니다.
- D~F 단계:
_create_graph()를 통해 실제 tracing이 이루어지고, decomposition을 거쳐 최종 joint FX Graph가 생성됩니다.
여기서 decomposition은 eager mode의 복잡성 제거(normalization)와는 별개로, 수많은 Torch Op을 정해진 소수의 낮은 수준 Op들로 분해하는 과정입니다 (Torch IR → ATen IR → Core ATen IR → Prims IR).
AOTAutograd 실행 예제
다음 코드를 통해 AOTAutograd가 실제로 어떻게 동작하는지 살펴보겠습니다:
@torch.compile()
def func(a, b):
return torch.clamp_min(a, b) * 3
p = torch.tensor([0.4, -0.2], requires_grad=True, device='cuda')
loss = func(p, 0).sum()
loss.backward()
print(p.grad)
Joint Graph 생성

Dynamo가 만들어낸 FxGraph는 joint graph의 forward 부분에 해당합니다.

f()를 수행한 후 grad()를 호출하면 forward와 backward를 모두 수행하는 함수가 됩니다.

이것을 trace하면 joint graph를 생성할 수 있습니다.
Joint Function의 Input은 다음과 같이 구성됩니다:
- Forward Input: 원래 모델의 입력값
- Backward Input: Loss로부터 전달되는 gradient (grad_outs)
- Forward에서 계산된 중간 tensor는 Joint Graph에 이미 포함됩니다.
Joint Graph Example
Forward와 backward가 하나의 그래프에 포함되어 있으며, Min-Cut 알고리즘으로 partition할 수 있습니다.
Decomposition (+ Functionalization)
Eager mode의 복잡한 연산을 명시적으로 표현하고, 단순한 Op들로 분해하는 과정입니다. IR 계층은 다음과 같이 세분화됩니다:
Torch IR → ATen IR → Core ATen IR → Prims IR
Min-Cut Algorithm으로 Forward/Backward 분리


Joint FxGraph를 Min-Cut Algorithm으로 forward와 backward로 분리한 후, 각각을 별도로 컴파일합니다.
Activation Checkpointing

Batch size가 커질수록 중간 activation도 함께 커집니다.

Segment별 input activation만 저장하고, backward 수행시 segment별로 forward를 다시 수행합니다.
Activation checkpointing은 recomputation 전략으로, 중간 activation을 저장하지 않고 필요할 때 다시 계산하여 메모리를 절약합니다. 메모리 사용량 감소와 추가 연산 비용 사이의 trade-off가 존재하며, joint graph를 통해 forward와 backward 간의 텐서 관계를 파악하여 효율적으로 적용할 수 있습니다.
Compiled Model = torch.autograd.Function
Compiled model은 forward와 backward를 모두 포함하는 custom backend 함수, 즉 torch.autograd.Function입니다.
PyTorch는 다음 두 가지 개념을 torch.autograd.Function으로 통합합니다:
- PyTorch operations을 포함하지 않는 코드(C++, CUDA, numpy)가 function transforms와 함께 동작하는 것
- Custom gradient rules (JAX의
custom_vjp/custom_jvp와 유사)
결과적으로
전체 과정을 정리하면:
- Compile된
forward()가 호출됩니다. - 추가로 compile된
backward()가 output tensor의 grad function으로 등록됩니다. loss.backward()수행 과정에서 compile된backward()를 실행합니다.
등록된 backend가 forward/backward를 각각 컴파일하여, 최종적으로 가속된 학습이 이루어집니다.