5/18/2023

swin transformer v2 - model forward and export onnx


1. load pre-trained model

2. export onnx

3. load onnx


refer to code:


.

import warnings
from torch.jit import TracerWarning
warnings.filterwarnings("ignore", category=TracerWarning)

#------------------
#swin-transformer v2 pretrained model
#------------------

from transformers import AutoImageProcessor, Swinv2Model
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

# print( list(last_hidden_states.shape) )
# Convert last_hidden_states to numpy
last_hidden_states_numpy = last_hidden_states.detach().numpy()
print(f"Shape of last_hidden_states: {last_hidden_states_numpy.shape}")
print(last_hidden_states)



#----------------
#onnx export
#------------------
import torch
from torch.autograd import Variable

# ensure the model is in evaluation mode
model.eval()

# create a dummy variable with the same size as your input
# for this example, let's assume the input is of size [1, 3, 256, 256]
dummy_input = Variable(torch.randn(1, 3, 256, 256))

# specify the file path
file_path = "./swinv2_tiny.onnx"

# export the model
torch.onnx.export(model, dummy_input, file_path)

#------------------
#onnx inference
#------------------
import onnxruntime as ort

# load the ONNX model
ort_session = ort.InferenceSession(file_path)

# convert the PyTorch tensor to numpy array for onnxruntime
print(inputs.keys())
inputs_numpy = inputs["pixel_values"].numpy()
# inputs_numpy = inputs["input_ids"].numpy()

# create a dictionary from model input name to the actual input data
ort_inputs = {ort_session.get_inputs()[0].name: inputs_numpy}

# forward
ort_outs = ort_session.run(None, ort_inputs)
print(f"Shape of ort_outs: {ort_outs[0].shape}")
print(ort_outs)
# print(type(ort_outs))
# print( list(ort_outs.shape) )

..


Thank you.

www.marearts.com

๐Ÿ™‡๐Ÿป‍♂️

No comments:

Post a Comment