mmdetection-3.x学习笔记——配置程序运行环境

2023-02-24 Views mmdetection-3.x730字4 min read

mmdetection会在Runner类初始化时调用setup_env(env_cfg)函数来配置程序的运行环境,包括多线程和分布式等环境信息,

def setup_env(self, env_cfg: Dict) -> None:
        """Setup environment.

        An example of ``env_cfg``::

            env_cfg = dict(
                cudnn_benchmark=True,
                mp_cfg=dict(
                    mp_start_method='fork',
                    opencv_num_threads=0
                ),
                dist_cfg=dict(backend='nccl'),
                resource_limit=4096
            )

        Args:
            env_cfg (dict): Config for setting environment.
        """
        # 如果env_cfg中存在'cudnn_benchmark'且值为True,则开启cudnn benchmark
        if env_cfg.get('cudnn_benchmark'): 
            # cudnn benchmark对针对卷积、池化等等常见操作进行了底层优化,
            # 开启可以提升速度,但是如果卷积层的设置一直变化,将会导致程序不停地做优化,
            # 反而会耗费更多的时间,所以一般设置为False。
            torch.backends.cudnn.benchmark = True 

        mp_cfg: dict = env_cfg.get('mp_cfg', {}) # 获取多进程环境配置
        set_multi_processing(**mp_cfg, distributed=self.distributed) # 配置多进程环境

        # init distributed env first, since logger depends on the dist info.
        if self.distributed and not is_distributed():
            dist_cfg: dict = env_cfg.get('dist_cfg', {})
            # 初始化分布式环境,launcher指定加载多进程的方式,
            # 可以是 'pytorch', 'mpi', 'slurm' 和 'none','none'表示不用分布式训练
            init_dist(self.launcher, **dist_cfg)

        self._rank, self._world_size = get_dist_info() # rank:进程序号,world_size: 全金进程数

        timestamp = torch.tensor(time.time(), dtype=torch.float64)
        # broadcast timestamp from 0 process to other processes
        broadcast(timestamp)
        self._timestamp = time.strftime('%Y%m%d_%H%M%S',
                                        time.localtime(timestamp.item()))

        # https://github.com/pytorch/pytorch/issues/973
        # set resource limit
        if platform.system() != 'Windows':
            import resource
            rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
            base_soft_limit = rlimit[0]
            hard_limit = rlimit[1]
            soft_limit = min(
                max(env_cfg.get('resource_limit', 4096), base_soft_limit),
                hard_limit)
            resource.setrlimit(resource.RLIMIT_NOFILE,
                               (soft_limit, hard_limit))

1、配置多进程环境

# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import warnings

import torch.multiprocessing as mp


def set_multi_processing(mp_start_method: str = 'fork',
                         opencv_num_threads: int = 0,
                         distributed: bool = False) -> None:
    """Set multi-processing related environment.

    Args:
        mp_start_method (str): Set the method which should be used to start
            child processes. Defaults to 'fork'.
        opencv_num_threads (int): Number of threads for opencv.
            Defaults to 0.
        distributed (bool): True if distributed environment.
            Defaults to False.
    """
    # set multi-process start method as `fork` to speed up the training
    if platform.system() != 'Windows':
        current_method = mp.get_start_method(allow_none=True) # 获取当前启动多进程的方法
        # 如果当前启动多进程的方法不为空且与配置文件指定的不一样,输出警告
        if (current_method is not None and current_method != mp_start_method):
            warnings.warn(
                f'Multi-processing start method `{mp_start_method}` is '
                f'different from the previous setting `{current_method}`.'
                f'It will be force set to `{mp_start_method}`. You can '
                'change this behavior by changing `mp_start_method` in '
                'your config.')
        mp.set_start_method(mp_start_method, force=True) # 设置为配置文件中的多进程启动方式

    try:
        import cv2

        # disable opencv multithreading to avoid system being overloaded
        cv2.setNumThreads(opencv_num_threads)
    except ImportError:
        pass

    # setup OMP threads
    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
    if 'OMP_NUM_THREADS' not in os.environ and distributed:
        omp_num_threads = 1
        warnings.warn(
            'Setting OMP_NUM_THREADS environment variable for each process'
            f' to be {omp_num_threads} in default, to avoid your system '
            'being overloaded, please further tune the variable for '
            'optimal performance in your application as needed.')
        # os.environ 包含跟系统有关的信息,如os.environ['USER']:当前使用用户
        os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 

    # setup MKL threads
    if 'MKL_NUM_THREADS' not in os.environ and distributed:
        mkl_num_threads = 1
        warnings.warn(
            'Setting MKL_NUM_THREADS environment variable for each process'
            f' to be {mkl_num_threads} in default, to avoid your system '
            'being overloaded, please further tune the variable for '
            'optimal performance in your application as needed.')
        os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)

EOF