awslabs/gluonts

MeanEstimator should be able to ignore NANs in training

Open

#2,175 opened on Jul 21, 2022

View on GitHub
 (2 comments) (0 reactions) (0 assignees)Python (3,888 stars) (753 forks)batch import
enhancementgood first issue

Description

Description

I created a simple random dataset, added a single NAN value in the first instance at the start of the last window before forecast, and compared evaluation metrics of 3 trivial estimators: Mean, Constant and Identity (code below). The ConstantEstimator is not affected at all - all metrics are normal. IdentifyPredictor ends up with NA for the metrics on the first instance and normal values elsewhere. This is not surprising, though one could argue whether that's the best handling of the missing value. However, MeanEstimator ends up with all metrics for all instances as NAN - I found this surprising. I think it would make more sense for MeanEstimator to ignore NANs in training data rather than have a single NAN mess up the predictor.

Another thought is that it'd be interesting to have a variant that computes and applies means separately on each instance, rather than over the whole training data.

To Reproduce

import pandas as pd
import numpy as np
from gluonts.evaluation import Evaluator
from gluonts.evaluation import backtest
from gluonts.dataset.common import ListDataset, TrainDatasets
from gluonts.evaluation import make_evaluation_predictions

# f=following gluonTS tutorial: https://ts.gluon.ai/stable/tutorials/forecasting/quick_start_tutorial.html#Custom-datasets
N = 10  # number of time series
T = 24*7  # number of timesteps
prediction_length = 24
freq = "1H"
custom_dataset = np.random.normal(size=(N, T))

start = pd.Timestamp("01-01-2019", freq=freq)  # can be different for each time series

### add NANs: 1st point in the last training window 
custom_dataset[0,121]=np.nan

train_ds = ListDataset(
    [{'item_id': i,'target': x, 'start': start} for (i,x) in enumerate(custom_dataset[:, :-prediction_length])],
    freq=freq
)
test_ds = ListDataset(
    [{'item_id': i,'target': x, 'start': start} for (i,x) in enumerate(custom_dataset[:, :])],
    freq=freq
)

# Fit MeanEstimator
from gluonts.model.trivial.mean import MeanEstimator
predictorM = MeanEstimator(prediction_length=24, freq='1H', num_samples=100).train(train_ds)
from gluonts.model.trivial.constant import ConstantValuePredictor
predictorC = ConstantValuePredictor(prediction_length=24, freq='1H', value=0, num_samples=100)
from gluonts.model.trivial.identity import IdentityPredictor
predictorI = IdentityPredictor(prediction_length=24, freq='1H', num_samples=100)

for predictor in [predictorC, predictorI, predictorM]:
    print(type(predictor))
    evaluator=Evaluator(quantiles=[0.1, 0.5, 0.9], num_workers=None)
    agg_metrics, item_metrics = backtest.backtest_metrics(test_ds, predictor, evaluator)    
    print([agg_metrics[x] for x in ['MSE', 'RMSE', 'NRMSE', 'ND']])
    print(item_metrics[['item_id','MSE','abs_error']])

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

<class 'gluonts.model.trivial.constant.ConstantValuePredictor'>
[1.1721360819194713, 1.082652336588007, 1.2891094480276746, 1.0000000025297278]
   item_id       MSE  abs_error
0      NaN  1.646284  25.712500
1      NaN  1.810231  25.060539
2      NaN  0.934014  17.689958
3      NaN  1.138792  18.325633
4      NaN  0.868155  18.164898
5      NaN  1.123610  17.956955
6      NaN  0.759920  16.268135
7      NaN  1.408551  21.661034
8      NaN  1.134947  22.545605
9      NaN  0.896856  18.177583

<class 'gluonts.model.trivial.identity.IdentityPredictor'>
[2.263645772580747, 1.5045417151347937, 1.7914513037901383, 1.2935801589632892]
   item_id       MSE  abs_error
0      0.0       NaN        NaN
1      1.0  2.201055  28.548775
2      2.0  1.534184  25.064253
3      3.0  3.204268  32.571938
4      4.0  1.617391  25.096828
5      5.0  1.972552  26.592113
6      6.0  2.252131  31.927937
7      7.0  2.406640  28.842175
8      8.0  2.678541  31.139656
9      9.0  2.506050  30.954016

<class 'gluonts.model.trivial.constant.ConstantPredictor'>  -- this is because MeanEstimator produces ConstantPredictor
[nan, nan, nan, 0.0]
   item_id  MSE  abs_error
0      0.0  NaN        NaN
1      1.0  NaN        NaN
2      2.0  NaN        NaN
3      3.0  NaN        NaN
4      4.0  NaN        NaN
5      5.0  NaN        NaN
6      6.0  NaN        NaN
7      7.0  NaN        NaN
8      8.0  NaN        NaN
9      9.0  NaN        NaN

Environment

  • Operating system: linux
  • Python version: 3.7.10
  • GluonTS version: 0.9.4
  • MXNet version: 1.7.0

Contributor guide