Python/NLP Code

Transformers 모델 병렬화(Model Parallelism) 간단하게 하는 방법

jimmy_AI 2022. 3. 21. 23:02
반응형

Transformers T5, GPT2 등 Model Parallelism

Transformers 라이브러리 내 T5, GPT-2 등 파라미터 사이즈가 큰 일부 모델에 대하여

모델 파라미터 병렬처리를 간단하게 할 수 있는 parallelize 함수와 device map의

사용 방법에 대해서 다루어보도록 하겠습니다.

 

이 글은 HuggingFace의 공식 document 내용을 바탕으로 작성되었습니다.

 

 

T5 모델 병렬처리 예시

Transformers에서 제공하는 T5 모델 크기의 attention module의 개수에 따라

아래처럼 device마다 할당할 module의 번호를 지정해주시면 됩니다.

 

t5-small은 6개, t5-base는 12개, t5-large, t5-3b 및 t5-11b는 24개

attention module을 가집니다.

 

예를 들어, GPU 3개에 t5-base의 어텐션 모듈을 4개씩 나누려면 아래처럼 작성하면 됩니다.

model = T5ForConditionalGeneration.from_pretrained("t5-base")
device_map = {
    0: [0, 1, 2, 3],
    1: [4, 5, 6, 7],
    2: [8, 9, 10, 11]
}
model.parallelize(device_map)

만일, 위에서 지정한 모델 병렬처리를 해제하기 위해서는 model.deparallelize()

함수를 다시 적용해주시면 됩니다.

 

GPT-2도 해당 기능을 지원하며, HuggingFace 공식 document 내에서 parallelize 메소드를

제공하고 있는 모델 종류라면 같은 방식으로 적용이 가능합니다.