RandomCrop とかを入力画像と正解画像で統一する方法

経緯

PyTorch で画像系の学習をするときに, RandomFlipRandomCrop をよく使う.最近研究関連で,オブジェクト位置,セグメンテーション,画像を順に生成していく既存モデルを構築しており,それらのデータでランダムな部分を統一する必要があった.なかなかいい感じの記事に出会えずどうやるか調べるのに苦労したので結論をまとめておく.

.

想定する読者

  • PyTorch で画像物体検知を実装しようとしている人
  • クラスの継承がわかる人

.

.

RandomCrop とかを入力画像と正解画像で統一する方法

RandomCrop とかを継承して,__call__ だけ上書きする.

from torchvision import transforms
from torchvision.transforms import functional as F


class UnitedRandomCrop(transforms.RandomCrop):

    def __call__(self, *args, **kwargs):
        img, segms, bboxes = args
        i, j, th, tw = self.get_params(img, self.size)
        cropped_img = F.crop(img, i, j, th, tw)
        cropped_segms = [F.crop(segm, i, j, th, tw) for segm in segms]
        cropped_bboxes = [self.crop_bbox(bbox, i, j, th, tw) for bbox in bboxes]
        return cropped_img, cropped_segms, cropped_bboxes

    @staticmethod
    def crop_bbox(bbox, i, j, th, tw):
        x, y, bw, bh = bbox
        start_x = max(0, x - j)
        start_y = max(0, y - i)
        stop_x = min(tw, x + bw - j)
        stop_y = min(th, y + bh - i)
        cropped_bbox = [start_x, start_y, stop_x - start_x, stop_y - start_y]
        return cropped_bbox

.

詳細

get_params でランダムなパラメータを取得している.このパラメータを使って,__call__ 内でデータに必要な変換を記述して返す.基本的には本家の記述に従えば OK.

get_paramsこの issue で導入が検討された関数.オブジェクト検知タスクにおいて,ランダム部分を入力画像と正解画像(セグメンテーション)で統一したいという欲求が発生し,様々な議論がなされていた所,fmassa さんがいい感じにまとめて問題提起を行ったのがこの issue .この議論の結論は RandomCrop とかのクラスで,パラメータ作成と実際の画像の変換を分け,また,画像変換部分は関数として提供することである.これによってユーザが自由にサブクラスや関数を作れるようになるとのこと.

全ての RandomHoge クラスに get_params が実装されているわけではないので,注意が必要.執筆時点(2019/04/08)で get_params メソッドを持つのは RandomCrop RandomResizedCrop ColorJitter RandomRotation RandomAffine.例えば RandomFlip には get_param メソッドがない.確かに単純すぎて分けるほどではない気がする.

.

あとがき

英語の issue ツリーを読むの疲れた.

もしもっといい方法があるならぜひ教えてください!