Simple Baseline for Image Restoration

2023-04-22 Views512字4 min read

配置文件

default_scope = 'mmedit'
save_dir = './work_dirs/'
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(
        type='CheckpointHook',
        interval=5000,
        out_dir='./work_dirs/',
        by_epoch=False,
        max_keep_ckpts=10,
        save_best='PSNR',
        rule='greater',
        save_optimizer=True),
    sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),
    dist_cfg=dict(backend='nccl'))
log_level = 'INFO'
log_processor = dict(type='EditLogProcessor', window_size=100, by_epoch=False)
load_from = None
resume = False
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='ConcatImageVisualizer',
    vis_backends=[dict(type='LocalVisBackend')],
    fn_key='gt_path',
    img_keys=['gt_img', 'input', 'pred_img'],
    bgr2rgb=False)
custom_hooks = [dict(type='BasicVisualizationHook', interval=1)]
experiment_name = 'nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
work_dir = './work_dirs/nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
model = dict(
    type='BaseEditModel',
    generator=dict(
        type='NAFNetLocal',
        img_channels=3,
        mid_channels=64,
        enc_blk_nums=[1, 1, 1, 28],
        middle_blk_num=1,
        dec_blk_nums=[1, 1, 1, 1]),
    pixel_loss=dict(type='PSNRLoss'),
    train_cfg=dict(),
    test_cfg=dict(),
    data_preprocessor=dict(
        type='EditDataPreprocessor',
        mean=[0.0, 0.0, 0.0],
        std=[255.0, 255.0, 255.0]))
train_pipeline = [
    dict(type='LoadImageFromFile', key='img'),
    dict(type='LoadImageFromFile', key='gt'),
    dict(type='SetValues', dictionary=dict(scale=1)),
    dict(
        type='Flip',
        keys=['img', 'gt'],
        flip_ratio=0.5,
        direction='horizontal'),
    dict(
        type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
    dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
    dict(type='PairedRandomCrop', gt_patch_size=256),
    dict(type='PackEditInputs')
]
val_pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='PackEditInputs')
]
dataset_type = 'BasicImageDataset'
train_dataloader = dict(
    num_workers=8,
    batch_size=8,
    persistent_workers=False,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type='BasicImageDataset',
        metainfo=dict(dataset_type='gopro', task_name='deblur'),
        data_root='./data/gopro/train',
        data_prefix=dict(gt='sharp', img='blur'),
        ann_file='meta_info_gopro_train.txt',
        pipeline=[
            dict(type='LoadImageFromFile', key='img'),
            dict(type='LoadImageFromFile', key='gt'),
            dict(type='SetValues', dictionary=dict(scale=1)),
            dict(
                type='Flip',
                keys=['img', 'gt'],
                flip_ratio=0.5,
                direction='horizontal'),
            dict(
                type='Flip',
                keys=['img', 'gt'],
                flip_ratio=0.5,
                direction='vertical'),
            dict(
                type='RandomTransposeHW',
                keys=['img', 'gt'],
                transpose_ratio=0.5),
            dict(type='PairedRandomCrop', gt_patch_size=256),
            dict(type='PackEditInputs')
        ]))
val_dataloader = dict(
    num_workers=4,
    persistent_workers=False,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='BasicImageDataset',
        metainfo=dict(dataset_type='gopro', task_name='deblur'),
        data_root='./data/gopro/test',
        ann_file='meta_info_gopro_test.txt',
        data_prefix=dict(gt='sharp', img='blur'),
        pipeline=[
            dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
            dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
            dict(type='PackEditInputs')
        ]))
test_dataloader = dict(
    num_workers=4,
    persistent_workers=False,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='BasicImageDataset',
        metainfo=dict(dataset_type='gopro', task_name='deblur'),
        data_root='./data/gopro/test',
        ann_file='meta_info_gopro_test.txt',
        data_prefix=dict(gt='sharp', img='blur'),
        pipeline=[
            dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
            dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
            dict(type='PackEditInputs')
        ]))
val_evaluator = [dict(type='MAE'), dict(type='PSNR'), dict(type='SSIM')]
test_evaluator = [dict(type='MAE'), dict(type='PSNR'), dict(type='SSIM')]
train_cfg = dict(
    type='IterBasedTrainLoop', max_iters=400000, val_interval=20000)
val_cfg = dict(type='EditValLoop')
test_cfg = dict(type='EditTestLoop')
optim_wrapper = dict(
    constructor='DefaultOptimWrapperConstructor',
    type='OptimWrapper',
    optimizer=dict(
        type='AdamW', lr=0.001, weight_decay=0.001, betas=(0.9, 0.9)))
param_scheduler = dict(
    type='CosineAnnealingLR', by_epoch=False, T_max=400000, eta_min=1e-07)
randomness = dict(seed=10, diff_rank_seed=True)
EOF