-
Notifications
You must be signed in to change notification settings - Fork 1
/
datautils.py
67 lines (56 loc) · 2.31 KB
/
datautils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from paddlets.datasets.repository import get_dataset, dataset_list
def _get_time_features(dt):
return np.stack([
dt.minute.to_numpy(),
dt.hour.to_numpy(),
dt.dayofweek.to_numpy(),
dt.day.to_numpy(),
dt.dayofyear.to_numpy(),
dt.month.to_numpy(),
dt.weekofyear.to_numpy(),
], axis=1).astype(np.float)
def load_forecast_csv(name, univar=False):
data_pts = get_dataset(f'{name}')
data = data_pts.to_dataframe()
# data = pd.read_csv(f'datasets/{name}.csv', index_col='date', parse_dates=True)
dt_embed = _get_time_features(data.index)
n_covariate_cols = dt_embed.shape[-1]
if univar:
if name in ('ETTh1', 'ETTh2', 'ETTm1', 'ETTm2'):
data = data[['OT']]
elif name == 'electricity':
data = data[['MT_001']]
else:
data = data.iloc[:, -1:]
data = data.to_numpy()
if name == 'ETTh1' or name == 'ETTh2':
train_slice = slice(None, 12*30*24)
valid_slice = slice(12*30*24, 16*30*24)
test_slice = slice(16*30*24, 20*30*24)
elif name == 'ETTm1' or name == 'ETTm2':
train_slice = slice(None, 12*30*24*4)
valid_slice = slice(12*30*24*4, 16*30*24*4)
test_slice = slice(16*30*24*4, 20*30*24*4)
else:
train_slice = slice(None, int(0.6 * len(data)))
valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data)))
test_slice = slice(int(0.8 * len(data)), None)
scaler = StandardScaler().fit(data[train_slice])
data = scaler.transform(data)
if name in ('electricity'):
data = np.expand_dims(data.T, -1) # Each variable is an instance rather than a feature
else:
data = np.expand_dims(data, 0)
if n_covariate_cols > 0:
dt_scaler = StandardScaler().fit(dt_embed[train_slice])
dt_embed = np.expand_dims(dt_scaler.transform(dt_embed), 0)
data = np.concatenate([np.repeat(dt_embed, data.shape[0], axis=0), data], axis=-1)
if name in ('ETTh1', 'ETTh2', 'electricity'):
pred_lens = [24, 48, 168, 336, 720]
else:
pred_lens = [24, 48, 96, 288, 672]
return data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols