Sequential.forward() got an unexpected keyword argument 'pretrained'

175 Views Asked by At

I want to customize ResNet34. So I wrote code like this.

def get_labels(file_path):
    file_name = file_path.stem
    matching_row = train_df[train_df['Id'] == file_name.split('.')[0]]
    
    if not matching_row.empty:
        label_str = matching_row.iloc[0]['Target']
        labels = list(map(int, label_str.split()))
    else:
        labels = []
        
    return labels

#Create dataBlock

def get_dls(image_path, bs=64): 
    return DataBlock(
        blocks = (ImageBlock, MultiCategoryBlock),
        splitter = RandomSplitter(valid_pct=0.2), 
        get_items = lambda path: get_image_files(image_path), 
        get_y = get_labels,
        item_tfms = Resize(512),
        batch_tfms = [*aug_transforms(size=512), Normalize.from_stats(*imagenet_stats)]      
).dataloaders(image_path, bs=bs)

dls = get_dls(train_directory, bs=64).to(device)

import sklearn.metrics

def macro_f1_score(preds, targets):
    preds = (preds > 0.5).int().cpu().numpy()# 
    targets = targets.cpu().numpy()
    return sklearn.metrics.f1_score(targets, preds, average='macro', zero_division=1)

resnet34 = fastai.vision.models.resnet34(pretrained=False)
num_features = resnet34.fc.in_features
resnet34 = nn.Sequential(*list(resnet34.children())[:-2]) 

num_classes = 28

custom_layers = nn.Sequential(
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(num_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, num_classes)
)

custom_resnet34 = nn.Sequential(
    resnet34,
    custom_layers
)

for param in custom_resnet34[0].parameters():
    param.requires_grad=False

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(custom_resnet34.parameters(), lr=0.001)


learn = vision_learner(dls, custom_resnet34, metrics=macro_f1_score)

And I've got this issue


TypeError Traceback (most recent call last) Cell In[115], line 1 ----> 1 learn = vision_learner(dls, custom_resnet34, metrics=macro_f1_score)

File /opt/conda/lib/python3.10/site-packages/fastai/vision/learner.py:228, in vision_learner(dls, arch, normalize, n_out, pretrained, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, cut, init, custom_head, concat_pool, pool, lin_ftrs, ps, first_bn, bn_final, lin_first, y_range, **kwargs) 226 else: 227 if normalize: _add_norm(dls, meta, pretrained, n_in) --> 228 model = create_vision_model(arch, n_out, pretrained=pretrained, **model_args) 230 splitter = ifnone(splitter, meta['split']) 231 learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs, 232 metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn, moms=moms)

File /opt/conda/lib/python3.10/site-packages/fastai/vision/learner.py:164, in create_vision_model(arch, n_out, pretrained, cut, n_in, init, custom_head, concat_pool, pool, lin_ftrs, ps, first_bn, bn_final, lin_first, y_range) 162 "Create custom vision architecture" 163 meta = model_meta.get(arch, _default_meta) --> 164 model = arch(pretrained=pretrained) 165 body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut'])) 166 nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: Sequential.forward() got an unexpected keyword argument 'pretrained'

When I got ResNet34, argument 'pretrained' was False. But there's error about argument 'pretrained'. How can I resolve this issue?

0

There are 0 best solutions below