pytorch convert a conv2d layer to tensorrt results in fp16 != fp32

61 Views Asked by At

I tried to convert a conv2d layer to TensorRT, and I found that with different params can result in different accuracy between fp16 and fp32. Anyone could give me some suggestions?

You can reproduce my result using following code,when changing linein_chans, embed_dim, patch_HW = 3, 128, (14,14) to in_chans, embed_dim, patch_HW = 3, 384, (14,14)

tensorrt version: 8.5.3.1 pytorch version: 1.13.1+cu117 cuda version: 12.1

import torch
import torch.nn as nn
import onnx
import onnxruntime
import numpy as np
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver
stream = pycuda.driver.Stream()

# define torch model
class test_conv2d_fp16_trt(nn.Module):
    def __init__(self, in_chans, embed_dim, patch_HW):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
        nn.init.xavier_uniform_(self.proj.weight)

    def forward(self, x):
        return self.proj(x)

# convert to onnx
def convert_to_onnx(x, torch_model, onnx_model_path):
    # Export the model
    torch.onnx.export(torch_model,               # model being run
                      x,                         # model input (or a tuple for multiple inputs)
                      onnx_model_path,   # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=14,          # the ONNX version to export the model to
                      do_constant_folding=False,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def onnx_model_infer(x, onnx_model_path):
    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model)
    ort_session = onnxruntime.InferenceSession(onnx_model_path)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)
    return ort_outs[0]
    
def check_output_match(output1, output2):
    np.testing.assert_allclose(output1, output2, rtol=1e-03, atol=1e-05)

def build_trt_model(logger, input_path, output_path, input_shape, is_fp16):
    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    profile = builder.create_optimization_profile()
    profile.set_shape(
        'input', min=input_shape,
        opt=input_shape,
        max=input_shape)
    config.add_optimization_profile(profile)
    config.max_workspace_size = 4 << 30
    config.min_timing_iterations = 1
    if is_fp16:
        config.flags = 1 << int(trt.BuilderFlag.FP16)
    
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    with trt.OnnxParser(network, logger) as parser:
        with open(input_path, 'rb') as fp:
            parser.parse(fp.read())
            engine = builder.build_engine(network, config=config)
    
    assert engine != None
    with open(output_path, 'wb') as fp:
        fp.write(engine.serialize())
    
    print('TRT{} engine is built.'.format(' (fp16)' if is_fp16 else ''))

def test_trt_model(logger, trt_model_path, input_shape=(1,3,504,378), output_shape=(1,384,36,27)):
    with trt.Runtime(logger) as runtime, open(trt_model_path, 'rb') as fp:
        engine = runtime.deserialize_cuda_engine(fp.read())
    
    image = np.ones(input_shape, dtype=np.float32)
    image_data = np.ascontiguousarray(image).astype(np.float32)
    
    # prepare GPU memory buffer
    image_gdata = pycuda.driver.mem_alloc(image_data.nbytes)
    output_data = np.zeros(output_shape, np.float32)
    output_gdata = pycuda.driver.mem_alloc(output_data.nbytes)
    with engine.create_execution_context() as context:
        context.active_optimization_profile = 0
        context.set_binding_shape(0, image_data.shape)
    
        # run the network
        pycuda.driver.memcpy_htod_async(
            image_gdata, image_data, stream)
        context.execute_async_v2(
            bindings=[int(image_gdata), int(output_gdata)],
            stream_handle=stream.handle)
        pycuda.driver.memcpy_dtoh_async(
            output_data, output_gdata, stream)
        stream.synchronize()
    
    image_gdata.free()
    output_gdata.free()
    return output_data

if __name__ == "__main__":
    # fixed params
    im_h, im_w = 504, 378
    input_shape = (1,3,504,378)
    
    onnx_model_path = "check_conv2d_128.onnx"
    trt_fp16_model_path = "check_conv2d_128_fp16.ts"
    trt_fp32_model_path = "check_conv2d_128_fp32.ts"

    # debug params
    in_chans, embed_dim, patch_HW = 3, 384, (14,14) # failed case
    # in_chans, embed_dim, patch_HW = 3, 128, (14,14) # success case
    output_shape=(input_shape[0], embed_dim, im_h//patch_HW[0], im_w//patch_HW[0])
    torch_model = test_conv2d_fp16_trt(in_chans, embed_dim, patch_HW)
    torch_model.eval()

    image = np.ones(input_shape, dtype=np.float32)
    x = torch.from_numpy(image)
    
    torch_output = to_numpy(torch_model(x))
    print("build onnx")
    convert_to_onnx(x, torch_model, onnx_model_path)
    onnx_output = onnx_model_infer(x, onnx_model_path)

    check_output_match(torch_output, onnx_output)
    print("torch_output match onnx_output success!")

    logger = trt.Logger(trt.Logger.WARNING)
    print("build trt")
    build_trt_model(logger, onnx_model_path, trt_fp16_model_path, input_shape=input_shape, is_fp16=True)
    build_trt_model(logger, onnx_model_path, trt_fp32_model_path, input_shape=input_shape, is_fp16=False)

    trt_fp16_output = test_trt_model(logger, trt_fp16_model_path, input_shape=input_shape, output_shape=output_shape)
    trt_fp32_output = test_trt_model(logger, trt_fp32_model_path, input_shape=input_shape, output_shape=output_shape)

    check_output_match(onnx_output, trt_fp32_output)
    print("trt_fp32_output match onnx_output success!")

    check_output_match(trt_fp16_output, trt_fp32_output)
    print("trt_fp16_output match trt_fp32_output success!")

0

There are 0 best solutions below