mmdetection-3.x学习笔记——测试时增强相关配置(rtmdet_tta.py)

2023-03-28 Views mmdetection-3.x561字3 min read

测试时增强(Test time augmentation,后文简称 TTA)是一种测试阶段的数据增强策略,旨在测试过程中,对同一张图片做翻转、缩放等各种数据增强,并在增强后的图像上进行测试,最后将增强后每张图片预测的结果还原到原始尺寸并做融合,以获得更加准确的预测结果。

使用测试时增强需要使用TTA对模型进行再次封装,为了让用户更加方便地使用 TTA,MMEngine 提供了 BaseTTAModel 类,用户只需按照任务需求,继承 BaseTTAModel 类,实现不同的 TTA 策略即可。TTA 的核心实现通常分为两个部分:

  1. 测试时的数据增强:测试时数据增强主要在 MMCV 中实现,可以参考 TestTimeAug 的 API 文档,本文档不再赘述。
  2. 模型推理以及结果融合:BaseTTAModel 的主要功能就是实现这一部分,BaseTTAModel.test_step 会解析测试时增强后的数据并进行推理。用户继承 BaseTTAModel 后只需实现相应的融合策略即可。
tta_model = dict(
    type='DetTTAModel', # TTA模型的类型
    tta_cfg=dict( # TTA的后处理配置,主要就是非极大值抑制
        nms=dict(
            type='nms',  # NMS 的类型
            iou_threshold=0.6),  # NMS 的阈值
            max_per_img=100)) # NMS 后要保留的 box 数量。

img_scales = [(640, 640), (320, 320), (960, 960)] # 用于定义下面tta_pipeline的图片缩放流程
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [   # 对测试图片进行缩放,由于上面指定了3个图片尺寸,所以缩放后得到3个不同大小的图片
                dict(type='Resize', scale=s, keep_ratio=True)
                for s in img_scales
            ],
            [   # 对图片进行随机反转,这里做了两次,一次概率为1,一次概率为0,即每张图片缩放(Resize)后都会进行翻转增强,变成两张图片。结合上面的3次缩放,所以是1张图片变6张图片
                # ``RandomFlip`` must be placed before ``Pad``, otherwise
                # bounding box coordinates after flipping cannot be
                # recovered correctly.
                dict(type='RandomFlip', prob=1.),
                dict(type='RandomFlip', prob=0.)
            ],
            [   # 把所有增强后的图片pad成指定大小
                dict(
                    type='Pad',
                    size=(960, 960),
                    pad_val=dict(img=(114, 114, 114))),
            ],
            [   # 最后捡出需要的数据
                dict(
                    type='PackDetInputs',
                    meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                               'scale_factor', 'flip', 'flip_direction'))
            ]
        ])
]

EOF