PyTorch で GPU 並列をちょっと詳しく
要約
PyTorch でマルチ GPU してみたけど,色々ハマったので記録に残しておく.データ並列もモデル並列(?)もやった.
メインターゲット
- PyTorch ユーザ
- GPU 並列したい人
前提知識
- 深層学習の基礎
- 分散深層学習の基礎 分散深層学習に関してはこちら↓ nomoto-eriko.hatenablog.com
.
並列化したコード
どのようなコードを並列化したのかの説明を軽くしておく.必要なければ読み飛ばしてもらって構わない.
この論文 の自前実装.テキストを入力に画像を生成するモデル.テキスト→オブジェクト位置→オブジェクトの形→画像 の順に予測・生成していく. このモデルを訓練するスクリプトを並列化した.
モデル
主なコンポーネント
Component | input | output |
---|---|---|
box_generator | テキスト | オブジェクトの位置とラベルのペア(以後オブジェクト)の列 |
shape_generator | テキスト,オブジェクト列 | オブジェクトの形の列 |
image_generator | テキスト,オブジェクトの形の列を一つの画像にまとめたやつ(セマンティックレイアウト) | 画像 |
補助的なコンポーネント
Component | input | output |
---|---|---|
text_encoder (pretrained) | テキスト(単語列) | テキストベクトル |
shape_discriminator | テキスト,オブジェクトの形の列 | そのオブジェクトが正例か shape_generator が生成したものかの 2 値 |
image_discriminator | テキスト,画像 | その画像が正例か image_generator が生成したものかの 2 値 |
vgg_model (pretrained) | オブジェクトの形 or 画像 | 入力画像の内部表現(本来は画像分類モデルだが,今回は内部表現を利用する) |
モデル図
特筆事項
- box_generator は RNN な構造を持つ.shape_generator は RNN な構造を持つし,CNN も持つ.
- vgg_model は vgg19 を利用.結構でかい.CNN.
.
並列化1:DataParallel
1 バッチを各 GPU に割り振る.例えば 4 GPU で DataParallel でバッチサイズ 256 のを回すと,一つの GPU がバッチサイズ 64 のバッチを処理する.
実装方法
PyTorch には DataParallel モジュールが用意されている.使い方はめっちゃ簡単. Document
device_ids = range(torch.cuda.device_vount())
model = DataParallel(model, device_ids=device_ids)
DataParallel が何をやってるかをちょっと詳しく
同期更新型のデータ並列.以降,勾配計算のみを行う複数個の GPU を子 GPU,勾配の集約を行う唯一の GPU を親 GPU と言う. ※親 GPU 自身も勾配計算を行う.
1. Forward
- ミニバッチを各 GPU に均等に割り当てる.
- 親 GPU に乗っているモデルのパラメータ情報等を各子 GPU にコピーする.
- 各 GPU にて Forward プロセスを行う.
- Forward 結果を親 GPU に集約する.
2. 親 GPU にてロスを計算
3. Backward
4. 親 GPU のモデルパラメータを更新
DataParallel に向かない設定
パラメータの多いモデル
イテレーションごとにモデルのコピーが発生するので,パラメータの多いモデルはコピーのオーバーヘッドが大きく,DataParallel の恩恵を受けられない.全結合層だけで構成されるモデルとかがその例.
基本的に畳み込み層で構成されていて,最終層だけ全結合層とかの場合は,全結合層を DataParallel の対象から外すというハックを行うことで大幅な高速化が期待できるそう.参考コード
自分でも全結合層を DataParallel の対象から外す処理を試してみた.イテレーションが 42秒から 40秒になるという,微妙な差.
小さいバッチサイズ
バッチサイズが小さくても,パラメータのコピーに対して Forward/Backward の計算が少なくなるので,DataParallel に向かない.十分なバッチサイズが確保できるよう,大きすぎないモデル設計を心がけるか,とてもメモリの大きい GPU を用意するかしよう.
RNN(特に LSTM)
DataParallel するには,batchFirst である必要がある.その他,色々とエラーが出て闇が深そう. さらに,LSTM はセルが複雑なことをするので速くならないらしい.説明をめちゃくちゃ端折ったので,興味がある人はこちらの議論を見てください.
.
並列化2:モデル並列 & multiprocessing
今回並列化したモデルは box_generator/shape_generator and shape_discriminator/image_generator and image_discriminator の 3 つに計算グラフを完全に分離できる.したがって,これらを別々の GPU に乗せたうえで各 GPU でイテレーションを並列実行することが割と簡単にできる. 計算グラフが切れないガチのモデル並列は一つ下のレイヤからいじる必要があって闇が深いので,この世のどんな GPU にも乗り切らないめっさでかいモデルを絶対に動かさなきゃいけない事情がない限り避けたほうがいいと思う.
実装方法
1. 各モデルを別々の GPU に載せる
以下のようにすれば良い.
device0 = torch.device("cuda:0") device1 = torch.device("cuda:1") device2 = torch.device("cuda:2") model.box_generator.to(device0) model.shape_generator.to(device1) model.image_generator.to(device2) box_input_tensor = box_input_tensor.to(device0) shape_input_tensor = shape_input_tensor.to(device1) image_input_tensor = image_input_tensor.to(device2)
ただし,全てのテンソルをちゃんと正しい GPU に割り当てる必要があるので,内部で .cuda()
とかで雑に初期化していると別々の GPU にあるテンソルで演算しようとしてんじゃねぇよエラーが出る.頑張って修正しよう.
2. multiprocessing
モデルを別々の GPU に乗せただけでは並列に実行されないので,並列に実行するために multiprocessing する必要がある.Python の multiprocessing の PyTroch ラッパーがあるので,それを使う.
Document にサンプル実装がある.訓練フェーズを関数化して,multiprocessing.Process に渡せば,その訓練が fork されて独立に動き出す.
CUDA initialization error が発生する場合は,mp.set_start_method('spawn')
を torch.manual_seed(args.seed)
の前で宣言すると回避できる.参考
spawn は子プロセスを開始する方式の一種で,指定しない場合は多分 fork だった気がする.デフォルトでは子プロセスの中でも CUDA の初期化を行ってしまうせいでエラーが発生するらしい.(CUDA の初期化は 1 回しかしちゃダメ)
spawn はプロセスを一度 Pickle 漬けにするので,lambda などのPickle 漬けできない書き方を解消する必要がある.また,Pickle 漬け処理が重たいので,遅い.悲しい.
.
2 種類の並列化を試した結果と感想
どちらも lab のサーバだとそんなに速くならなくて悲しかった.それぞれ,「batch_size が小さすぎる」「spwan が遅すぎる」というネックがあり,それを解消できなかった. ABCI で試したところ,DataParallel で 1.3 倍くらいには速くなったので,それでやっていくことにした.
PyTorch の DataParallel は基本的に CV 系のモデルを想定していて,NLP 系のモデルに向いていないのが悲しかった.使う分には楽なので,使えるところで局所的に使うのが賢そう. multiprocessing はそもそも PyTorch でそこまでサポートされていなくて,エラー回避が大変だったし,効果が薄かった. DataParallel を(上手く)使うことをオススメする.
.
リファレンス
Hong et al., CVPR2018: S. Hong, D. Yang, J. Choi and H. Lee,``Inferring Semantic Layout for Hierarchical Text-to-Image Synthesis,'' In proc. of CVPR2018. pdf