Skip to content

plot_prediction errorbar error #1145

@holymonkey808

Description

@holymonkey808
  • PyTorch-Forecasting version: 0.10.3
  • PyTorch version: 1.12.1
  • Python version: 3.10.6
  • Operating System: Amazon Linux 2

Expected behavior

Running TFT tutorial and change the max_prediction_length to 1 encounters error, while max_prediction_length > 1 works fine.
0.10.1 works fine with max_prediction_length = 1

Getting the following error:

ValueError Traceback (most recent call last)
Input In [18], in <cell line: 2>()
1 # fit network
----> 2 trainer.fit(
3 tft,
4 train_dataloaders=train_dataloader,
5 val_dataloaders=val_dataloader,
6 )

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:696, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
677 r"""
678 Runs the full optimization routine.
679
(...)
693 datamodule: An instance of :class:~pytorch_lightning.core.datamodule.LightningDataModule.
694 """
695 self.strategy.model = model
--> 696 self._call_and_handle_interrupt(
697 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
698 )

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
648 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
649 else:
--> 650 return trainer_fn(*args, **kwargs)
651 # TODO(awaelchli): Unify both exceptions below, where KeyboardError doesn't re-raise
652 except KeyboardInterrupt as exception:

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:735, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
731 ckpt_path = ckpt_path or self.resume_from_checkpoint
732 self._ckpt_path = self.__set_ckpt_path(
733 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
734 )
--> 735 results = self._run(model, ckpt_path=self.ckpt_path)
737 assert self.state.stopped
738 self.training = False

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1166, in Trainer._run(self, model, ckpt_path)
1162 self._checkpoint_connector.restore_training_state()
1164 self._checkpoint_connector.resume_end()
-> 1166 results = self._run_stage()
1168 log.detail(f"{self.class.name}: trainer tearing down")
1169 self._teardown()

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1252, in Trainer._run_stage(self)
1250 if self.predicting:
1251 return self._run_predict()
-> 1252 return self._run_train()

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1283, in Trainer._run_train(self)
1280 self.fit_loop.trainer = self
1282 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1283 self.fit_loop.run()

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:271, in FitLoop.advance(self)
267 self._data_fetcher.setup(
268 dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
269 )
270 with self.trainer.profiler.profile("run_training_epoch"):
--> 271 self._outputs = self.epoch_loop.run(self._data_fetcher)

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:201, in Loop.run(self, *args, **kwargs)
199 self.on_advance_start(*args, **kwargs)
200 self.advance(*args, **kwargs)
--> 201 self.on_advance_end()
202 self._restarting = False
203 except StopIteration:

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:241, in TrainingEpochLoop.on_advance_end(self)
239 if should_check_val:
240 self.trainer.validating = True
--> 241 self._run_validation()
242 self.trainer.training = True
244 # update plateau LR scheduler after metrics are logged

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:299, in TrainingEpochLoop._run_validation(self)
296 self.val_loop._reload_evaluation_dataloaders()
298 with torch.no_grad():
--> 299 self.val_loop.run()

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:155, in EvaluationLoop.advance(self, *args, **kwargs)
153 if self.num_dataloaders > 1:
154 kwargs["dataloader_idx"] = dataloader_idx
--> 155 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
157 # store batch level output per dataloader
158 self._outputs.append(dl_outputs)

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
198 try:
199 self.on_advance_start(*args, **kwargs)
--> 200 self.advance(*args, **kwargs)
201 self.on_advance_end()
202 self._restarting = False

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:143, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs)
140 self.batch_progress.increment_started()
142 # lightning module methods
--> 143 output = self._evaluation_step(**kwargs)
144 output = self._evaluation_step_end(output)
146 self.batch_progress.increment_processed()

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:240, in EvaluationEpochLoop._evaluation_step(self, **kwargs)
229 """The evaluation step (validation_step or test_step depending on the trainer's state).
230
231 Args:
(...)
237 the outputs of the step
238 """
239 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 240 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
242 return output

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1704, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
1701 return
1703 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"):
-> 1704 output = fn(*args, **kwargs)
1706 # restore current_fx when nested context
1707 pl_module._current_fx_name = prev_fx_name

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:370, in Strategy.validation_step(self, *args, **kwargs)
368 with self.precision_plugin.val_step_context():
369 assert isinstance(self.model, ValidationStep)
--> 370 return self.model.validation_step(*args, **kwargs)

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:420, in BaseModel.validation_step(self, batch, batch_idx)
418 x, y = batch
419 log, out = self.step(x, y, batch_idx)
--> 420 log.update(self.create_log(x, y, out, batch_idx))
421 return log

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:520, in TemporalFusionTransformer.create_log(self, x, y, out, batch_idx, **kwargs)
519 def create_log(self, x, y, out, batch_idx, **kwargs):
--> 520 log = super().create_log(x, y, out, batch_idx, **kwargs)
521 if self.log_interval > 0:
522 log["interpretation"] = self._log_interpretation(out)

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:469, in BaseModel.create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs)
467 self.log_metrics(x, y, out, prediction_kwargs=prediction_kwargs)
468 if self.log_interval > 0:
--> 469 self.log_prediction(
470 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs
471 )
472 return {}

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:717, in BaseModel.log_prediction(self, x, out, batch_idx, **kwargs)
715 log_indices = [0]
716 for idx in log_indices:
--> 717 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
718 tag = f"{self.current_stage} prediction"
719 if self.training:

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:711, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
694 """
695 Plot actuals vs prediction and attention
696
(...)
707 plt.Figure: matplotlib figure
708 """
710 # plot prediction as normal
--> 711 fig = super().plot_prediction(
712 x,
713 out,
714 idx=idx,
715 add_loss_to_title=add_loss_to_title,
716 show_future_observed=show_future_observed,
717 ax=ax,
718 **kwargs,
719 )
721 # add attention on secondary axis
722 if plot_attention:

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:832, in BaseModel.plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs)
830 else:
831 quantiles = torch.tensor([[y_quantile[0, i]], [y_quantile[0, -i - 1]]])
--> 832 ax.errorbar(
833 x_pred,
834 y[[-n_pred]],
835 yerr=quantiles - y[-n_pred],
836 c=pred_color,
837 capsize=1.0,
838 )
840 if add_loss_to_title is not False:
841 if isinstance(add_loss_to_title, bool):

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/init.py:1423, in _preprocess_data..inner(ax, data, *args, **kwargs)
1420 @functools.wraps(func)
1421 def inner(ax, *args, data=None, **kwargs):
1422 if data is None:
-> 1423 return func(ax, *map(sanitize_sequence, args), **kwargs)
1425 bound = new_sig.bind(ax, *args, **kwargs)
1426 auto_label = (bound.arguments.get(label_namer)
1427 or bound.kwargs.get(label_namer))

File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/axes/_axes.py:3587, in Axes.errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs)
3584 res = np.zeros(err.shape, dtype=bool) # Default in case of nan
3585 if np.any(np.less(err, -err, out=res, where=(err == err))):
3586 # like err<0, but also works for timedelta and nan.
-> 3587 raise ValueError(
3588 f"'{dep_axis}err' must not contain negative values")
3589 # This is like
3590 # elow, ehigh = np.broadcast_to(...)
3591 # return dep - elow * ~lolims, dep + ehigh * ~uplims
3592 # except that broadcast_to would strip units.
3593 low, high = dep + np.row_stack([-(1 - lolims), 1 - uplims]) * err

ValueError: 'yerr' must not contain negative values

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions