Transformer text generation 원리, 코드 구현 예제
트랜스포머 구조의 모델에서 텍스트를 생성하는 원리를 간단히 정리해보고,
허깅페이스에서 지원하는 Transformers 모듈을 활용하여
텍스트를 생성하는 코드를 구현하는 과정에 대하여 다루어보도록 하겠습니다.
트랜스포머 모델 텍스트 생성 원리(인코더-디코더 기반)
Transformer 구조 모델에서 텍스트를 생성하는 원리를 간단하게 먼저 정리해보겠습니다.
Encoder-Decoder 기반 구조로 이루어져있는 트랜스포머 기반 모델의 특징을 활용하여
인코더에서 input text를 임베딩한 결과와
이전 단계까지에서 생성된 output token을 디코더에서 받아들여 예측된 토큰 확률 분포에서
가장 확률이 높은 토큰 or 확률에 따른 샘플링 방법 등으로 다음 토큰을 예측하여
end token이 등장하거나 지정한 max length에 도달할 때까지 토큰을 생성하는 방식으로
텍스트 생성이 이루어지게 됩니다.
위 과정에 대한 청사진을 그림으로 요약한 결과는 다음과 같이 나타낼 수 있습니다.
여기서 0번 토큰을 start token, 4번 토큰을 end token으로 가정해보겠습니다.
초반에는 인코더 임베딩 결과 + 시작 토큰(0번)만 디코더의 input으로 들어가고 있습니다.
최대 확률을 가지는 토큰으로 생성하는 상황을 가정한다면,
위의 그림에서는 1번 토큰의 확률이 가장 높으므로, 1번 토큰이 생성되게 됩니다.
다음 토큰을 이어서 생성할 때는 인코더 임베딩 결과는 동일하게 input으로 들어가고,
전까지 생성된 토큰들의 sequence(0 1)가 새롭게 디코더의 input에 적용되어
다음 토큰이 생성되는 것을 볼 수 있습니다.
같은 과정을 반복하다가, end token(여기서는 4번)이 생성되면서
텍스트 생성 과정이 종료되게 되고, 최종 생성된 sequence는 0 1 3 4번 토큰이 됩니다.
트랜스포머 모델 텍스트 생성 과정 코드 직접 구현해보기
먼저, 텍스트 생성 과정을 직접 구현해보는 예시를 살펴보겠습니다.
다른 구조의 트랜스포머 구조의 모델에 모두 적용이 가능하지만,
여기서는 BART 구조의 모델에서 텍스트를 생성하는 모델 구현 코드 스니펫을 살펴보겠습니다.
(각 줄 코드의 설명은 주석을 참고해주세요.)
import torch
import torch.nn as nn
import torch.nn.Functional as F
from transforemrs import BartForConditionalGeneration
max_length = 50
batch_size = 8
start_token = 0
end_token = 2
class BART_generator(nn.Module):
def __init__(self):
super(BART_generator, self).__init__()
# 학습할 bart 구조 모델 가져오기(ForConditionalGeneration 형태)
self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
def forward(self, input_tokens, max_length=max_length, batch_size=batch_size):
# 배치 사이즈 만큼의 시작 토큰 텐서 생성
decoder_result = torch.tensor([[start_token]]).repeat(batch_size, 1)
# 배치 내 각 문장의 생성이 종료되었는지 여부를 알려주는 텐서(1 : 미종료, 0: 종료)
is_end = torch.tensor([1] * batch_size).reshape(-1, 1)
for i in range(max_length):
# input token과 decoder_input(이전까지 생성된 토큰)의 id를 지정해주기
outputs = self.bart(input_ids=input_tokens.input_ids, decoder_input_ids=decoder_result)
logits = outputs.logits
# 결과 logits에 softmax를 취하기
logits_softmax = F.softmax(logits, 2)[:, i]
# 최대 value를 가지는 위치의 토큰을 생성, 단 종료 시에는 is_end = 0이므로 0번인 padding 생성
generated_token = torch.argmax(logits_softmax, 1).reshape(-1, 1) * is_end
# 생성된 토큰 이어 붙이기
decoder_result = torch.cat([decoder_result, generated_token], dim=1)
is_end = is_end * (generated_token != end_token) # end token 생성->종료 알림
if is_end.sum().item() == 0: # 배치 내 모든 문장 생성 종료 시 for문 종료
break
return decoder_result
다른 모델에서도 전반적인 구조는 거의 동일하게 가져갈 수 있으며,
(start, end 토큰 번호 지정 및 모델 호출 메소드 종류를 변경해주시면 됩니다.)
ForConditionalGeneration 형태 함수가 토큰 확률 분포의 logits을 반환하므로
해당 메소드를 가져와주시면 용이하게 구현이 가능합니다.
배치 내 각 문장의 생성이 종료된 경우는 0번 토큰인 padding이 생성되도록
is_end로 지정을 해주었으며,
input_ids는 계속 고정되고 decoder_ids 부분만 하나씩 추가된 형태로 변경되는 구조를
눈여겨보시면 좋습니다.
허깅페이스의 기능을 활용하여 간단하게 텍스트 생성해보기
사실 위 과정에서 직접 구현했던 코드는 트랜스포머 모델의 generate 메소드의 원리와
거의 비슷하며, 여기서는 확률 기반 샘플링 등 다양한 기능을 사용할 수 있어
매우 편리하게 다양한 조건의 텍스트 생성 결과를 받아볼 수 있습니다.
generate 메소드의 사용법은 간단한데, 아래의 허깅페이스 공식 페이지를 참고하시면 됩니다.
(이 글에서는 추가 설명은 생략하도록 하겠습니다.)
혹은, pipeline 기능을 활용해서도 쉽게 텍스트 생성을 진행해볼 수 있습니다.
from transformers import pipeline
# 파이프라인 모드(텍스트 생성), 모델 종류 지정
generator = pipeline('text-generation', model = 'facebook/bart-base')
# input 문장 및 최대 길이 등 조건 지정
generator("Hello, I am ", max_length = 50)
'Python > NLP Code' 카테고리의 다른 글
CLS 토큰이란? / 파이썬 BERT CLS 임베딩 벡터 추출 예제 (3) | 2022.09.03 |
---|---|
파이썬 BERT 모델 활용 IMDB 데이터셋 감성 분석 classification 예제 (0) | 2022.06.18 |
파이썬 텍스트 데이터 증강 모듈 : nlpaug (0) | 2022.05.13 |