Setting, Error/Python, Python Library

Scikit-learn TimeSeriesSplit

Juheon Kwak 2024. 8. 5. 17:48

- 출처 : https://otexts.com/fpp3/tscv.html

-

* 시계열 교차검증 Time Series Cross Validation (TSCV)

 

time series forecasting을 cross validation할 때는 다음과 같은 방법으로 시도해야 한다.

 

- python code 방법

(출처 : https://www.geeksforgeeks.org/time-series-cross-validation/)

 

--> scikit-learn의 TimeSeriesSplit   함수를 이용하면 된다.

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html

 

[사용법]

from sklearn.model_selection import TimeSeriesSplit

tscv = TimeSeriesSplit()

tscv = TimeSeriesSplit(gap=0, max_train_size=None, n_splits=5, test_size=None)

 

for train_index, test_index in tscv.split(df):

    # Split data into train and test
    train_data = df.iloc[train_index].values
    test_data = df.iloc[test_index].values

    # Arrange train data into X_train and y_train
    X_train, y_train = [], []
    for i in range(seq_len, len(train_data)):
        X_train.append(train_data[i-seq_len:i])
        y_train.append(train_data[:, num_features-1][i])
    X_train, y_train = np.array(X_train), np.array(y_train)

 

 

 

k-Fold 나 Stratified k-Fold 사용방식과 유사하다.