I tried to convert a conv2d layer to TensorRT, and I found that with different params can result in different accuracy between fp16 and fp32. Anyone could give me some suggestions?
You can reproduce my result using following code,when changing linein_chans, embed_dim, patch_HW = 3, 128, (14,14) to in_chans, embed_dim, patch_HW = 3, 384, (14,14)
tensorrt version: 8.5.3.1 pytorch version: 1.13.1+cu117 cuda version: 12.1
import torch
import torch.nn as nn
import onnx
import onnxruntime
import numpy as np
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver
stream = pycuda.driver.Stream()
# define torch model
class test_conv2d_fp16_trt(nn.Module):
def __init__(self, in_chans, embed_dim, patch_HW):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
nn.init.xavier_uniform_(self.proj.weight)
def forward(self, x):
return self.proj(x)
# convert to onnx
def convert_to_onnx(x, torch_model, onnx_model_path):
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_model_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=False, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def onnx_model_infer(x, onnx_model_path):
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(onnx_model_path)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def check_output_match(output1, output2):
np.testing.assert_allclose(output1, output2, rtol=1e-03, atol=1e-05)
def build_trt_model(logger, input_path, output_path, input_shape, is_fp16):
builder = trt.Builder(logger)
config = builder.create_builder_config()
profile = builder.create_optimization_profile()
profile.set_shape(
'input', min=input_shape,
opt=input_shape,
max=input_shape)
config.add_optimization_profile(profile)
config.max_workspace_size = 4 << 30
config.min_timing_iterations = 1
if is_fp16:
config.flags = 1 << int(trt.BuilderFlag.FP16)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
with trt.OnnxParser(network, logger) as parser:
with open(input_path, 'rb') as fp:
parser.parse(fp.read())
engine = builder.build_engine(network, config=config)
assert engine != None
with open(output_path, 'wb') as fp:
fp.write(engine.serialize())
print('TRT{} engine is built.'.format(' (fp16)' if is_fp16 else ''))
def test_trt_model(logger, trt_model_path, input_shape=(1,3,504,378), output_shape=(1,384,36,27)):
with trt.Runtime(logger) as runtime, open(trt_model_path, 'rb') as fp:
engine = runtime.deserialize_cuda_engine(fp.read())
image = np.ones(input_shape, dtype=np.float32)
image_data = np.ascontiguousarray(image).astype(np.float32)
# prepare GPU memory buffer
image_gdata = pycuda.driver.mem_alloc(image_data.nbytes)
output_data = np.zeros(output_shape, np.float32)
output_gdata = pycuda.driver.mem_alloc(output_data.nbytes)
with engine.create_execution_context() as context:
context.active_optimization_profile = 0
context.set_binding_shape(0, image_data.shape)
# run the network
pycuda.driver.memcpy_htod_async(
image_gdata, image_data, stream)
context.execute_async_v2(
bindings=[int(image_gdata), int(output_gdata)],
stream_handle=stream.handle)
pycuda.driver.memcpy_dtoh_async(
output_data, output_gdata, stream)
stream.synchronize()
image_gdata.free()
output_gdata.free()
return output_data
if __name__ == "__main__":
# fixed params
im_h, im_w = 504, 378
input_shape = (1,3,504,378)
onnx_model_path = "check_conv2d_128.onnx"
trt_fp16_model_path = "check_conv2d_128_fp16.ts"
trt_fp32_model_path = "check_conv2d_128_fp32.ts"
# debug params
in_chans, embed_dim, patch_HW = 3, 384, (14,14) # failed case
# in_chans, embed_dim, patch_HW = 3, 128, (14,14) # success case
output_shape=(input_shape[0], embed_dim, im_h//patch_HW[0], im_w//patch_HW[0])
torch_model = test_conv2d_fp16_trt(in_chans, embed_dim, patch_HW)
torch_model.eval()
image = np.ones(input_shape, dtype=np.float32)
x = torch.from_numpy(image)
torch_output = to_numpy(torch_model(x))
print("build onnx")
convert_to_onnx(x, torch_model, onnx_model_path)
onnx_output = onnx_model_infer(x, onnx_model_path)
check_output_match(torch_output, onnx_output)
print("torch_output match onnx_output success!")
logger = trt.Logger(trt.Logger.WARNING)
print("build trt")
build_trt_model(logger, onnx_model_path, trt_fp16_model_path, input_shape=input_shape, is_fp16=True)
build_trt_model(logger, onnx_model_path, trt_fp32_model_path, input_shape=input_shape, is_fp16=False)
trt_fp16_output = test_trt_model(logger, trt_fp16_model_path, input_shape=input_shape, output_shape=output_shape)
trt_fp32_output = test_trt_model(logger, trt_fp32_model_path, input_shape=input_shape, output_shape=output_shape)
check_output_match(onnx_output, trt_fp32_output)
print("trt_fp32_output match onnx_output success!")
check_output_match(trt_fp16_output, trt_fp32_output)
print("trt_fp16_output match trt_fp32_output success!")