-
Notifications
You must be signed in to change notification settings - Fork 702
Description
- PyTorch-Forecasting version: 0.10.3
- PyTorch version: 1.12.1
- Python version: 3.10.6
- Operating System: Amazon Linux 2
Expected behavior
Running TFT tutorial and change the max_prediction_length to 1 encounters error, while max_prediction_length > 1 works fine.
0.10.1 works fine with max_prediction_length = 1
Getting the following error:
ValueError Traceback (most recent call last)
Input In [18], in <cell line: 2>()
1 # fit network
----> 2 trainer.fit(
3 tft,
4 train_dataloaders=train_dataloader,
5 val_dataloaders=val_dataloader,
6 )
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:696, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
677 r"""
678 Runs the full optimization routine.
679
(...)
693 datamodule: An instance of :class:~pytorch_lightning.core.datamodule.LightningDataModule
.
694 """
695 self.strategy.model = model
--> 696 self._call_and_handle_interrupt(
697 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
698 )
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
648 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
649 else:
--> 650 return trainer_fn(*args, **kwargs)
651 # TODO(awaelchli): Unify both exceptions below, where KeyboardError
doesn't re-raise
652 except KeyboardInterrupt as exception:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:735, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
731 ckpt_path = ckpt_path or self.resume_from_checkpoint
732 self._ckpt_path = self.__set_ckpt_path(
733 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
734 )
--> 735 results = self._run(model, ckpt_path=self.ckpt_path)
737 assert self.state.stopped
738 self.training = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1166, in Trainer._run(self, model, ckpt_path)
1162 self._checkpoint_connector.restore_training_state()
1164 self._checkpoint_connector.resume_end()
-> 1166 results = self._run_stage()
1168 log.detail(f"{self.class.name}: trainer tearing down")
1169 self._teardown()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1252, in Trainer._run_stage(self)
1250 if self.predicting:
1251 return self._run_predict()
-> 1252 return self._run_train()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1283, in Trainer._run_train(self)
1280 self.fit_loop.trainer = self
1282 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1283 self.fit_loop.run()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:271, in FitLoop.advance(self)
267 self._data_fetcher.setup(
268 dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
269 )
270 with self.trainer.profiler.profile("run_training_epoch"):
--> 271 self._outputs = self.epoch_loop.run(self._data_fetcher)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:201, in Loop.run(self, *args, **kwargs)
199 self.on_advance_start(*args, **kwargs)
200 self.advance(*args, **kwargs)
--> 201 self.on_advance_end()
202 self._restarting = False
203 except StopIteration:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:241, in TrainingEpochLoop.on_advance_end(self)
239 if should_check_val:
240 self.trainer.validating = True
--> 241 self._run_validation()
242 self.trainer.training = True
244 # update plateau LR scheduler after metrics are logged
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:299, in TrainingEpochLoop._run_validation(self)
296 self.val_loop._reload_evaluation_dataloaders()
298 with torch.no_grad():
--> 299 self.val_loop.run()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:155, in EvaluationLoop.advance(self, *args, **kwargs)
153 if self.num_dataloaders > 1:
154 kwargs["dataloader_idx"] = dataloader_idx
--> 155 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
157 # store batch level output per dataloader
158 self._outputs.append(dl_outputs)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:143, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs)
140 self.batch_progress.increment_started()
142 # lightning module methods
--> 143 output = self._evaluation_step(**kwargs)
144 output = self._evaluation_step_end(output)
146 self.batch_progress.increment_processed()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:240, in EvaluationEpochLoop._evaluation_step(self, **kwargs)
229 """The evaluation step (validation_step or test_step depending on the trainer's state).
230
231 Args:
(...)
237 the outputs of the step
238 """
239 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 240 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
242 return output
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1704, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
1701 return
1703 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"):
-> 1704 output = fn(*args, **kwargs)
1706 # restore current_fx when nested context
1707 pl_module._current_fx_name = prev_fx_name
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:370, in Strategy.validation_step(self, *args, **kwargs)
368 with self.precision_plugin.val_step_context():
369 assert isinstance(self.model, ValidationStep)
--> 370 return self.model.validation_step(*args, **kwargs)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:420, in BaseModel.validation_step(self, batch, batch_idx)
418 x, y = batch
419 log, out = self.step(x, y, batch_idx)
--> 420 log.update(self.create_log(x, y, out, batch_idx))
421 return log
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:520, in TemporalFusionTransformer.create_log(self, x, y, out, batch_idx, **kwargs)
519 def create_log(self, x, y, out, batch_idx, **kwargs):
--> 520 log = super().create_log(x, y, out, batch_idx, **kwargs)
521 if self.log_interval > 0:
522 log["interpretation"] = self._log_interpretation(out)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:469, in BaseModel.create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs)
467 self.log_metrics(x, y, out, prediction_kwargs=prediction_kwargs)
468 if self.log_interval > 0:
--> 469 self.log_prediction(
470 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs
471 )
472 return {}
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:717, in BaseModel.log_prediction(self, x, out, batch_idx, **kwargs)
715 log_indices = [0]
716 for idx in log_indices:
--> 717 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
718 tag = f"{self.current_stage} prediction"
719 if self.training:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:711, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
694 """
695 Plot actuals vs prediction and attention
696
(...)
707 plt.Figure: matplotlib figure
708 """
710 # plot prediction as normal
--> 711 fig = super().plot_prediction(
712 x,
713 out,
714 idx=idx,
715 add_loss_to_title=add_loss_to_title,
716 show_future_observed=show_future_observed,
717 ax=ax,
718 **kwargs,
719 )
721 # add attention on secondary axis
722 if plot_attention:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:832, in BaseModel.plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs)
830 else:
831 quantiles = torch.tensor([[y_quantile[0, i]], [y_quantile[0, -i - 1]]])
--> 832 ax.errorbar(
833 x_pred,
834 y[[-n_pred]],
835 yerr=quantiles - y[-n_pred],
836 c=pred_color,
837 capsize=1.0,
838 )
840 if add_loss_to_title is not False:
841 if isinstance(add_loss_to_title, bool):
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/init.py:1423, in _preprocess_data..inner(ax, data, *args, **kwargs)
1420 @functools.wraps(func)
1421 def inner(ax, *args, data=None, **kwargs):
1422 if data is None:
-> 1423 return func(ax, *map(sanitize_sequence, args), **kwargs)
1425 bound = new_sig.bind(ax, *args, **kwargs)
1426 auto_label = (bound.arguments.get(label_namer)
1427 or bound.kwargs.get(label_namer))
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/axes/_axes.py:3587, in Axes.errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs)
3584 res = np.zeros(err.shape, dtype=bool) # Default in case of nan
3585 if np.any(np.less(err, -err, out=res, where=(err == err))):
3586 # like err<0, but also works for timedelta and nan.
-> 3587 raise ValueError(
3588 f"'{dep_axis}err' must not contain negative values")
3589 # This is like
3590 # elow, ehigh = np.broadcast_to(...)
3591 # return dep - elow * ~lolims, dep + ehigh * ~uplims
3592 # except that broadcast_to would strip units.
3593 low, high = dep + np.row_stack([-(1 - lolims), 1 - uplims]) * err
ValueError: 'yerr' must not contain negative values