Week 5: Distributed Programming in Pytorch

Pytorch + NPU 온라인 모임 #5 | 2025-01-15

소개

이번 강의에서는 PyTorch의 분산 프로그래밍(Distributed Programming)에 대해 다룹니다. Week 1에서 PyTorch 2.0의 핵심 특징 중 하나로 MPI-like distributed programming model을 언급한 바 있는데, 이번 주에 그 내용을 본격적으로 살펴봅니다. 대규모 모델 학습이 일반화되면서 분산 프로그래밍은 PyTorch를 사용하는 거의 모든 실무 환경에서 필수적인 기술이 되었습니다.

강의는 크게 네 부분으로 구성됩니다:

  1. Overview — MPI와 OpenMP의 차이, PyTorch 분산 프로그래밍의 기본 개념과 아키텍처
  2. torchrun을 활용한 distribute matmul 예제 — 실제 분산 행렬 곱셈을 수행하는 전체 과정
  3. Device와 연결: CUDA 예제 — CUDA Device 선택, 동기화, Process Group 관리의 내부 동작
  4. PyTorch가 제공하는 모델 병렬화 패키지들 — DDP, FSDP, TP 등 고수준 병렬화 도구

MPI vs OpenMP

분산 프로그래밍을 이해하려면 먼저 두 가지 대표적인 병렬 프로그래밍 모델인 MPI와 OpenMP의 차이를 알아야 합니다. 이 둘은 메모리 모델부터 근본적으로 다릅니다.

MPI

Distributed Memory

OpenMP

Shared Memory

OpenMP는 Shared Memory 모델 기반으로, 모든 프로세서가 동일한 메모리에 접근할 수 있습니다. 반면 MPI는 Distributed Memory 모델 기반으로, 각 프로세서가 자신만의 메모리를 가지고 있으며 프로세서 간에는 메시지 전달을 통해 통신합니다. PyTorch는 MPI 스타일을 따릅니다. GPU마다 독립적인 메모리를 가지고 있기 때문에, 분산 메모리 모델이 GPU 기반 학습 환경에 자연스럽게 맞아떨어지는 것입니다.

Task 구성 방식의 차이

두 모델은 Task를 구성하는 방식에서도 큰 차이가 있습니다.

MPI

독립적인 프로세스들의 묶음

OpenMP

Master-Worker (Fork-Join) 모델

MPI에서는 각 프로세스가 독립적으로 시작하며, 메시지 패싱으로 통신하고 Barrier를 통해 동기화합니다. OpenMP에서는 Master 태스크가 시작한 뒤 Fork(병렬 태스크 생성) → Join(종료)을 반복하는 구조입니다. PyTorch의 분산 학습에서 각 GPU 프로세스가 독립적으로 동일한 학습 스크립트를 실행하고, 필요한 시점에 gradient를 동기화하는 패턴은 바로 MPI의 이러한 구조에서 비롯된 것입니다.

MPI 프로그램 구조

MPI 프로그램은 여러 프로세스에서 동일한 프로그램을 실행합니다. 각 프로세스는 **초기화(initialization)**를 수행한 후 병렬로 작업하며, rank(프로세스 번호)에 따라 다르게 동작하도록 코드를 작성합니다. 예를 들어 rank 0은 데이터를 로드하고, 나머지 rank는 데이터를 전달받는 식입니다. 모든 작업이 완료되면 finalize 단계를 거쳐 종료합니다. PyTorch도 이러한 구조와 유사한 패턴을 따르며, 뒤에서 살펴볼 torchrun 예제에서 이를 직접 확인할 수 있습니다.

PyTorch Distributed Programming: 기본 개념

PyTorch 분산 프로그래밍에서 사용하는 핵심 용어들을 먼저 정리합니다.

Node는 물리적인 인스턴스로, 하나 또는 여러 개의 GPU를 가지고 있습니다. Process는 각 node에서 생성되며, 일반적으로 1 process : 1 GPU 관계를 맺습니다. World는 해당 job에 참여하는 모든 process의 집합을 말합니다.

Rank는 프로세스의 고유 식별자인데, 두 가지 종류가 있습니다. Global rank는 World 전체에서의 순서이고, local rank는 하나의 Node 내에서의 순서입니다. 아래 다이어그램에서 보듯이, 2개의 Node에 각각 3개의 프로세스가 있다면 global rank는 05, local rank는 각 Node 내에서 02가 됩니다.

Rendezvous는 학습 job에 참여하는 node들을 한데 모으는 과정으로, 동적으로 관리할 수 있습니다. 노드가 추가되거나 빠지는 상황에서도 유연하게 대응할 수 있는 메커니즘입니다.

World

Node 2

Rank 3
(local rank 0)

Rank 4
(local rank 1)

Rank 5
(local rank 2)

Node 1

Rank 0
(local rank 0)

Rank 1
(local rank 1)

Rank 2
(local rank 2)

Rendezvous

Overall Architecture

PyTorch의 분산 프로그래밍 아키텍처는 세 개의 레이어로 구성됩니다.

torch.distributed (process)

Backends

c10d

collective

p2p

High-level packages

DDP

FSDP

TP

SP

PP

ProcessGroupGloo

ProcessGroupNCCL
(optional)

ProcessGroupMPI
(optional)

Custom
(optional)

CPU

GPU

MPI lib

최상위에는 DDP, FSDP, TP, SP, PP 등의 고수준 패키지가 있습니다. 이 패키지들은 사용자가 복잡한 분산 통신 로직을 직접 작성하지 않아도 모델 병렬화를 쉽게 적용할 수 있도록 해줍니다. 중간에는 collective/p2p 통신을 담당하는 c10d 레이어가 있는데, 이것이 PyTorch 분산 통신의 핵심 API입니다. 그리고 하단에는 Gloo, NCCL, MPI 등의 통신 backend가 위치하여 실제 하드웨어 수준의 통신을 수행합니다.

Process Group

Process Group은 서로 통신하는 프로세스들의 집합입니다. 기본(default) process group으로 모든 프로세스를 하나로 묶을 수도 있고, 필요에 따라 여러 개의 subgroup으로 나눌 수도 있습니다. 예를 들어, 서로 다른 종류의 병렬화를 동시에 적용하려면 프로세스들을 여러 subgroup으로 나누어 각각 독립적으로 통신하도록 구성합니다.

Default process_group

process_group (default)

Rank 0

Rank 1

Rank 2

Rank 4

Rank 5

Rank 6

process_group with 2 subgroups

process_group

subgroup 2

Rank 4

Rank 5

Rank 6

subgroup 1

Rank 0

Rank 1

Rank 2

Distributed Communication Layer (c10d)

c10d는 PyTorch에서 제공하는 distributed communication API로, 분산 프로그래밍의 통신 기반을 담당합니다. 크게 두 가지 통신 방식을 제공합니다.

Collective communication은 group 내에 있는 모든 process들이 협력하여 data를 공유하고 처리하는 방식입니다. DDP의 gradient 동기화, FSDP의 parameter sharding 등 대부분의 PyTorch 분산 패키지가 이 방식을 사용합니다. P2P communication은 하나의 process에서 다른 process로 data를 직접 전송하는 방식으로, 개별 프로세스 간 직접적인 데이터 교환이 필요할 때 사용됩니다. Pipeline Parallel 같은 경우가 대표적입니다.

Collective Communication 연산들

Collective communication에서 가장 많이 사용되는 연산들을 살펴봅니다.

Scatter / Gather — Scatter는 하나의 프로세스가 가진 데이터를 여러 프로세스에 분산하는 연산이고, Gather는 반대로 여러 프로세스의 데이터를 한 곳으로 모으는 연산입니다. 데이터를 쪼개서 나눠주고 다시 모으는 가장 기본적인 패턴입니다.

Reduce / All-Reduce — Reduce는 여러 프로세스의 데이터를 하나의 연산(예: 합산)으로 합쳐 하나의 프로세스에 저장하는 것이고, All-Reduce는 같은 연산을 수행하되 결과를 모든 프로세스에 공유합니다. 분산 학습에서 gradient를 합산할 때 All-Reduce가 핵심적으로 사용됩니다.

Broadcast / All-Gather — Broadcast는 한 프로세스의 데이터를 모든 프로세스에 동일하게 전달하는 연산입니다. 예를 들어 rank 0이 가진 모델 가중치를 모든 프로세스에 복사할 때 사용됩니다. All-Gather는 각 프로세스가 가진 서로 다른 데이터를 모든 프로세스에 수집하는 연산으로, FSDP에서 sharded parameter를 복원할 때 활용됩니다.

Communication Backends

c10d는 interface만 제공하고, 실제 통신 동작은 개별 하드웨어에 대한 communication backend에서 구현됩니다. PyTorch가 기본으로 제공하는 backend는 다음과 같습니다:

Backend지원 Device비고
GlooCPU / GPUCPU 분산 학습에 주로 사용
NCCLGPUNVIDIA GPU 전용, GPU 간 통신 최적화
MPICPU일반적인 클러스터 분산 환경

이러한 구조 덕분에 third-party backend를 추가하는 것도 가능합니다. c10d/Backend.hpp를 구현하면 되는데, broadcast, allreduce, reduce, allgather21개 virtual function으로 구성되어 있습니다. 리벨리온 NPU와 같은 새로운 하드웨어도 이 인터페이스를 구현함으로써 PyTorch의 분산 프로그래밍 생태계에 통합될 수 있습니다.

torchrun을 활용한 분산 행렬 곱셈 예제

지금까지 개념을 살펴봤으니, 이제 실제로 분산 행렬 곱셈을 수행하는 예제를 통해 전체 과정을 따라가 봅니다.

torchrun이란

torchrun은 PyTorch 분산 프로그래밍에서 각 node가 수행할 일을 template화한 utility script입니다. 개발자는 단일 프로세스에서 실행되는 것처럼 PyTorch 스크립트를 작성하고, torchrun이 이를 여러 프로세스에서 자동으로 실행해줍니다. 각 node(physical instance)에서 argument를 설정하여 torchrun을 실행하면 됩니다.

torchrun --nnodes=2 --nproc_per_node=8 \
    --rdzv_id=job1 --rdzv_backend=c10d \
    --rdzv_endpoint=node1:29500 \
    dist_matmul_allreduce.py

주요 파라미터는 다음과 같습니다:

파라미터설명
nnodes참여하는 node의 개수
nproc_per_nodenode당 process의 개수
node_ranknode ID
--rdzv-backendrendezvous backend
--rdzv-endpointrendezvous host:port
test.py각 process가 실행할 pytorch 코드

클러스터 준비

torchrun을 시작하기 전에 클러스터 설정이 필요합니다. 가장 단순한 방법은 SSH 기반 설정으로, /etc/hosts에 클러스터의 모든 노드의 IP와 hostname을 등록하고, 모든 노드들 사이에 암호 없이 로그인이 가능하도록 세팅하는 것입니다. 이 외에도 Ray cluster, Kubernetes, Slurm workload manager, Horovod 등 다양한 클러스터 관리 도구를 사용할 수 있습니다.

이 강의에서 사용할 예제 클러스터는 GPU가 8개씩 달려 있는 두 개의 노드입니다:

  • 192.168.0.2 (node_rank 0)
  • 192.168.0.3 (node_rank 1)

torchrun 실행

각 노드에서 다음과 같이 실행합니다. 두 노드 모두 동일한 스크립트(dist_matmul_allreduce.py)를 실행하지만, node_rank만 다르게 지정합니다. 이것이 바로 앞서 설명한 MPI 스타일의 “동일한 프로그램을 여러 프로세스에서 실행”하는 패턴입니다.

@ node_rank: 0

$ torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
    --rdzv_id=job1 --rdzv_backend=c10d \
    --rdzv_endpoint="192.168.0.2:29500" \
    dist_matmul_allreduce.py

@ node_rank: 1

$ torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \
    --rdzv_id=job1 --rdzv_backend=c10d \
    --rdzv_endpoint="192.168.0.2:29500" \
    dist_matmul_allreduce.py

torchrun 수행 과정

torchrun을 실행하면 내부적으로 다음과 같은 단계를 거칩니다:

  1. torchrun 명령어 인자 parsing 및 초기화 — 위에서 지정한 파라미터들을 읽어들입니다.
  2. 프로세스 생성 — 각 노드에서 nproc_per_node만큼의 Python process를 생성합니다. 이 예제에서는 노드당 8개, 총 16개의 프로세스가 만들어집니다.
  3. 랑데뷰 (rendezvous) — 모든 프로세스가 서로를 발견하고 동기화합니다. 각 프로세스에 global rank가 할당됩니다.
  4. Process별로 PyTorch 스크립트 실행 — 프로세스 그룹을 생성하면서 communication backend를 선택(nccl 혹은 gloo)하고, 데이터를 로딩 및 sharding한 뒤, 계산을 수행하고, torch.distributed.all_reduce를 이용해 동기화한 후, 프로세스 그룹을 제거합니다.
  5. 부가 기능 — fault tolerance를 위해 checkpointing과 logging을 지원합니다. --max-restarts=N 옵션으로 실패 시 N번까지 재시작을 시도할 수 있습니다.

랑데뷰 (Rendezvous) 상세

랑데뷰는 각 프로세스에서 PyTorch 스크립트를 실행하기 전에 수행되는 중요한 준비 단계입니다. 구체적으로 다음과 같은 일이 일어납니다.

먼저 랑데뷰 instance가 만들어지고 IPC를 위해 c10d backend가 초기화됩니다. PyTorch의 DynamicRendezvousHandler class를 사용하며, rdzv-endpoint로 지정된 노드에서 rendezvous backend가 생성되어 hosting됩니다. 모든 프로세스는 이 랑데뷰 endpoint에 연결하고, 다른 프로세스들이 모두 join하기를 기다립니다.

모든 프로세스가 모이면, 각 프로세스에 고유한 rank를 할당하고, 모든 프로세스에서 consistent한 distributed environment를 세팅합니다. 이 과정이 완료되어 모든 프로세스가 준비가 되면 비로소 PyTorch script의 실행이 시작됩니다.

fault tolerance 측면에서도 랑데뷰는 중요한 역할을 합니다. 만약 한 노드가 실패하면, 남아 있는 다른 노드들로 랑데뷰를 다시 시도하여 학습을 이어갈 수 있습니다.

Task 수행 코드

이제 실제 분산 행렬 곱셈 코드를 살펴봅니다. 먼저 메인 코드입니다:

if __name__ == "__main__":
    rank = int(os.environ.get("RANK"))
    if rank == 0:
        A, B = torch.ones(n, n), torch.ones(n, n)
    else:
        A, B = torch.empty(n, n), torch.empty(n, n)

    result = dist_matmul_allreduce(
        A, B,
        int(os.environ.get("LOCAL_RANK")),
        int(os.environ.get("WORLD_SIZE")))

    if rank == 0:
        print(result)

이 코드에서 MPI 패턴이 명확히 드러납니다. Rank 0 프로세스만이 초기 데이터를 실제로 로드하고(torch.ones), 다른 프로세스들은 빈 텐서(torch.empty)를 만들어 놓습니다. 이후 dist_matmul_allreduce 함수에서 broadcast를 통해 rank 0의 데이터가 모든 프로세스에 전달됩니다. 최종 결과도 Rank 0에서만 출력합니다.

분산 행렬 곱셈의 핵심 함수들은 다음과 같습니다:

def dist_matmul_allreduce(
        A, B, local_rank, world_size):
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl")
    dist.broadcast(A, 0)
    dist.broadcast(B, 0)
    local_A, local_B = distributed_data(
        A, B, local_rank, world_size)
    local_result = local_matmul(
        local_A, local_B)
    dist.all_reduce(
        local_result, op=dist.ReduceOp.SUM)
    dist.destroy_process_group()
    return local_result

이 함수의 흐름은 MPI 프로그램 구조에서 설명한 것과 정확히 일치합니다. CUDA device를 선택하고(초기화), process group을 생성한 뒤(통신 채널 설정), broadcast로 데이터를 나누어 계산하고, all_reduce로 결과를 합산한 후, process group을 제거합니다(finalize).

def distributed_data(A, B, lrank, world_size):
    n = 16000
    k = n // world_size
    device = torch.device(f"cuda:{lrank}")
    # split A and B
    local_A = A[:, lrank*k:(lrank+1)*k].to(
        device)
    local_B = B[lrank*k:(lrank+1)*k, :].to(
        device)
    return local_A, local_B

def local_matmul(local_A, local_B):
    return torch.matmul(local_A, local_B)

distributed_data 함수는 행렬 A와 B를 각 프로세스가 담당할 부분으로 쪼개는 역할을 합니다. 행렬 A는 열(column) 방향으로, 행렬 B는 행(row) 방향으로 분할하여 각각의 GPU에 올립니다. 이렇게 분할된 부분 행렬끼리의 곱셈 결과를 all_reduce로 합산하면 전체 행렬 곱셈의 결과를 얻을 수 있습니다.

Device와 연결: CUDA 예제

앞서 살펴본 dist_matmul_allreduce 함수에는 CUDA와의 접점이 네 군데 있습니다. 하나씩 상세히 살펴봅니다.

def dist_matmul_allreduce(A, B, local_rank, world_size):
    # 1. CUDA Device 선택
    torch.cuda.set_device(local_rank)
    # 3. Process Group 생성 및 통신 채널 설정
    dist.init_process_group(backend="nccl")
    dist.broadcast(A, 0)
    dist.broadcast(B, 0)
    local_A, local_B = distributed_data(A, B, local_rank, world_size)
    local_result = local_matmul(local_A, local_B)
    # 2. 동기화 (CUDA Runtime)
    dist.all_reduce(local_result, op=dist.ReduceOp.SUM)
    # 4. Process Group 제거
    dist.destroy_process_group()
    return local_result

1: CUDA Device 선택

torch.cuda.set_device()를 통해 각 프로세스가 사용할 GPU를 명시적으로 선택합니다. 여기서 Rendezvous가 지정해 준 LOCAL_RANK 환경변수를 사용하는데, 이는 하나의 node 내에서만 unique한 rank입니다 (전체 World에서 unique한 것은 RANK 환경변수). 이를 통해 node 내에서 process : device == 1 : 1 관계를 만듭니다.

이 호출은 torch.distributed.init_process_group()보다 먼저 수행되어야 합니다. 명시적으로 device를 지정하지 않으면 모든 프로세스가 0번 device를 사용하게 되어 GPU 자원이 낭비됩니다. 한 가지 알아둘 점은, torch.cuda.set_device()를 호출하는 것만으로는 CUDA 초기화가 발생하지 않는다는 것입니다. PyTorch는 lazy 초기화 방식을 사용하여, 실제로 Tensor를 최초로 GPU에 만들 때 비로소 CUDA 초기화가 일어납니다.

아래 표는 2개 노드, 노드당 8개 GPU 환경에서 RANK와 LOCAL_RANK의 관계를 보여줍니다:

0178915
RANK0178915
LOCAL_RANK017017

2: 동기화 (CUDA Runtime)

torch.distributed.all_reduce()가 호출되면 내부적으로 두 단계를 거칩니다. 먼저 PyTorch의 ProcessGroup C++ binding으로 전달되어 초기화된 ProcessGroup을 확인하고, NCCL backend를 사용하는지 확인합니다. 그 다음 ProcessGroupNCCL에서 NCCLComm 객체를 사용하여 NCCL 통신을 준비하고, NCCL library의 ncclAllReduce() 함수를 호출합니다.

all_reduce는 기본적으로 blocking 연산이지만, async_op=True를 지정하면 non-blocking으로 실행할 수 있어 통신과 계산을 겹칠 수 있습니다:

# non-blocking all_reduce() 예시
work = dist.all_reduce(
    local_result,
    op=dist.ReduceOp.SUM,
    async_op=True
)
do_something()  # 통신이 진행되는 동안 다른 작업 수행
work.wait()     # 통신 완료를 기다림

실제 PyTorch 내부 구현을 보면 이 동작이 명확히 드러납니다:

# torch/distributed/distributed_c10d.py
def all_reduce(tensor, op=ReduceOp.SUM,
               group=None, async_op=False):
    ...
    work = group.allreduce([tensor], opts)
    if async_op:
        return work
    else:
        work.wait()

3: Process Group 생성 및 통신 채널 설정

torch.distributed.init_process_group()은 프로세스들 간의 통신 채널을 생성하는 함수입니다. NCCL(GPU), GLOO(GPU/CPU), MPI(CPU) 등의 backend를 지정할 수 있으며, 초기화 방법(init_method), timeout, world size, rank 등의 metadata도 여기서 설정합니다.

초기화 방법은 두 가지가 있습니다. TCP 주소를 직접 명시하는 방법과 공유 filesystem을 사용하는 방법입니다:

# 주소를 직접 명시한 초기화
dist.init_process_group(
    backend="nccl",
    init_method='tcp://127.0.0.1:23456',
    world_size=world_size,
    rank=rank
)
# 공유 filesystem을 통한 초기화
dist.init_process_group(
    backend="nccl",
    init_method='file:///mnt/nfs/sharedfile',
    world_size=world_size,
    rank=rank
)

내부적으로는 torch.distributed가 Rendezvous를 통해 server/client 정보를 수집합니다. MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK 등의 환경변수를 참고해 정보를 수집하고, 이를 바탕으로 TCPStore를 생성함으로써 UNIX process 간 통신 환경을 준비합니다. 이때 rank == 0인 프로세스만 daemon을 생성하고, 나머지 프로세스들은 이 daemon에 connect하는 구조입니다.

# torch/distributed/rendezvous.py
if "rank" in query_dict:
    rank = int(query_dict["rank"])
else:
    rank = int(_get_env_or_raise("RANK"))
if "world_size" in query_dict:
    world_size = int(query_dict["world_size"])
else:
    world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
use_libuv = _get_use_libuv_from_query_dict(query_dict)
store = _create_c10d_store(
    master_addr, master_port, rank, world_size, timeout, use_libuv
)

4: Process Group 제거

torch.distributed.destroy_process_group()은 통신 자원을 해제하고 memory 등의 resource를 정리하는 함수입니다. 중요한 점은, group에 참여했던 모든 process들이 이를 호출해야 비로소 종료될 수 있다는 것입니다. Collective operation처럼 동작하므로, 일부 프로세스만 호출하면 나머지 프로세스들이 영원히 기다리게 됩니다.

NCCL backend의 경우, 내부적으로 NCCLComm 객체가 소멸되면서 ncclCommDestroy()를 호출하여 GPU 간 통신에 사용된 resource를 정리합니다.

PyTorch가 제공하는 모델 병렬화 패키지들

지금까지 c10d 레벨의 저수준 분산 통신을 살펴봤습니다. 실제 모델 학습에서는 이런 저수준 API를 직접 사용하기보다, PyTorch가 제공하는 고수준 병렬화 패키지를 사용하는 경우가 대부분입니다. 각 패키지는 서로 다른 병렬화 전략을 구현하고 있습니다.

**Distributed Data Parallel (DDP)**은 가장 기본적인 병렬화 방식으로, 동일한 모델을 여러 GPU에 복제하여 각 GPU가 서로 다른 데이터 배치를 처리합니다. 각 GPU에서 계산된 gradient를 all_reduce로 합산한 뒤 모든 GPU의 모델 파라미터를 동일하게 업데이트합니다.

**Fully Sharded Data Parallel (FSDP)**은 DDP의 메모리 효율성을 크게 개선한 방식입니다. parameter(weight, bias)를 world_size로 나누어서 각 GPU에 sharding하여 저장합니다. 계산이 필요할 때만 all_gather로 전체 parameter를 복원(unsharding)하고, 계산이 끝나면 다시 resharding하는 방식으로, GPU 메모리를 훨씬 효율적으로 사용할 수 있습니다.

**Tensor Parallel (TP)**은 하나의 parameter 자체를 여러 GPU에 분할하는 방식입니다. row parallel 또는 column parallel 방식으로 parameter를 GPU에 sharding하며, input/parameter가 분할 계획(plan)에 맞게 개별 GPU에 할당되어 계산한 후 output을 합산합니다.

이 외에도 **Sequence Parallel (SP)**과 **Pipeline Parallel (PP)**이 있으며, 실제 대규모 모델 학습에서는 이러한 기법들을 조합하여 사용합니다.

FSDP 예제 코드

FSDP를 사용하는 것은 놀라울 정도로 간단합니다. 일반적인 PyTorch 모델을 정의한 뒤, FSDP()로 감싸기만 하면 됩니다.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

모델 자체는 일반적인 CNN과 동일합니다. FSDP를 적용하려면 다음 두 줄이면 충분합니다:

def fsdp_main(rank, world_size, args):
    setup(rank, world_size)
    ……
    model = Net().to(rank)
    model = FSDP(model)
    ……

(PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel)

https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html

FSDP 내부 동작

FSDP의 내부 동작은 initialization, forward, backward 세 단계로 나눌 수 있습니다.

Initialization 단계에서는 먼저 FlatParameter를 설정합니다. 이 과정에서 개별 process의 sharding 정보를 결정하고, sharded memory를 할당하며, 불필요한 unsharded memory를 해제합니다. 그리고 forward/backward 실행 시 자동으로 호출될 hook들을 등록합니다. pre hook에서는 all_gather를, post hook에서는 unsharding을 수행하도록 설정합니다.

Forward 단계에서는 각 layer마다 all_gather로 sharded parameter를 복원하고, computation을 수행한 뒤, 다시 unsharding하는 과정을 반복합니다. 한 layer의 계산이 끝나면 해당 layer의 전체 parameter를 다시 버려서 메모리를 절약합니다.

Backward 단계에서는 forward와 유사하게 각 layer에서 all_gather → computation → unsharding 과정을 거치지만, 추가로 gradient를 동기화하기 위한 reduce_scatter 연산이 수행됩니다. 이 hook은 _register_post_backward_hook을 통해 등록됩니다.

Tensor Parallel 예제 코드

Tensor Parallel은 FSDP와 달리 개발자가 **분할 계획(plan)**을 직접 설계해야 합니다. 어떤 layer의 어떤 parameter를 column 방향으로 쪼갈지, row 방향으로 쪼갈지를 명시적으로 지정해야 하는 것입니다.

from torch.distributed.device_mesh \
    import init_device_mesh
from torch.distributed.tensor.parallel \
    import ColwiseParallel, \
           RowwiseParallel, \
           parallelize_module

tp_mesh = init_device_mesh("cuda", (8,))

layer_tp_plan = {
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

for layer_id, transformer_block \
        in enumerate(model.layers):
    layer_tp_plan = {...}
    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan,
    )

(Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism)

https://pytorch.org/tutorials/intermediate/TP_tutorial.html

위 코드에서 볼 수 있듯이, feed_forward의 w1은 column 방향으로, w2는 row 방향으로 분할하는 식의 계획을 세웁니다. 이러한 분할 전략은 Megatron-LM 등의 연구에서 제안된 기법을 활용하여 추상화한 것으로, PyTorch의 parallelize_module API를 통해 비교적 간결하게 적용할 수 있습니다.

Next Week Preview: Beyond PyTorch

다음 주에는 PyTorch의 경계를 넘어서는 두 가지 주제를 다룰 예정입니다.

Custom kernel은 두 가지 경우에 필요합니다. 첫째는 기존 op으로 cover되지 않는 새로운 op이 필요한 경우이고, 둘째는 PyTorch에서 제공하는 op의 성능이 부족한 경우입니다. CUDA, cuBLAS, cuTLASS, Triton 등을 사용하여 작성할 수 있습니다.

vLLM은 대규모 언어 모델 추론을 위한 핵심 프레임워크입니다. vLLM의 전체 구조, vLLM이 제공하는 최적화 기능들, 그리고 vLLM이 요구하는 custom op들을 살펴볼 예정입니다.