How to resolve the stride assert error produced by the PyTorch compile?

202 Views Asked by At

I am currently trying to run the code from the section 3.2 of the "A QuickPyTorch 2.0 Tutorial" (https://www.learnpytorch.io/pytorch_2_intro/#27-create-training-and-testing-loops). When I run the code, I get the following error:

AssertionError: expected size 64==64, stride 3136==1 at dim=1

enter image description here

Partial output and full error below:

...
/home/isaac-aktam/anaconda3/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/isaac-aktam/anaconda3/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/isaac-aktam/anaconda3/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/isaac-aktam/anaconda3/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
Training Epoch 0:  99%|█████████▉| 195/196 [01:53<00:00,  1.72it/s, train_loss=0.671, train_acc=0.769]
  0%|          | 0/5 [01:53<?, ?it/s]

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[25], line 43
     39 print(f"Time to compile: {compile_time} | Note: the first time you compile a model, the first epoch may take longer due optimizations happening behind the scenes")
     41 # Train the compiled model
---> 43 single_run_compile_results = train(model = compiled_model.to(device),
     44                        train_dataloader = train_dataloader_CIFAR10,
     45                        test_dataloader = test_dataloader_CIFAR10,
     46                        optimizer = optimizer,
     47                        loss_fn = loss_fn,
     48                        epochs = NUM_EPOCHS,
     49                        device = device)

Cell In[24], line 204, in train(model, train_dataloader, test_dataloader, optimizer, loss_fn, epochs, device, disable_progress_bar)
    200 for epoch in tqdm(range(epochs), disable=disable_progress_bar):
    201 
    202     # Perform training step and time it
    203     train_epoch_start_time = time.time()
--> 204     train_loss, train_acc = train_step(epoch=epoch, 
    205                                       model=model,
    206                                       dataloader=train_dataloader,
    207                                       loss_fn=loss_fn,
    208                                       optimizer=optimizer,
    209                                       device=device,
    210                                       disable_progress_bar=disable_progress_bar)
    211     train_epoch_end_time = time.time()
    212     train_epoch_time = train_epoch_end_time - train_epoch_start_time

Cell In[24], line 60, in train_step(epoch, model, dataloader, loss_fn, optimizer, device, disable_progress_bar)
     57 optimizer.zero_grad()
     59 # 4. Loss backward
---> 60 loss.backward()
     62 # 5. Optimizer step
     63 optimizer.step()

File ~/anaconda3/lib/python3.10/site-packages/torch/_tensor.py:483, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    436 r"""Computes the gradient of current tensor wrt graph leaves.
    437 
    438 The graph is differentiated using the chain rule. If the tensor is
   (...)
    480         used to compute the attr::tensors.
    481 """
    482 if has_torch_function_unary(self):
--> 483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
    486         self,
    487         gradient=gradient,
    488         retain_graph=retain_graph,
    489         create_graph=create_graph,
    490         inputs=inputs,
    491     )
    492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/anaconda3/lib/python3.10/site-packages/torch/overrides.py:1560, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1556 if _is_torch_function_mode_enabled():
   1557     # if we're here, the mode must be set to a TorchFunctionStackMode
   1558     # this unsets it and calls directly into TorchFunctionStackMode's torch function
   1559     with _pop_mode_temporarily() as mode:
-> 1560         result = mode.__torch_function__(public_api, types, args, kwargs)
   1561     if result is not NotImplemented:
   1562         return result

File ~/anaconda3/lib/python3.10/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/function.py:288, in BackwardCFunction.apply(self, *args)
    282     raise RuntimeError(
    283         "Implementing both 'backward' and 'vjp' for a custom "
    284         "Function is not allowed. You should only implement one "
    285         "of them."
    286     )
    287 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 288 return user_fn(self, *args)

File ~/anaconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:3232, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
   3230     out = CompiledFunctionBackward.apply(*all_args)
   3231 else:
-> 3232     out = call_compiled_backward()
   3233 return out

File ~/anaconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:3204, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
   3199     with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
   3200         CompiledFunction.compiled_bw = aot_config.bw_compiler(
   3201             bw_module, placeholder_list
   3202         )
-> 3204 out = call_func_with_args(
   3205     CompiledFunction.compiled_bw,
   3206     all_args,
   3207     steal_args=True,
   3208     disable_amp=disable_amp,
   3209 )
   3211 out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
   3212 return tuple(out)

File ~/anaconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1506, in call_func_with_args(f, args, steal_args, disable_amp)
   1504 with context():
   1505     if hasattr(f, "_boxed_call"):
-> 1506         out = normalize_as_list(f(args))
   1507     else:
   1508         # TODO: Please remove soon
   1509         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
   1510         warnings.warn(
   1511             "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
   1512             "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
   1513             "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
   1514         )

File ~/anaconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    326 dynamic_ctx.__enter__()
    327 try:
--> 328     return fn(*args, **kwargs)
    329 finally:
    330     set_eval_frame(prior)

File ~/anaconda3/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline.<locals>.inner(*args, **kwargs)
     15 @functools.wraps(fn)
     16 def inner(*args, **kwargs):
---> 17     return fn(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torch/_inductor/codecache.py:374, in CompiledFxGraph.__call__(self, inputs)
    373 def __call__(self, inputs) -> Any:
--> 374     return self.get_current_callable()(inputs)

File ~/anaconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:628, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
    626 def run(new_inputs):
    627     copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 628     return model(new_inputs)

File ~/anaconda3/lib/python3.10/site-packages/torch/_inductor/codecache.py:401, in _run_from_cache(compiled_graph, inputs)
    391     from .codecache import PyCodeCache
    393     compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
    394         compiled_graph.cache_key,
    395         compiled_graph.artifact_path,
   (...)
    398         else (),
    399     ).call
--> 401 return compiled_graph.compiled_artifact(inputs)

File /tmp/torchinductor_isaac-aktam/fi/cfignuiw5cmdhtpeow6axfwfjocu44zvgnrf4rlockxh2k5th7f3.py:5638, in call(args)
   5636 del primals_4
   5637 buf473 = buf472[0]
-> 5638 assert_size_stride(buf473, (s0, 64, 56, 56), (200704, 1, 3584, 64))
   5639 buf474 = buf472[1]
   5640 assert_size_stride(buf474, (64, 64, 1, 1), (64, 1, 64, 64))

AssertionError: expected size 64==64, stride 3136==1 at dim=1


I found a relevant post on GitHub (https://github.com/pytorch/pytorch/pull/91605), but don't understand how to implement it.

Note: the code works when torch.compile() is not utilised. Furthermore, when torch.compile() is on, the code crashes after the first epoch.

0

There are 0 best solutions below