How do I export my fastai resnet50/vision_learner trained model into torchserve?

32 Views Asked by At

My goal is to deploy a model I trained with Fastai into Torchserve. I was following this tutorial but got stuck on the part where he created the model class for pytorch.

He mentions that to run our model in Torchserve, and we need the following:

  1. A model class
  2. The weights exported from pytorch (a pth file)
  3. A handler

Out of these, I get two: the weights and the handler. However, where I'm stuck is in the model class. He created one class file, but I have no idea where he got the DynamicUnet to use as a base for the class or how he mixed that class with unet_learner to create a custom pytorch model class. Can you help me build a model class for a model trained under the learner vision_learner and the pre-trained model of resnet50?

1

There are 1 best solutions below

0
Antonio Tapia On

Found the fix myself. Turns out you can just go into the ??create_vision_model and then the ??add_head and put them as the model class inside of the "initialize" function inside of the handler.py ; you should end up with something like this:

state_dict = torch.load(model_pt_path, map_location=self.device)
head = None
concat_pool = True
pool = True
lin_ftrs = None
ps = 0.5
first_bn = True
bn_final = False
lin_first = False
y_range = None
init = nn.init.kaiming_normal_
arch = resnet50
n_out = 9
pretrained = True
cut = None
n_in = 3
custom_head = None
# self.model = MyVisionModel()
meta = model_meta.get(arch, _default_meta)
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None
if head is None:
    head = create_head(nf, n_out, concat_pool=concat_pool, pool=pool,
                       lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first,
                       y_range=y_range)
self.model = nn.Sequential(body, head)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()

logger.debug("Model file {0} loaded successfully".format(model_pt_path))
self.initialized = True