RandomCrop とかを入力画像と正解画像で統一する方法
経緯
PyTorch で画像系の学習をするときに, RandomFlip
や RandomCrop
をよく使う.最近研究関連で,オブジェクト位置,セグメンテーション,画像を順に生成していく既存モデルを構築しており,それらのデータでランダムな部分を統一する必要があった.なかなかいい感じの記事に出会えずどうやるか調べるのに苦労したので結論をまとめておく.
.
想定する読者
- PyTorch で画像物体検知を実装しようとしている人
- クラスの継承がわかる人
.
.
RandomCrop とかを入力画像と正解画像で統一する方法
RandomCrop
とかを継承して,__call__
だけ上書きする.
from torchvision import transforms from torchvision.transforms import functional as F class UnitedRandomCrop(transforms.RandomCrop): def __call__(self, *args, **kwargs): img, segms, bboxes = args i, j, th, tw = self.get_params(img, self.size) cropped_img = F.crop(img, i, j, th, tw) cropped_segms = [F.crop(segm, i, j, th, tw) for segm in segms] cropped_bboxes = [self.crop_bbox(bbox, i, j, th, tw) for bbox in bboxes] return cropped_img, cropped_segms, cropped_bboxes @staticmethod def crop_bbox(bbox, i, j, th, tw): x, y, bw, bh = bbox start_x = max(0, x - j) start_y = max(0, y - i) stop_x = min(tw, x + bw - j) stop_y = min(th, y + bh - i) cropped_bbox = [start_x, start_y, stop_x - start_x, stop_y - start_y] return cropped_bbox
.
詳細
get_params
でランダムなパラメータを取得している.このパラメータを使って,__call__
内でデータに必要な変換を記述して返す.基本的には本家の記述に従えば OK.
get_params
はこの issue で導入が検討された関数.オブジェクト検知タスクにおいて,ランダム部分を入力画像と正解画像(セグメンテーション)で統一したいという欲求が発生し,様々な議論がなされていた所,fmassa さんがいい感じにまとめて問題提起を行ったのがこの issue .この議論の結論は RandomCrop
とかのクラスで,パラメータ作成と実際の画像の変換を分け,また,画像変換部分は関数として提供することである.これによってユーザが自由にサブクラスや関数を作れるようになるとのこと.
全ての RandomHoge クラスに get_params
が実装されているわけではないので,注意が必要.執筆時点(2019/04/08)で get_params
メソッドを持つのは RandomCrop
RandomResizedCrop
ColorJitter
RandomRotation
RandomAffine
.例えば RandomFlip
には get_param
メソッドがない.確かに単純すぎて分けるほどではない気がする.
.
あとがき
英語の issue ツリーを読むの疲れた.
もしもっといい方法があるならぜひ教えてください!