GroupedProphet example notebook

This notebook provides an example of the GroupedProphet API, complete with a self-contained data generator.

[1]:
import os
import sys
import logging
import itertools
import pandas as pd
import numpy as np
import string
import random
from datetime import timedelta, datetime
from collections import namedtuple
from diviner import GroupedProphet

Importing plotly failed. Interactive plots will not work.

Create a synthetic grouped data generator for time series

This will create seasonal data for multiple series worth of data, arranging the generated DataFrame such that identifying columns that define a distinct series are created (and returned when called).

[ ]:


def _generate_time_series(series_size: int): residuals = np.random.lognormal( mean=np.random.uniform(low=0.5, high=3.0), sigma=np.random.uniform(low=0.6, high=0.98), size=series_size, ) trend = [ np.polyval([23.0, 1.0, 5], x) for x in np.linspace( start=0, stop=np.random.randint(low=0, high=4), num=series_size ) ] seasonality = [ 90 * np.sin(2 * np.pi * 1000 * (i / (series_size * 200))) + 40 for i in np.arange(0, series_size) ] return residuals + trend + seasonality + np.random.uniform(low=20.0, high=1000.0) def _generate_grouping_columns(column_count: int, series_count: int): candidate_list = list(string.ascii_uppercase) candidates = random.sample( list(itertools.permutations(candidate_list, column_count)), series_count ) column_names = sorted([f"key{x}" for x in range(column_count)], reverse=True) return [dict(zip(column_names, entries)) for entries in candidates] def _generate_raw_df( column_count: int, series_count: int, series_size: int, start_dt: str, days_period: int, ): candidates = _generate_grouping_columns(column_count, series_count) start_date = datetime.strptime(start_dt, "%Y-%M-%d") dates = np.arange( start_date, start_date + timedelta(days=series_size * days_period), timedelta(days=days_period), ) df_collection = [] for entry in candidates: generated_series = _generate_time_series(series_size) series_dict = {"ds": dates, "y": generated_series} series_df = pd.DataFrame.from_dict(series_dict) for column, value in entry.items(): series_df[column] = value df_collection.append(series_df) return pd.concat(df_collection) def generate_example_data( column_count: int, series_count: int, series_size: int, start_dt: str, days_period: int = 1, ): Structure = namedtuple("Structure", "df key_columns") data = _generate_raw_df( column_count, series_count, series_size, start_dt, days_period ) key_columns = list(data.columns) for key in ["ds", "y"]: key_columns.remove(key) return Structure(data, key_columns)

Suppress optimizer stdout messaging

[2]:
class suppress_stdout_stderr():
    """
    Context manager to prevent the PyStan solver from filling stdout with a wall of text
    """

    def __init__(self):
        self.devnull_stdout = os.open(os.devnull, os.O_RDWR)
        self.devnull_stderr = os.open(os.devnull, os.O_RDWR)
        self.stdout = os.dup(1)
        self.stderr = os.dup(2)

    def __enter__(self):

        os.dup2(self.devnull_stdout, 1)
        os.dup2(self.devnull_stderr, 2)

    def __exit__(self, *_):

        os.dup2(self.stdout, 1)
        os.dup2(self.stderr, 2)

        os.close(self.devnull_stdout)
        os.close(self.devnull_stderr)
[3]:
logging.getLogger().setLevel(logging.CRITICAL)

Generate the stacked example data

[4]:
generated = generate_example_data(
    column_count=3,
    series_count=40,
    series_size=365*4,
    start_dt="2018-02-02",
    days_period=1
)
train = generated.df
key_columns = generated.key_columns

View the data structure

As can be seen, the elements that are required to define the stacked series are here: * column ‘y’ - the series data that we’re going to use for training of the models * column ‘ds’ - the datetime values that correspond to each ‘y’ entry * columns [‘key2’, ‘key1’, ‘key0’], the combination of which define a unique series.

[5]:
train
[5]:
ds y key2 key1 key0
0 2018-01-02 00:02:00 139.540404 H D L
1 2018-01-03 00:02:00 145.638506 H D L
2 2018-01-04 00:02:00 144.510473 H D L
3 2018-01-05 00:02:00 145.096257 H D L
4 2018-01-06 00:02:00 145.901135 H D L
... ... ... ... ... ...
1455 2021-12-27 00:02:00 81.341661 Q Z O
1456 2021-12-28 00:02:00 82.385532 Q Z O
1457 2021-12-29 00:02:00 85.898131 Q Z O
1458 2021-12-30 00:02:00 88.452676 Q Z O
1459 2021-12-31 00:02:00 88.664839 Q Z O

58400 rows × 5 columns

Fit the GroupedProphet models on each of the distinct groups defined by the ‘key_columns’ argument.

[6]:
with suppress_stdout_stderr(): # Suppress stdout from PyStan
    logging.getLogger("prophet").setLevel(logging.CRITICAL) # Suppress INFO warnings
    grouped_prophet_model = GroupedProphet(n_changepoints=30, uncertainty_samples=0).fit(
    train, key_columns
)

Extract the parameters from the trained models

[7]:
params = grouped_prophet_model.extract_model_params()
params
[7]:
grouping_key_columns key2 key1 key0 changepoint_prior_scale changepoint_range component_modes country_holidays daily_seasonality extra_regressors ... seasonality_prior_scale specified_changepoints stan_backend start t_scale train_holiday_names uncertainty_samples weekly_seasonality y_scale yearly_seasonality
0 (key2, key1, key0) A I S 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 981.029852 auto
1 (key2, key1, key0) A Q J 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 616.079172 auto
2 (key2, key1, key0) B O Q 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 903.068555 auto
3 (key2, key1, key0) C O Y 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1031.305845 auto
4 (key2, key1, key0) C T D 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1043.772833 auto
5 (key2, key1, key0) D A C 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1195.890246 auto
6 (key2, key1, key0) D M U 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 628.177374 auto
7 (key2, key1, key0) F O H 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1247.246604 auto
8 (key2, key1, key0) G U L 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 279.778234 auto
9 (key2, key1, key0) H D L 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 313.89009 auto
10 (key2, key1, key0) H R X 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 310.707764 auto
11 (key2, key1, key0) I F C 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1126.160902 auto
12 (key2, key1, key0) I U H 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1287.238874 auto
13 (key2, key1, key0) I W H 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 606.188878 auto
14 (key2, key1, key0) I Y P 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 301.869476 auto
15 (key2, key1, key0) J A S 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1364.239088 auto
16 (key2, key1, key0) J T G 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 980.384209 auto
17 (key2, key1, key0) K H Y 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 917.979213 auto
18 (key2, key1, key0) L D E 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1079.828476 auto
19 (key2, key1, key0) L S E 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 915.850019 auto
20 (key2, key1, key0) L Z S 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 828.615341 auto
21 (key2, key1, key0) M D T 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 336.260444 auto
22 (key2, key1, key0) M O X 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 842.892671 auto
23 (key2, key1, key0) M W L 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 242.552077 auto
24 (key2, key1, key0) N Q R 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 475.901684 auto
25 (key2, key1, key0) N W R 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1078.139666 auto
26 (key2, key1, key0) P F H 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 818.370678 auto
27 (key2, key1, key0) Q X Y 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1065.844067 auto
28 (key2, key1, key0) Q Z O 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 182.534667 auto
29 (key2, key1, key0) R I J 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 592.192065 auto
30 (key2, key1, key0) S B T 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 630.32091 auto
31 (key2, key1, key0) S C Z 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1013.812286 auto
32 (key2, key1, key0) T A Y 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 785.534176 auto
33 (key2, key1, key0) T B P 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 804.37043 auto
34 (key2, key1, key0) T Z F 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 937.109632 auto
35 (key2, key1, key0) U T S 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 744.092648 auto
36 (key2, key1, key0) W G R 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1011.435103 auto
37 (key2, key1, key0) X I A 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 772.332066 auto
38 (key2, key1, key0) X Z N 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1300.590936 auto
39 (key2, key1, key0) Y J F 0.05 0.8 {'additive': ['yearly', 'weekly', 'additive_te... None auto {} ... 10.0 False <prophet.models.PyStanBackend object at 0x7fd8... 2018-01-02 00:02:00 1459 days None 0 auto 1074.841913 auto

40 rows × 29 columns

Cross validate each model to return the scoring metrics for each group

[8]:
with suppress_stdout_stderr():
    metrics = grouped_prophet_model.cross_validate_and_score(
        horizon="30 days",
        period="365 days",
        initial="730 days",
        parallel="threads",
        rolling_window=0.05,
        monthly=False
    )
[9]:
metrics
[9]:
grouping_key_columns key2 key1 key0 mse rmse mae mape mdape smape
0 (key2, key1, key0) A I S 262.540068 15.062691 13.915674 0.016464 0.015734 0.016646
1 (key2, key1, key0) A Q J 469.382836 20.242075 19.657388 0.041805 0.042604 0.042899
2 (key2, key1, key0) B O Q 2817.155728 50.511383 42.431716 0.055081 0.048438 0.057438
3 (key2, key1, key0) C O Y 585.259877 22.515334 21.410191 0.024674 0.023593 0.025063
4 (key2, key1, key0) C T D 757.678167 24.916772 23.975279 0.026917 0.025814 0.027394
5 (key2, key1, key0) D A C 1798.440133 39.779557 37.261880 0.037357 0.034757 0.038267
6 (key2, key1, key0) D M U 2947.507801 51.200966 44.124248 0.163164 0.143661 0.148713
7 (key2, key1, key0) F O H 2275.917905 45.236682 38.896113 0.035109 0.032875 0.036025
8 (key2, key1, key0) G U L 11708.036260 107.953013 86.074602 1.332295 1.592902 0.637987
9 (key2, key1, key0) H D L 10671.352270 103.017873 81.313511 0.704081 0.857213 0.439851
10 (key2, key1, key0) H R X 12140.010983 109.888034 88.612042 3.650613 4.432163 0.900864
11 (key2, key1, key0) I F C 882.160772 27.834954 26.533773 0.027547 0.027449 0.028020
12 (key2, key1, key0) I U H 342.033309 17.099936 16.129785 0.013927 0.013389 0.014053
13 (key2, key1, key0) I W H 664.345357 21.451403 20.532673 0.047146 0.050368 0.048850
14 (key2, key1, key0) I Y P 10648.205364 102.705291 85.009991 1.322913 1.517187 0.629965
15 (key2, key1, key0) J A S 3753.890398 59.210502 50.237830 0.049434 0.044973 0.051284
16 (key2, key1, key0) J T G 712.125767 25.003669 22.405640 0.026714 0.023557 0.027216
17 (key2, key1, key0) K H Y 443.912007 19.960495 17.500978 0.021829 0.020405 0.022165
18 (key2, key1, key0) L D E 1467.502779 36.169521 34.984144 0.043103 0.042812 0.044233
19 (key2, key1, key0) L S E 260.607376 15.042063 14.491292 0.018544 0.018222 0.018755
20 (key2, key1, key0) L Z S 540.464149 21.388109 20.170990 0.029981 0.029070 0.030580
21 (key2, key1, key0) M D T 11475.321478 106.907205 85.239846 0.719624 0.872762 0.455814
22 (key2, key1, key0) M O X 980.723118 29.443091 27.917302 0.040580 0.037341 0.041620
23 (key2, key1, key0) M W L 11471.704236 106.876550 85.262401 6.551329 6.945873 0.957528
24 (key2, key1, key0) N Q R 12126.909866 109.861268 87.720958 0.428517 0.512430 0.317728
25 (key2, key1, key0) N W R 831.001845 26.453888 25.005874 0.027805 0.028192 0.028326
26 (key2, key1, key0) P F H 3385.137185 55.836240 47.652862 0.094587 0.090453 0.101384
27 (key2, key1, key0) Q X Y 696.874649 24.708330 22.547868 0.024816 0.021517 0.025236
28 (key2, key1, key0) Q Z O 10837.921043 103.902697 83.108939 22.734702 17.173468 1.136795
29 (key2, key1, key0) R I J 1338.534546 34.192363 33.305858 0.074509 0.076687 0.078025
30 (key2, key1, key0) S B T 1228.447613 33.062395 32.205870 0.068078 0.068130 0.070888
31 (key2, key1, key0) S C Z 2721.175850 49.565001 43.222989 0.051009 0.045872 0.052886
32 (key2, key1, key0) T A Y 580.708758 22.265963 21.240699 0.032808 0.031591 0.033493
33 (key2, key1, key0) T B P 1559.715623 36.489070 33.756588 0.053180 0.050840 0.055136
34 (key2, key1, key0) T Z F 257.901818 14.821516 14.286622 0.017740 0.017585 0.017936
35 (key2, key1, key0) U T S 1707.310374 38.692001 37.233078 0.063790 0.064129 0.066334
36 (key2, key1, key0) W G R 443.086490 19.538874 18.915095 0.022123 0.021703 0.022428
37 (key2, key1, key0) X I A 2638.423929 48.920116 46.225857 0.084699 0.086252 0.089271
38 (key2, key1, key0) X Z N 1164.674674 31.577529 30.914779 0.029353 0.029181 0.029889
39 (key2, key1, key0) Y J F 238.461604 13.881420 13.325958 0.014331 0.013884 0.014468

Save the model

[10]:
save_path = "/tmp/model/grouped_prophet"
grouped_prophet_model.save(save_path)

Load the model

[11]:
loaded_model = GroupedProphet.load(save_path)

Generate forecasts for each group

Forecasting is not limited to the frequency of the originating series for each group. The training data, with a periodicity of daily, can have weekly predictions generated by specifying the frequency of W, as shown below.

[12]:
forecast = loaded_model.forecast(horizon=16, frequency="W")
[13]:
forecast
[13]:
grouping_key_columns key2 key1 key0 ds trend additive_terms weekly yearly multiplicative_terms yhat
0 (key2, key1, key0) A I S 2022-01-02 00:02:00 946.265699 -54.783190 -0.041724 -54.741466 0.0 891.482510
1 (key2, key1, key0) A I S 2022-01-09 00:02:00 949.014237 -45.960688 -0.041724 -45.918964 0.0 903.053549
2 (key2, key1, key0) A I S 2022-01-16 00:02:00 951.762775 -36.963503 -0.041724 -36.921779 0.0 914.799272
3 (key2, key1, key0) A I S 2022-01-23 00:02:00 954.511313 -26.911199 -0.041724 -26.869476 0.0 927.600114
4 (key2, key1, key0) A I S 2022-01-30 00:02:00 957.259851 -15.607711 -0.041724 -15.565987 0.0 941.652140
... ... ... ... ... ... ... ... ... ... ... ...
635 (key2, key1, key0) Y J F 2022-03-20 00:02:00 1063.621168 60.727302 0.220967 60.506335 0.0 1124.348470
636 (key2, key1, key0) Y J F 2022-03-27 00:02:00 1066.317513 71.955841 0.220967 71.734874 0.0 1138.273354
637 (key2, key1, key0) Y J F 2022-04-03 00:02:00 1069.013858 82.569743 0.220967 82.348776 0.0 1151.583600
638 (key2, key1, key0) Y J F 2022-04-10 00:02:00 1071.710203 91.797486 0.220967 91.576519 0.0 1163.507689
639 (key2, key1, key0) Y J F 2022-04-17 00:02:00 1074.406547 99.418800 0.220967 99.197833 0.0 1173.825347

640 rows × 11 columns

[14]:
os.remove(save_path)