Python/Pytorch

[Pytorch] onnx -> pth 파일 변환 방법 정리

jimmy_AI 2025. 9. 22. 00:24
반응형

안녕하세요.

이번 글에서는 onnx 파일로 저장된 모델을 pth 파일로 변환하는 방법에 대하여

간략하게 정리해보도록 하겠습니다.

 

 

1. onnx2torch 모듈 활용

만일, onnx 파일 내의 모델이 Conv, ReLU 등 기본 연산자로만 구성되었고,

입력 및 출력 shape이 고정된 등 비교적 간단한 경우,

onnx2torch 모듈로 onnx -> pth 파일 변환을 매우 쉽게 할 수 있습니다.

사용 방법은 다음과 같습니다.

# (필요 시) 라이브러리 설치
# !pip install onnx2torch

# 변환
from onnx2torch import convert
torch_model = convert("your_model.onnx")

# 저장
torch.save(torch_model.state_dict(), "converted_model.pth")

 

 

2. 수동 변환 방법

앞에서 소개한 onnx2torch 모듈의 변환 방법이 실패했다면,

다음과 같은 수동 변환 방법을 적용해서 변환을 시도해야 합니다.

다만, 이 방법은 onnx 파일 내 모델이 구현된 구조의 코드를 알고 있어야 가능합니다.

import torch
import torch.nn as nn
import onnx
from collections import OrderedDict
import numpy as np
from onnx import numpy_helper

# 1. 다음과 같이 모델의 구조를 알고 있는 상황 가정
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# 2. ONNX 모델에서 가중치를 추출
onnx_model_path = 'your_model.onnx'
onnx_model = onnx.load(onnx_model_path)
onnx_weights = {init.name: numpy_helper.to_array(init) for init in onnx_model.graph.initializer}

# 3. PyTorch 모델에 맞게 가중치를 state_dict 형태로 변환

# ONNX 모델 -> PyTorch 모델 가중치 이름 매핑 필요
mapping = {
    'conv1_weight': 'conv1.weight',
    'conv1_bias': 'conv1.bias',
    'fc1_weight': 'fc1.weight',
    'fc1_bias': 'fc1.bias',
}

# 변환
new_state_dict = OrderedDict()
for onnx_name, pytorch_name in mapping.items():
    if onnx_name in onnx_weights:
        tensor = torch.from_numpy(onnx_weights[onnx_name])
        # Linear 레이어 가중치이면 전치(transpose)
        if 'fc' in pytorch_name and 'weight' in pytorch_name:
            tensor = tensor.T
        new_state_dict[pytorch_name] = tensor

# 4. PyTorch 모델을 생성하고 가중치를 로드
pytorch_model = MyModel()
# strict=False로 일부 불일치 허용할 수 있는데, 모델 구조가 원래 구현했던 구조와 다를 수 있으므로 유의 필요!!!
pytorch_model.load_state_dict(new_state_dict, strict=False)

# 5. .pth 파일로 저장
torch.save(pytorch_model.state_dict(), 'converted_model.pth')

 

만일, 모델의 구조 코드를 모르고 있거나, layer 이름 매핑을 맞추기 어려운 등

수동 변환이 까다로운 상황이라면 onnx 파일을 변환하기 보다는

onnxruntime 모듈을 활용하여 그대로 사용하는 방법을 권장드립니다.

 

이 글이 onnx 파일을 다루는 과정에 도움이 되셨다면 좋겠습니다.

잘 봐주셔서 감사드립니다.