import io
import torch
import onnx
import torch.onnx
from onnx2pytorch import ConvertModel
def to_onnx(model, inp_size, device=torch.device("cpu"), do_constant_folding=False):
if isinstance(inp_size, (tuple, list)) and not isinstance(inp_size[0], int):
input_image = tuple([torch.rand(i, device=device) for i in inp_size])
else:
input_image = torch.rand(inp_size, device=device)
model.to(device)
bitstream = io.BytesIO()
torch.onnx.export(
model,
input_image,
bitstream,
export_params=True,
opset_version=11,
do_constant_folding=do_constant_folding,
input_names=["input"],
output_names=["output"],
)
return onnx.ModelProto.FromString(bitstream.getvalue())
def to_converted(model, inp_size):
onnx_model = to_onnx(model, inp_size)
model = ConvertModel(onnx_model)
return model