Simple Baseline for Image Restoration
配置文件
default_scope = 'mmedit'
save_dir = './work_dirs/'
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook',
interval=5000,
out_dir='./work_dirs/',
by_epoch=False,
max_keep_ckpts=10,
save_best='PSNR',
rule='greater',
save_optimizer=True),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),
dist_cfg=dict(backend='nccl'))
log_level = 'INFO'
log_processor = dict(type='EditLogProcessor', window_size=100, by_epoch=False)
load_from = None
resume = False
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='ConcatImageVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
fn_key='gt_path',
img_keys=['gt_img', 'input', 'pred_img'],
bgr2rgb=False)
custom_hooks = [dict(type='BasicVisualizationHook', interval=1)]
experiment_name = 'nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
work_dir = './work_dirs/nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
model = dict(
type='BaseEditModel',
generator=dict(
type='NAFNetLocal',
img_channels=3,
mid_channels=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
dec_blk_nums=[1, 1, 1, 1]),
pixel_loss=dict(type='PSNRLoss'),
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0.0, 0.0, 0.0],
std=[255.0, 255.0, 255.0]))
train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]
val_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='PackEditInputs')
]
dataset_type = 'BasicImageDataset'
train_dataloader = dict(
num_workers=8,
batch_size=8,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='./data/gopro/train',
data_prefix=dict(gt='sharp', img='blur'),
ann_file='meta_info_gopro_train.txt',
pipeline=[
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='vertical'),
dict(
type='RandomTransposeHW',
keys=['img', 'gt'],
transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]))
val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='./data/gopro/test',
ann_file='meta_info_gopro_test.txt',
data_prefix=dict(gt='sharp', img='blur'),
pipeline=[
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='PackEditInputs')
]))
test_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='./data/gopro/test',
ann_file='meta_info_gopro_test.txt',
data_prefix=dict(gt='sharp', img='blur'),
pipeline=[
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='PackEditInputs')
]))
val_evaluator = [dict(type='MAE'), dict(type='PSNR'), dict(type='SSIM')]
test_evaluator = [dict(type='MAE'), dict(type='PSNR'), dict(type='SSIM')]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=400000, val_interval=20000)
val_cfg = dict(type='EditValLoop')
test_cfg = dict(type='EditTestLoop')
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.001, weight_decay=0.001, betas=(0.9, 0.9)))
param_scheduler = dict(
type='CosineAnnealingLR', by_epoch=False, T_max=400000, eta_min=1e-07)
randomness = dict(seed=10, diff_rank_seed=True)