PyTorch トレーニングを torch-ort で高速化する
PyTorch トレーニング スクリプトにシンプルな変更を加えるだけで、自由に選択したターゲット ハードウェア上で torch_ort.ORTModule を実行し、大規模な言語モデルのトレーニングを高速化できるようになりました。
ディープ ラーニング モデルのトレーニングに必要なコンピューティングおよびメモリ リソースが増え続ける中、本日マイクロソフトは torch_ort.ORTModule をリリースしました。これによって PyTorch モデルの分散トレーニングはさらに加速し、トレーニングに要する時間とリソースを削減できるようになります。また torch-ort は NVIDIA と AMD 両方の GPU に対応し、開発者に柔軟性を提供します。さらに、torch-ort パッケージを DeepSpeed などその他のディープ ラーニング オプティマイザーと組み合わせて使用することで、トレーニング タスクのパフォーマンスはさらに高まります。
https://github.com/pytorch/ort の torch-ort パッケージから提供される ORTModule クラスは、torch.nn.Module のシンプルなラッパーです。ORTModule は GPT や BERT などのトランスフォーマー モデルをサポートしますが、その他のモダリティのサポートにも今後対応していく予定です。現時点では、ターゲット タスクのラベル付きデータセットで多くの人気言語モデルを微調整したり、モデルの自己教師あり学習に特定のコーパスを追加したり、新しいモデルの事前トレーニングをゼロから実験したりすることができます。
パフォーマンス
マイクロソフトでは、社内の複数の大規模ワークロードで既に torch-ort を使用しています。さらに、最も広く使用されている Hugging Face モデルのいくつかに対する微調整をベンチマーク調査したところ、ORTModule 単体で最大 37%、DeepSpeed との組み合わせで最大 86% のスループット向上が確認されました。
これらの実験は、Azure の世界最高水準のインフラストラクチャ Azure ND A100 v4 で実行されました。さらに、マシン内部およびマシン間では、GPU 間の帯域幅が最適化されています。
上記のグラフは、トレーニング サンプルの毎秒スループットを示しています。実際のトレーニング ジョブに要する時間は、トレーニング サンプルの数および使用する CPU/GPU の種類に応じて変化します。ORTModule は、トレーニング処理を開始する前にモデルの最適化を 1 回だけ実行します。これは回避できないコストですが、実行全体を通して相殺されます。
また AskHereFirst では、ORTModule と DeepSpeed の組み合わせを使用して、カスタム自然言語タスクの 27 億パラメーター GPT-Neo モデルをトレーニングすることに成功しています。以前は、これほどの規模のモデルを、現存するハードウェア上でトレーニングするのは不可能でした。コロンビア大学の外部組織である AskHereFirst では、構造化データ ストア向けの強力な AI ベース自然言語クエリ ソリューションを運用しています。このソリューションを活用することで、金融、メディア、マーケティング、スポーツといった幅広い業界の検索処理が大幅に効率化されます。
「ORTModule と DeepSpeed を採用するまでは、私たちのカスタム自然言語タスク用に GPT-NEO をトレーニングすることは不可能でした。現在では、微調整された 27 億パラメーター GPT-NEO モデルを生成し、自然言語入力を構造化クエリにマッピングして、さまざまな用途に利用しています」—Vishal Misra 氏、コロンビア大学教授兼 AskHereFirst 創設者
ハードウェアの移植性
分散トレーニング ワークロードを実行するハードウェア プラットフォームには、さまざまな選択肢があります。torch_ort.ORTModule は NVIDIA と AMD の GPU で動作します。
マイクロソフトでは、CUDA 10.2 または CUDA 11.1 を使用して、NVIDIA 用 torch-ort パッケージをリリースしています。これを使用することで、Azure の NVIDIA GPU またはユーザーのオンプレミス環境の両方で PyTorch トレーニングの実行を高速化できます。
また AMD GPU 向けには、ROCm 4.2 を使用した torch-ort のプレビュー パッケージをリリースしています。
シンプルな開発者エクスペリエンス
ORTModule は簡単に使い始められます。torch-ort パッケージをダウンロード、インストールして、以下に示したコード サンプルに従って ORTModule でモデルをラップするだけです。
ORTModule で torch.nn.Module をラップする部分を除き、PyTorch トレーニング ループは変更されていません。
PyTorch トレーニング ループは変更されないため、torch.autocast や NVIDIA apex など、PyTorch エコシステムのその他のライブラリに ORTModule をシームレスに統合できます。
動作のメカニズム
最初の forward 呼び出しで、前方予測パス用および後方勾配計算パス用の 2 つの最適化済み計算グラフが生成されます。トレーニング ループのその他の部分はすべてネイティブの PyTorch により実行されます。カーネルの最適化、サブグラフ演算の融合、CPU と GPU 間におけるメモリ コピーの削減といった最適化がこれらのグラフで行われることで、高速化が実現されます。
詳細情報
ORTModule の技術に関する詳細情報についてはこちらで、マイクロソフトと AMD のパートナーシップについてはこちらでご確認いただけます。その他のドキュメント、サンプル、チームへのご連絡については、こちらをご覧ください。