1. load pre-trained model
2. export onnx
3. load onnx
refer to code:
.
import warnings
from torch.jit import TracerWarning
warnings.filterwarnings("ignore", category=TracerWarning)
#------------------
#swin-transformer v2 pretrained model
#------------------
from transformers import AutoImageProcessor, Swinv2Model
import torch
from datasets import load_dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
# print( list(last_hidden_states.shape) )
# Convert last_hidden_states to numpy
last_hidden_states_numpy = last_hidden_states.detach().numpy()
print(f"Shape of last_hidden_states: {last_hidden_states_numpy.shape}")
print(last_hidden_states)
#----------------
#onnx export
#------------------
import torch
from torch.autograd import Variable
# ensure the model is in evaluation mode
model.eval()
# create a dummy variable with the same size as your input
# for this example, let's assume the input is of size [1, 3, 256, 256]
dummy_input = Variable(torch.randn(1, 3, 256, 256))
# specify the file path
file_path = "./swinv2_tiny.onnx"
# export the model
torch.onnx.export(model, dummy_input, file_path)
#------------------
#onnx inference
#------------------
import onnxruntime as ort
# load the ONNX model
ort_session = ort.InferenceSession(file_path)
# convert the PyTorch tensor to numpy array for onnxruntime
print(inputs.keys())
inputs_numpy = inputs["pixel_values"].numpy()
# inputs_numpy = inputs["input_ids"].numpy()
# create a dictionary from model input name to the actual input data
ort_inputs = {ort_session.get_inputs()[0].name: inputs_numpy}
# forward
ort_outs = ort_session.run(None, ort_inputs)
print(f"Shape of ort_outs: {ort_outs[0].shape}")
print(ort_outs)
# print(type(ort_outs))
# print( list(ort_outs.shape) )
..
Thank you.
www.marearts.com
๐๐ป♂️
No comments:
Post a Comment