반응형
안녕하세요.
이번 글에서는 onnx 파일로 저장된 모델을 pth 파일로 변환하는 방법에 대하여
간략하게 정리해보도록 하겠습니다.
1. onnx2torch 모듈 활용
만일, onnx 파일 내의 모델이 Conv, ReLU 등 기본 연산자로만 구성되었고,
입력 및 출력 shape이 고정된 등 비교적 간단한 경우,
onnx2torch 모듈로 onnx -> pth 파일 변환을 매우 쉽게 할 수 있습니다.
사용 방법은 다음과 같습니다.
# (필요 시) 라이브러리 설치
# !pip install onnx2torch
# 변환
from onnx2torch import convert
torch_model = convert("your_model.onnx")
# 저장
torch.save(torch_model.state_dict(), "converted_model.pth")
2. 수동 변환 방법
앞에서 소개한 onnx2torch 모듈의 변환 방법이 실패했다면,
다음과 같은 수동 변환 방법을 적용해서 변환을 시도해야 합니다.
다만, 이 방법은 onnx 파일 내 모델이 구현된 구조의 코드를 알고 있어야 가능합니다.
import torch
import torch.nn as nn
import onnx
from collections import OrderedDict
import numpy as np
from onnx import numpy_helper
# 1. 다음과 같이 모델의 구조를 알고 있는 상황 가정
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.fc1 = nn.Linear(64 * 28 * 28, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
# 2. ONNX 모델에서 가중치를 추출
onnx_model_path = 'your_model.onnx'
onnx_model = onnx.load(onnx_model_path)
onnx_weights = {init.name: numpy_helper.to_array(init) for init in onnx_model.graph.initializer}
# 3. PyTorch 모델에 맞게 가중치를 state_dict 형태로 변환
# ONNX 모델 -> PyTorch 모델 가중치 이름 매핑 필요
mapping = {
'conv1_weight': 'conv1.weight',
'conv1_bias': 'conv1.bias',
'fc1_weight': 'fc1.weight',
'fc1_bias': 'fc1.bias',
}
# 변환
new_state_dict = OrderedDict()
for onnx_name, pytorch_name in mapping.items():
if onnx_name in onnx_weights:
tensor = torch.from_numpy(onnx_weights[onnx_name])
# Linear 레이어 가중치이면 전치(transpose)
if 'fc' in pytorch_name and 'weight' in pytorch_name:
tensor = tensor.T
new_state_dict[pytorch_name] = tensor
# 4. PyTorch 모델을 생성하고 가중치를 로드
pytorch_model = MyModel()
# strict=False로 일부 불일치 허용할 수 있는데, 모델 구조가 원래 구현했던 구조와 다를 수 있으므로 유의 필요!!!
pytorch_model.load_state_dict(new_state_dict, strict=False)
# 5. .pth 파일로 저장
torch.save(pytorch_model.state_dict(), 'converted_model.pth')
만일, 모델의 구조 코드를 모르고 있거나, layer 이름 매핑을 맞추기 어려운 등
수동 변환이 까다로운 상황이라면 onnx 파일을 변환하기 보다는
onnxruntime 모듈을 활용하여 그대로 사용하는 방법을 권장드립니다.
이 글이 onnx 파일을 다루는 과정에 도움이 되셨다면 좋겠습니다.
잘 봐주셔서 감사드립니다.
'Python > Pytorch' 카테고리의 다른 글
| [Pytorch] model.eval() vs torch.no_grad()의 차이 (1) | 2025.06.03 |
|---|---|
| [Pytorch] 파이토치 ReLU 함수 종류 총정리(ReLU, LeakyReLU, PReLU, RReLU, ReLU6) (0) | 2025.05.18 |
| [Pytorch] checkpoint vs torchscript vs onnx 모델 속도 비교 (0) | 2023.09.14 |