TFT:解釈可能な時系列予測用ディープラーニング(1/2)

学習手法

1.TFT:解釈可能な時系列予測用ディープラーニング(1/2)まとめ

・複数の時間間隔で関心のある変数の将来を予測することは時系列機械学習における重要な課題
・従来の時系列モデルはモデルがどのようにして予測に至ったかを説明することは困難
・TFTは優れた精度と解釈可能性の両方を実現するマルチホライズン予測用の新モデル

2.マルチホライズン予測とは?

以下、ai.googleblog.comより「Interpretable Deep Learning for Time Series Forecasting」の意訳です。元記事は2021年12月13日、O. ArikさんとTomas Pfisterさんによる投稿です。

アイキャッチ画像のクレジットはPhoto by Jon Tyson on Unsplash

マルチホライズン予測(Multi-horizon forecasting)、すなわち複数の時間間隔で関心のある変数の将来を予測することは、時系列機械学習における重要な課題です。

現実世界のほとんどのデータセットには時間の要素があり、未来を予測することで大きな価値を引き出すことができます。例えば、小売業者は将来の売上を利用してサプライチェーンやプロモーションを最適化し、投資マネージャーは金融資産の将来の価格を予測してパフォーマンスを最大化し、医療機関は将来の患者入院数を利用して十分な人員と設備を確保することができます。

ディープニューラルネットワーク(DNN:Deep neural networks)は、従来の時系列モデルに比べて強力な性能向上を示し、マルチホライズン予測に用いられることが多くなっています。多くのモデル(DeepAR、MQRNNなど)はリカレントニューラルネットワーク(RNN:Recurrent Neural Networks)の亜種に焦点を当ててきました。

しかし、Transformerをベースとするモデルを含む最近の改良では、RNNの帰納的バイアス(情報のを順次順番に処理する事)を超え、過去の時間ステップから関連するステップを選択する能力を強化するために、attentionに基づく層を使用しています。

しかし、これらの手法はしばしば、マルチホライズン予測実施時によく見られる様々な異なった入力を考慮していないことが多いのです。そして、 外部の入力データ源が全て将来にわたって既知であると仮定するか、重要な静的説明変数(static covariates)を無視します。


静的説明変数(店の場所、商品の情報)と様々な時間経過に依存する入力(休日、開店、閉店)におけるマルチホライズン予測

また、従来の時系列モデルは、多くのパラメータ間の複雑な非線形相互作用によって予測を行っており、そのようなモデルがどのようにして予測に至ったかを説明することは困難でした。

残念ながら、DNNの挙動を説明するための一般的な手法には限界があります。例えば、post-hoc手法(例えば、LIMEやSHAP)は入力特徴の順序を考慮しません。

Attentionベースのモデルの中には、主に言語や音声などの連続データに対して固有の解釈可能性を持つものが提案されていますが、マルチホライズン予測には言語や音声だけでなく様々な種類の入力が存在します。

Attentionに基づくモデルは関連する時間ステップについての洞察を与えることができますが、与えられた時間ステップにおける異なる特徴の重要性を区別することはできません。

マルチホライズン予測におけるデータの異類混交性に取り組んで高いパフォーマンスを発揮し、これらの予測を解釈可能なものにするための新しい手法が必要です。

そのために、私たちは、International Journal of Forecastingに掲載された「Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting」を発表し、マルチホライズン予測のためのAttentionベースのDNNモデル、時間的融合トランスフォーマー(TFT:Temporal Fusion Transformer)を提案します。

TFTは、優れた精度と解釈可能性の両方を実現するために、モデルを一般的なマルチホライズン予測タスクに明示的に合わせるように設計されており、様々な事例でそれを実証しています。

Temporal Fusion Transformer

高い予測性能を得るために、静的な入力、既知の入力、観測した入力といった入力タイプごとに特徴表現を効率的に構築するようTFTを設計しています。
TFTの主な構成要素は以下の通りです。

(1)ゲーティング機構(Gating mechanisms)

ゲーティング機構はデータから学習を行い、モデルの未使用部をスキップします。これにより、様々なデータセットに対応する深さとネットワークの複雑さを提供します。

(2)変数選択ネットワーク(Variable selection networks)

変数選択ネットワークは各タイムステップで関連する入力変数を選択します。従来のDNNは無関係な特徴に過剰適応してしまう可能性がありますが、Attentionベースの変数選択は、最も顕著な特徴に学習能力の大部分を固定するようモデルに促すことで、汎化能力を向上させることができます。

(3)静的説明変数エンコーダ

静的特徴を統合して、時間的変遷のモデル化を制御します。静的特徴は予測に重要な影響を与える可能性があります。例えば、店舗の場所によって売上げの時間的変遷が異なることがあります。
(例えば、田舎の店舗では週末の交通量が多くなるが、都心の店舗では就業時間後に毎日のピークが来ることがあるなど)。

(4)時間処理

「観測された時間経過と共に変化する入力」と「既知の時間経過と共に変化する入力」の両方から、長期観点と短期観点の双方から時間的関係を学習する時間処理。

局所的な処理には帰納的バイアスが有効なため、順序情報処理用にはsequence-to-sequenceレイヤーを採用しています。長期的な依存関係は、新しい解釈可能なマルチヘッド注目ブロックを用いて捉えます。

これにより、情報の有効経路パスを短縮することができます。つまり、関連する情報(例えば、昨年の売上)を持つ過去の時間ステップを、直接的に注目することができるのようになります。

(5)予測区間

予測区間は、マルチホライズン予測における目標値の範囲を決定するために分位値予測を示し、ユーザーが点予測だけでなく、出力の分布を理解するのに役立ちます。


TFTは、静的なメタデータ、時間的に変化する過去データ、時間的に変化する先験的な既知の未来データを入力とします。入力データに基づき、最も顕著な特徴を選択するために、変数選択機構が使用されます。ゲーテッド情報は残差入力として追加され、その後正則化されます。ゲーテッド残差ネットワーク(GRN:Gated Residual Network)ブロックは、スキップ接続とゲーティングレイヤーにより、効率的な情報の流れを可能にします。時間依存処理は、LSTMによる局所処理とマルチヘッドアテンションによる任意の時間ステップの情報統合を基本します。

3.TFT:解釈可能な時系列予測用ディープラーニング(1/2)関連リンク

1)ai.googleblog.com
Interpretable Deep Learning for Time Series Forecasting

2)www.sciencedirect.com
Temporal Fusion Transformers for interpretable multi-horizon time series forecasting

タイトルとURLをコピーしました