One Blog Plz go to study now ! https://yarkable.github.io/ Sat, 09 Sep 2023 03:28:00 +0000 Sat, 09 Sep 2023 03:28:00 +0000 Jekyll v3.9.3 MMCV之Config注释详解 <h1 id="前言">前言</h1> <p>对 MMCV Config 类的结构记录一下,这个类主要是将 python dict 或者 json/yaml 文件中的 dict 对象转化成方便操作的 dict 对象,有些细节写的还是很好的,本文档用的 MMCV 的版本为 1.3.5</p> <h1 id="class-configdict">class ConfigDict</h1> <p>这个类别继承了 addict 中的 Dict 类,可以通过访问属性的方式来访问字典中的值,其中重写了 <code class="language-plaintext highlighter-rouge">__missing__</code> 和 <code class="language-plaintext highlighter-rouge">__getattr__</code> 这两个魔法函数,因为对于 addict 中的 Dict,当字典中不存在 key 时会调用 <code class="language-plaintext highlighter-rouge">__missing__</code> 方法返回一个空的字典,而对于 ConfigDict ,当字典中不存在 key 时会直接报错,而不是返回一个默认值。</p> <pre><code class="language-Python">class ConfigDict(Dict): def __missing__(self, name): raise KeyError(name) def __getattr__(self, name): try: value = super(ConfigDict, self).__getattr__(name) except KeyError: ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") except Exception as e: ex = e else: return value raise ex </code></pre> <p>其中,addict 对 python 默认 dict 的加强在于重写了 <code class="language-plaintext highlighter-rouge">__getattr__</code> 和 <code class="language-plaintext highlighter-rouge">__setattr__</code> 函数,这两个函数让用户可以通过访问属性的方式(也就是 <code class="language-plaintext highlighter-rouge">a.b</code>)来访问字典中的值,不过 addict 可以嵌套多层,比较强大,我们可以重写这两个函数来实现一个简单的 demo:</p> <pre><code class="language-Python">class MyDict(dict): def __setattr__(self, __name: str, __value) -&gt; None: print('__setattr__') self[__name] = __value def __getattr__(self, item): print('__getattr__') return self[item] md = MyDict() md.a = 1 print(md.a) # setattr # getattr # 1 </code></pre> <p>不过上述的 demo 还没办法做到嵌套调用,mmcv 官方写的这个最简版本 demo 可以实现嵌套调用,本质上就是对 dict 的值进行深度遍历</p> <pre><code class="language-Python">class MiniDict(dict): def __init__(self, *args): super().__init__() for arg in args: for key, val in arg.items(): # 对字典对象进行属性设置,并进行迭代 self[key] = self._hook(val) def _hook(self, item): if isinstance(item, dict): return MiniDict(item) return item # 递归调用return item# 在.a和['a']时候自动调用 def __getattr__(self, item): return self[item] r = MiniDict(dict(a=dict(b=2))) print(r.a.b) # 2 </code></pre> <h1 id="class-config">class Config</h1> <h2 id="init"><strong>init</strong></h2> <p>初始化函数,一般不会直接创建一个 Config 对象,而是从文件中读取 dict 以及其他信息作为参数传入初始化函数中,返回一个 Config 对象</p> <pre><code class="language-Python">def __init__(self, cfg_dict=None, cfg_text=None, filename=None): if cfg_dict is None: cfg_dict = dict() elif not isinstance(cfg_dict, dict): raise TypeError('cfg_dict must be a dict, but ' f'got {type(cfg_dict)}') # 传进来的 dict 里面不能有预留的 key,不然报错 for key in cfg_dict: if key in RESERVED_KEYS: raise KeyError(f'{key} is reserved for config file') # Config 没有显式的父类,所以继承了 Object 这个类 # 调用父类的方法是因为 Config 重写了 __setattr__ 和 __getattr__ 函数 # 所以要用父类的方法,不然就会陷入死循环 super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) super(Config, self).__setattr__('_filename', filename) if cfg_text: text = cfg_text elif filename: with open(filename, 'r') as f: text = f.read() else: text = '' super(Config, self).__setattr__('_text', text) </code></pre> <h2 id="fromfile">fromfile</h2> <p>这个是最重要的函数,也就是从文件中读取 config,然后变成 Config 对象。由于是静态函数,所以可以不通过对象来调用,可以直接通过 Config 类调用,也就是 <code class="language-plaintext highlighter-rouge">Config.fromfile</code></p> <pre><code class="language-Python">@staticmethod def fromfile(filename, use_predefined_variables=True, import_custom_modules=True): cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) # 可以在这里灵活导入一些自定义的模块 if import_custom_modules and cfg_dict.get('custom_imports', None): import_modules_from_strings(**cfg_dict['custom_imports']) return Config(cfg_dict, cfg_text=cfg_text, filename=filename) </code></pre> <h2 id="_file2dict">_file2dict</h2> <p>fromfile 的主要逻辑,很重要!</p> <pre><code class="language-Python">@staticmethod def _file2dict(filename, use_predefined_variables=True): filename = osp.abspath(osp.expanduser(filename)) check_file_exist(filename) fileExtname = osp.splitext(filename)[1] if fileExtname not in ['.py', '.json', '.yaml', '.yml']: raise IOError('Only py/yml/yaml/json type are supported now!') # 这里创建了一个临时文件来保存原来的 config 文件,是为了让文件名可以出现 `a.b.py` 这种形式 # 如果 config 是存储在 py 文件中的话,则是通过 import 来进行读取的,如果 import a.b, # 则会认为 a 是一个包的名字,就会出错,其实模块名叫 a.b, # 因此这里就巧妙地通过操作系统的 copy 将原文件换了个合理的名字保存在 tmp 文件夹中 # 避免了导入模块时会发生的错误 with tempfile.TemporaryDirectory() as temp_config_dir: temp_config_file = tempfile.NamedTemporaryFile( dir=temp_config_dir, suffix=fileExtname) if platform.system() == 'Windows': temp_config_file.close() temp_config_name = osp.basename(temp_config_file.name) # 替换一些 mmcv 预定义好的模版变量,默认是 True if use_predefined_variables: Config._substitute_predefined_vars(filename, temp_config_file.name) else: shutil.copyfile(filename, temp_config_file.name) if filename.endswith('.py'): temp_module_name = osp.splitext(temp_config_name)[0] # 将 temp_config_dir 添加到环境变量中,方便找到模块进行导入 sys.path.insert(0, temp_config_dir) # 用 ast 抽象语法树检查 python 文件的格式 Config._validate_py_syntax(filename) # 将存储着配置的 py 文件导入 mod = import_module(temp_module_name) sys.path.pop(0) # 只要是不带有 __ 开头的 key 全都保存在 cfg_dict 中 # cfg_dict = { name: value for name, value in mod.__dict__.items() if not name.startswith('__') } # delete imported module # 存储完之后就把这个模块给删了 del sys.modules[temp_module_name] # 如果是其他后缀的文件的话就直接用 mmcv 导入成字典格式 elif filename.endswith(('.yml', '.yaml', '.json')): import mmcv cfg_dict = mmcv.load(temp_config_file.name) # close temp file temp_config_file.close() cfg_text = filename + '\n' with open(filename, 'r', encoding='utf-8') as f: # Setting encoding explicitly to resolve coding issue on windows cfg_text += f.read() # BASE_KEY 默认是 _base_,为继承的配置 if BASE_KEY in cfg_dict: cfg_dir = osp.dirname(filename) # 获取到 base 文件名,用列表装,因为 base 文件可能有很多个 base_filename = cfg_dict.pop(BASE_KEY) base_filename = base_filename if isinstance( base_filename, list) else [base_filename] cfg_dict_list = list() cfg_text_list = list() for f in base_filename: # 读取 base 文件中的配置,这边其实是个递归,就是 base 文件中也允许有 _base_ 字段 _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) cfg_dict_list.append(_cfg_dict) cfg_text_list.append(_cfg_text) base_cfg_dict = dict() for c in cfg_dict_list: # 不同的 base 文件中不允许存在相同的 key if len(base_cfg_dict.keys() &amp; c.keys()) &gt; 0: raise KeyError('Duplicate key is not allowed among bases') base_cfg_dict.update(c) # 将 base 文件中的配置合并到该文件的配置中 base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) cfg_dict = base_cfg_dict # merge cfg_text cfg_text_list.append(cfg_text) cfg_text = '\n'.join(cfg_text_list) # 所以 cfg_dict 就是合并之后处理完的配置,形式为 dict # cfg_text 就是字符串,包含了所有 base 文件中的内容以及该文件的配置内容 return cfg_dict, cfg_text </code></pre> <h2 id="_substitute_predefined_vars">_substitute_predefined_vars</h2> <p>这个函数就是预定义了一些模版变量,在实际创建对象的时候将这些变量替换成用户独特的值。</p> <pre><code class="language-Python">@staticmethod def _substitute_predefined_vars(filename, temp_config_name): # 这里获取到了文件的 4 种属性 file_dirname = osp.dirname(filename) file_basename = osp.basename(filename) file_basename_no_extension = osp.splitext(file_basename)[0] file_extname = osp.splitext(filename)[1] # 支持下面这些属性 support_templates = dict( fileDirname=file_dirname, fileBasename=file_basename, fileBasenameNoExtension=file_basename_no_extension, fileExtname=file_extname) with open(filename, 'r', encoding='utf-8') as f: # Setting encoding explicitly to resolve coding issue on windows config_file = f.read() for key, value in support_templates.items(): # 通过正则表达式将上面的 4 个模版替换成真实值 # 正则表达式的意思是 ,key 左右可以有0或0以上个空格 regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' value = value.replace('\\', '/') config_file = re.sub(regexp, value, config_file) # 这个上面分析过了,就是将原本的文件经过一些处理之后存在临时文件中,方便对 py 文件进行导入 with open(temp_config_name, 'w') as tmp_config_file: tmp_config_file.write(config_file) </code></pre> <h2 id="import_modules_from_strings">import_modules_from_strings</h2> <p>这是 <code class="language-plaintext highlighter-rouge">mmcv.utils.misc</code> 中的一个函数,用来根据字符串导入 python 的模块。</p> <pre><code class="language-Python">def import_modules_from_strings(imports, allow_failed_imports=False): """Import modules from the given list of strings. Args: imports (list | str | None): The given module names to be imported. allow_failed_imports (bool): If True, the failed imports will return None. Otherwise, an ImportError is raise. Default: False. Returns: list[module] | module | None: The imported modules. Examples: &gt;&gt;&gt; osp, sys = import_modules_from_strings( ... ['os.path', 'sys']) &gt;&gt;&gt; import os.path as osp_ &gt;&gt;&gt; import sys as sys_ &gt;&gt;&gt; assert osp == osp_ &gt;&gt;&gt; assert sys == sys_ """ if not imports: return single_import = False if isinstance(imports, str): single_import = True imports = [imports] if not isinstance(imports, list): raise TypeError( f'custom_imports must be a list but got type {type(imports)}') imported = [] for imp in imports: if not isinstance(imp, str): raise TypeError( f'{imp} is of type {type(imp)} and cannot be imported.') try: imported_tmp = import_module(imp) except ImportError: if allow_failed_imports: warnings.warn(f'{imp} failed to import and is ignored.', UserWarning) imported_tmp = None else: raise ImportError imported.append(imported_tmp) if single_import: imported = imported[0] return imported </code></pre> <h2 id="_merge_a_into_b">_merge_a_into_b</h2> <p>在 <code class="language-plaintext highlighter-rouge">_file2dict</code> 中,我们使用了 <code class="language-plaintext highlighter-rouge">Config._merge_a_into_b(cfg_dict, base_cfg_dict)</code> 将 base 文件中的配置和当前文件中的配置进行了合并,这里看看具体是怎么做的。</p> <pre><code class="language-Python">@staticmethod def _merge_a_into_b(a, b, allow_list_keys=False): """merge dict ``a`` into dict ``b`` (non-inplace). Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid in-place modifications. Args: a (dict): The source dict to be merged into ``b``. b (dict): The origin dict to be fetch keys from ``a``. allow_list_keys (bool): If True, int string keys (e.g. '0', '1') are allowed in source ``a`` and will replace the element of the corresponding index in b if b is a list. Default: False. Returns: dict: The modified dict of ``b`` using ``a``. Examples: # Normally merge a into b. &gt;&gt;&gt; Config._merge_a_into_b( ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) {'obj': {'a': 2}} # Delete b first and merge a into b. &gt;&gt;&gt; Config._merge_a_into_b( ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) {'obj': {'a': 2}} # b is a list &gt;&gt;&gt; Config._merge_a_into_b( ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) [{'a': 2}, {'b': 2}] """ b = b.copy() for k, v in a.items(): # 允许列表作为 key,一般不会用这种情况,先不管 if allow_list_keys and k.isdigit() and isinstance(b, list): k = int(k) if len(b) &lt;= k: raise KeyError(f'Index {k} exceeds the length of list {b}') b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) # 如果 base 文件和当前文件中都有相同的 key,且当前 key 的 value 不含 __delete__ # 那就对这个 key 的 value 进行递归的 _merge_a_into_b elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): allowed_types = (dict, list) if allow_list_keys else dict if not isinstance(b[k], allowed_types): raise TypeError( f'{k}={v} in child config cannot inherit from base ' f'because {k} is a dict in the child config but is of ' f'type {type(b[k])} in base config. You may set ' f'`{DELETE_KEY}=True` to ignore the base config') b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) else: # 很巧妙,对 b 中不存在的 key 直接加上去 # 如果有 __delete__ 的 key 也直接替换成 a 的 # 而且上一个 if 已经将 __delete__ 弹出,此时,b 的 key 中已经不包含 __delete__ 了 b[k] = v return b </code></pre> <h2 id="魔法函数">魔法函数</h2> <p>Config 类对很多魔法函数都重写了,旨在通过 addict 对 python 的字典更加方便地访问</p> <pre><code class="language-Python"> def __getattr__(self, name): return getattr(self._cfg_dict, name) def __getitem__(self, name): return self._cfg_dict.__getitem__(name) def __setattr__(self, name, value): if isinstance(value, dict): value = ConfigDict(value) self._cfg_dict.__setattr__(name, value) def __setitem__(self, name, value): if isinstance(value, dict): value = ConfigDict(value) self._cfg_dict.__setitem__(name, value) </code></pre> <h2 id="reference">Reference</h2> <p>https://zhuanlan.zhihu.com/p/346203167</p> <p>https://blog.csdn.net/qq_38253797/article/details/121471389</p> Fri, 05 May 2023 00:00:00 +0000 https://yarkable.github.io/2023/05/05/MMCV%E4%B9%8BConfig%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ https://yarkable.github.io/2023/05/05/MMCV%E4%B9%8BConfig%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ linux object detection deep learning mmcv MMDetection & pycocotools eval 详解 <h1 id="preface">preface</h1> <p>记录 mmdet 对检测器进行评估的过程,以 COCO 数据集为例,所使用到的 mmdet 版本为 2.18.0。本质上其实就是对 pycocotools 的封装调用,特此记录,方便复习。</p> <h1 id="testpy">test.py</h1> <p>首先是在 <code class="language-plaintext highlighter-rouge">tools/test.py</code> 产生推理过后的结果,然后再用每一个数据集的 <code class="language-plaintext highlighter-rouge">evaluate</code> 函数进行性能的评估</p> <pre><code class="language-Python">if not distributed: model = MMDataParallel(model, device_ids=[0]) # 调用检测器的前向推理函数得到推理的结果 outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, args.show_score_thr) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect) # 上面得到的 outputs 是一个列表 # 列表长度为验证集的图片数 # 里面每一个元素的长度都是数据集的类别数 # 再里面的每一个元素就是检测器检测出来的对应该类别的框坐标和置信度(经过阈值过滤和 nms 了) # 拿 COCO 来说,这里的列表长度就是 5000,里面每一个元素又是长度为 80 的列表 # 里面的结果已经经过阈值过滤以及利用 iou 阈值进行 nms,相关参数在 Head 的 test_cfg,每个类别保留的框的最大个数默认为 100 rank, _ = get_dist_info() if rank == 0: if args.out: print(f'\nwriting results to {args.out}') mmcv.dump(outputs, args.out) kwargs = {} if args.eval_options is None else args.eval_options # 这个选项调用 _det2json 将大列表转成 COCO 标准的 json 格式的列表,一般用于提交检测结果至服务器评估 # {'image_id': x, 'bbox': x, 'score': x, 'category_id': x} if args.format_only: dataset.format_results(outputs, **kwargs) # 这里是 eval 的逻辑 if args.eval: # 如果是 coco 的话一般这里的评估方式是 bbox eval_kwargs = cfg.get('evaluation', {}).copy() # hard-code way to remove EvalHook args for key in [ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' ]: eval_kwargs.pop(key, None) eval_kwargs.update(dict(metric=args.eval, **kwargs)) # 直接调用数据集的 evaluate 函数进行评估 # 所以 mmdet 每一个数据集都必须实现这个方法 metric = dataset.evaluate(outputs, **eval_kwargs) print(metric) metric_dict = dict(config=args.config, metric=metric) if args.work_dir is not None and rank == 0: mmcv.dump(metric_dict, json_file) </code></pre> <h1 id="快速debug的脚本">快速debug的脚本</h1> <p>如果是用官方提供的 test 脚本的话,每一次 debug 都得重新推理一遍数据集的结果,很麻烦,我们可以用 <code class="language-plaintext highlighter-rouge">--out</code> 参数将推理的结果保存成 pkl 格式,然后用下面的脚本直接导入,速度很快。下面这个脚本其实就是对 test.py 的精简版,只留了数据集相关的配置,我叫它为 <code class="language-plaintext highlighter-rouge">naive_test.py</code>。</p> <pre><code class="language-Python">import pickle import argparse from mmcv import Config, DictAction from mmdet.datasets import build_dataset def parse_args(): parser = argparse.ArgumentParser( description='MMDet test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('results', help='pkl format infer results path') parser.add_argument( '--eval', type=str, nargs='+', help='evaluation metrics, which depends on the dataset, e.g., "bbox",' ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') parser.add_argument( '--eval-options', nargs='+', action=DictAction, help='custom options for evaluation, the key-value pair in xxx=yyy ' 'format will be kwargs for dataset.evaluate() function') args = parser.parse_args() return args def main(): args = parse_args() cfg = Config.fromfile(args.config) # 这句不加的话会报错,train 的时候会过滤掉没有标注的图片,所以最终的图片数不对 cfg.data.test.test_mode = True dataset = build_dataset(cfg.data.test) with open(args.results, 'rb') as f: outputs = pickle.load(f) kwargs = {} if args.eval_options is None else args.eval_options if args.eval: eval_kwargs = cfg.get('evaluation', {}).copy() # hard-code way to remove EvalHook args for key in [ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' ]: eval_kwargs.pop(key, None) eval_kwargs.update(dict(metric=args.eval, **kwargs)) metric = dataset.evaluate(outputs, **eval_kwargs) print(metric) if __name__ == '__main__': main() </code></pre> <h1 id="mmdetdatasetcocopy">mmdet/dataset/coco.py</h1> <h2 id="evaluate函数">evaluate函数</h2> <p>上面看到流程已经进入了数据集的 <code class="language-plaintext highlighter-rouge">evaluate</code> 函数,这里我们就看细看一下 evaluate 的细节</p> <pre><code class="language-Python">def evaluate(self, results, metric='bbox', logger=None, jsonfile_prefix=None, classwise=False, proposal_nums=(100, 300, 1000), iou_thrs=None, metric_items=None): """Evaluation in COCO protocol. Args: results (list[list | tuple]): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. Options are 'bbox', 'segm', 'proposal', 'proposal_fast'. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. jsonfile_prefix (str | None): The prefix of json files. It includes the file path and the prefix of filename, e.g., "a/b/prefix". If not specified, a temp file will be created. Default: None. classwise (bool): Whether to evaluating the AP for each class. proposal_nums (Sequence[int]): Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000). iou_thrs (Sequence[float], optional): IoU threshold used for evaluating recalls/mAPs. If set to a list, the average of all IoUs will also be computed. If not specified, [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. Default: None. metric_items (list[str] | str, optional): Metric items that will be returned. If not specified, ``['AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when ``metric=='bbox' or metric=='segm'``. Returns: dict[str, float]: COCO style evaluation metric. """ # 变成列表方便统一遍历(本身传进来也就是列表) metrics = metric if isinstance(metric, list) else [metric] # COCO 允许的几种 metric allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] for metric in metrics: if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') if iou_thrs is None: # 一般 iou_thrs 不会特意设置,所以默认是计算 mAP 的 # 如果只想计算某个 iou 下的 AP 可以在函数中传入这个参数 iou_thrs = np.linspace( .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) if metric_items is not None: # 这个一般也不会设置,所以默认返回所有的 metric, # 想单独返回某一个 metric 的话可以设置这个参数 if not isinstance(metric_items, list): metric_items = [metric_items] # 这里是将推理生成的结果(列表)按照标准的 COCO 格式变成了 JSON 格式 # 在 test.py 选择 --format-only 参数的话会直接调用这个函数返回结果 # 只不过这里会将结果保存在一个临时文件夹中,因为我们并不需要保存这个结果 # 如果是 bbox 测评的话,result_files 会得到两个字段保存的路径 # {'bbox': '/tmp/tmpj126zjei/results.bbox.json', # 'proposal': '/tmp/tmpj126zjei/results.bbox.json'} result_files, tmp_dir = self.format_results(results, jsonfile_prefix) # 这是 mmdet 额外加的记录 metric 的字典,在 eval 的最后一行会输出 # OrderedDict([('bbox_mAP', x), ('bbox_mAP_50', x), ('bbox_mAP_75', x), ('bbox_mAP_s', x), ('bbox_mAP_m', x), ('bbox_mAP_l', x), ('bbox_mAP_copypaste', x)]) eval_results = OrderedDict() # cocoGt 指的是从验证集中读取到的真实标签 cocoGt = self.coco for metric in metrics: msg = f'Evaluating {metric}...' if logger is None: msg = '\n' + msg print_log(msg, logger=logger) # TODO if metric == 'proposal_fast': ar = self.fast_eval_recall( results, proposal_nums, iou_thrs, logger='silent') log_msg = [] for i, num in enumerate(proposal_nums): eval_results[f'AR@{num}'] = ar[i] log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') log_msg = ''.join(log_msg) print_log(log_msg, logger=logger) continue iou_type = 'bbox' if metric == 'proposal' else metric if metric not in result_files: raise KeyError(f'{metric} is not in results') try: # 将刚刚保存的 COCO 标准结果给重新 load 进来 predictions = mmcv.load(result_files[metric]) if iou_type == 'segm': # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa # When evaluating mask AP, if the results contain bbox, # cocoapi will use the box area instead of the mask area # for calculating the instance area. Though the overall AP # is not affected, this leads to different # small/medium/large mask AP results. for x in predictions: x.pop('bbox') warnings.simplefilter('once') warnings.warn( 'The key "bbox" is deleted for more accurate mask AP ' 'of small/medium/large instances since v2.12.0. This ' 'does not change the overall mAP calculation.', UserWarning) # 通过 loadRes 将检测的结果加载为 cocoDt # 返回了一个 COCO 类的对象 cocoDt = cocoGt.loadRes(predictions) except IndexError: print_log( 'The testing results of the whole dataset is empty.', logger=logger, level=logging.ERROR) break # 直接调用 API 传入 cocoGt 和 cocoDt,以及评估的 iou 的方式 cocoEval = COCOeval(cocoGt, cocoDt, iou_type) # 等同于 self.coco.get_cat_ids(cat_names=self.CLASSES) cocoEval.params.catIds = self.cat_ids # 等同于 self.coco.get_img_ids() cocoEval.params.imgIds = self.img_ids # 传进来的 proposal 参数,格式化成列表传入 cocoEval.params.maxDets = list(proposal_nums) # 默认是从 0.5-0.95 的阈值,可以自己传入参数 cocoEval.params.iouThrs = iou_thrs # mapping of cocoEval.stats coco_metric_names = { 'mAP': 0, 'mAP_50': 1, 'mAP_75': 2, 'mAP_s': 3, 'mAP_m': 4, 'mAP_l': 5, 'AR@100': 6, 'AR@300': 7, 'AR@1000': 8, 'AR_s@1000': 9, 'AR_m@1000': 10, 'AR_l@1000': 11 } if metric_items is not None: for metric_item in metric_items: if metric_item not in coco_metric_names: raise KeyError( f'metric item {metric_item} is not supported') if metric == 'proposal': # 进行 AR 评估的话,不需要传入类别,默认用的是所有类别 # CocoEval 的用法在后面会讲 cocoEval.params.useCats = 0 # 直接三步走,调用 API cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() if metric_items is None: metric_items = [ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ] for item in metric_items: # cocoEval.stats 在 summarize() 之后就可以取到所有的 metric 的值 # coco_metric_names 是为了跟 cocoEval.stats 的值做个映射方便取值 # cocoEval.stats 有 12 个值,前 6 个是 AP 相关,后 6 个是 AR 相关 val = float( f'{cocoEval.stats[coco_metric_names[item]]:.3f}') eval_results[item] = val else: cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() # 计算单个类别的 AP if classwise: # Compute per-category AP # Compute per-category AP # from https://github.com/facebookresearch/detectron2/ precisions = cocoEval.eval['precision'] # precision: (iou, recall, cls, area range, max dets) assert len(self.cat_ids) == precisions.shape[2] results_per_category = [] for idx, catId in enumerate(self.cat_ids): # area range index 0: all area ranges # max dets index -1: typically 100 per image nm = self.coco.loadCats(catId)[0] precision = precisions[:, :, idx, 0, -1] precision = precision[precision &gt; -1] if precision.size: ap = np.mean(precision) else: ap = float('nan') results_per_category.append( (f'{nm["name"]}', f'{float(ap):0.3f}')) num_columns = min(6, len(results_per_category) * 2) results_flatten = list( itertools.chain(*results_per_category)) headers = ['category', 'AP'] * (num_columns // 2) results_2d = itertools.zip_longest(*[ results_flatten[i::num_columns] for i in range(num_columns) ]) table_data = [headers] table_data += [result for result in results_2d] table = AsciiTable(table_data) print_log('\n' + table.table, logger=logger) if metric_items is None: metric_items = [ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' ] for metric_item in metric_items: key = f'{metric}_{metric_item}' val = float( f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' ) eval_results[key] = val # 前面 6 个是记录 AP 的,后面六个是 AR ap = cocoEval.stats[:6] # 改一下这里乘个 100 就可以生成百分号计数的 AP 了,不然总是小数 eval_results[f'{metric}_mAP_copypaste'] = ( f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' f'{ap[4]:.3f} {ap[5]:.3f}') # 细节,还将临时文件给清理了 if tmp_dir is not None: tmp_dir.cleanup() return eval_results </code></pre> <h2 id="小结">小结</h2> <p>所以总结一下上述代码的总体逻辑:要评估 COCO 格式的数据集的话,首先需要产生推理结果,然后将推理结果进行格式化变成 COCO 格式的 json 列表,再通过读入验证集的真实标签,用 pycocotools 的 <code class="language-plaintext highlighter-rouge">loadRes</code>函数将推理结果进一步格式化成标准 COCO 对象。然后确定参数(评估 iou 的类型,得分阈值,评估方式等等),再直接调用 <code class="language-plaintext highlighter-rouge">cocoEval</code> 的 API 就得到了所有的 metric,一般保存在 <code class="language-plaintext highlighter-rouge">cocoEval</code> 的 <code class="language-plaintext highlighter-rouge">eval </code>和 <code class="language-plaintext highlighter-rouge">stats </code>变量中。因此可以用以下代码简单概括:</p> <p>COCOeval 的 <code class="language-plaintext highlighter-rouge">loadRes</code> 函数可以接受字符串也可以接受已经读取好的 json 列表,我们可以直接传入 json 列表保存的路径,这个列表可以用 mmdet test 脚本的 <code class="language-plaintext highlighter-rouge">--format-only</code> 选项生成,然后就可以输出各种指标了</p> <pre><code class="language-Python">from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval import argparse if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-g", "--gt", type=str, help="Assign the groud true path.", default=None) parser.add_argument("-d", "--dt", type=str, help="Assign the detection result path.", default=None) args = parser.parse_args() cocoGt = COCO(args.gt) cocoDt = cocoGt.loadRes(args.dt) cocoEval = COCOeval(cocoGt, cocoDt, "bbox") cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() </code></pre> <h1 id="pycocotoolscocoevalpy">pycocotools/cocoeval.py</h1> <p>上面只从表面上介绍了一下评估的过程,可以看到 mmdet 在评估的代码中还是做了一些封装的,不过总之还是调用 pycocotools 的函数。这边我们就深入细节看看计算 mAP 的过程是怎么实现的。</p> <h2 id="class-cocoeval">class COCOEval</h2> <h3 id="init"><strong>init</strong></h3> <p>初始化函数写了一些 COCOEval 的使用方法,很贴心了</p> <pre><code class="language-Plain">The usage for CocoEval is as follows: cocoGt=..., cocoDt=... # load dataset and results E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object E.params.recThrs = ...; # set parameters as desired E.evaluate(); # run per image evaluation E.accumulate(); # accumulate per image results E.summarize(); # display summary metrics of results For example usage see evalDemo.m and http://mscoco.org/. The evaluation parameters are as follows (defaults in brackets): imgIds - [all] N img ids to use for evaluation catIds - [all] K cat ids to use for evaluation iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation recThrs - [0:.01:1] R=101 recall thresholds for evaluation areaRng - [...] A=4 object area ranges for evaluation maxDets - [1 10 100] M=3 thresholds on max detections per image iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints' iouType replaced the now DEPRECATED useSegm parameter. useCats - [1] if true use category labels for evaluation Note: if useCats=0 category labels are ignored as in proposal scoring. Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified. evaluate(): evaluates detections on every image and every category and concats the results into the "evalImgs" with fields: dtIds - [1xD] id for each of the D detections (dt) gtIds - [1xG] id for each of the G ground truths (gt) dtMatches - [TxD] matching gt id at each IoU or 0 gtMatches - [TxG] matching dt id at each IoU or 0 dtScores - [1xD] confidence of each dt gtIgnore - [1xG] ignore flag for each gt dtIgnore - [TxD] ignore flag for each dt at each IoU accumulate(): accumulates the per-image, per-category evaluation results in "evalImgs" into the dictionary "eval" with fields: params - parameters used for evaluation date - date evaluation was performed counts - [T,R,K,A,M] parameter dimensions (see above) precision - [TxRxKxAxM] precision for every evaluation setting recall - [TxKxAxM] max recall for every evaluation setting Note: precision and recall==-1 for settings with no gt objects. </code></pre> <h3 id="evaluate">evaluate</h3> <p>这是评估三部曲的第一步,直接调用这个函数进行评估</p> <pre><code class="language-Python">def evaluate(self): ''' Run per image evaluation on given images and store results (a list of dict) in self.evalImgs :return: None ''' tic = time.time() print('Running per image evaluation...') # 可以用 p.__dict__ 以 dict 的形式获取到类的成员 p = self.params # add backward compatibility if useSegm is specified in params if not p.useSegm is None: p.iouType = 'segm' if p.useSegm == 1 else 'bbox' print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) print('Evaluate annotation type *{}*'.format(p.iouType)) # numpy 转 list p.imgIds = list(np.unique(p.imgIds)) if p.useCats: # numpy 转 list p.catIds = list(np.unique(p.catIds)) p.maxDets = sorted(p.maxDets) self.params=p # 做一些准备工作,具体见下文 self._prepare() # loop through images, area range, max detection number catIds = p.catIds if p.useCats else [-1] if p.iouType == 'segm' or p.iouType == 'bbox': computeIoU = self.computeIoU elif p.iouType == 'keypoints': computeIoU = self.computeOks # 对所有图片所有类别的图像都进行 IoU 的计算 # computeIoU 得到的是对于同一张图片,gt 和 dt 的 IOU,所以得到的是 shape (#dt, #gt) 的 numpy 数组 # 这个操作对每一个类别和每一张图片都要做一次,所以 self.ious 的 key 的长度为 #val_set*#classes, COCO 就是 5000*80 self.ious = {(imgId, catId): computeIoU(imgId, catId) \ for imgId in p.imgIds for catId in catIds} evaluateImg = self.evaluateImg # 取的是最大的那个,100,代表每一个类别最多保留的检测框的数量 maxDet = p.maxDets[-1] # 进行单张图片单个类别的评估 # 得到的是一个 #img*#class*#areaRng的列表,COCO 是 5000*80*4(ALl, Small, Medium, Large) # 每一个列表存储着该图片特定类别的 dt 和 gt 匹配结果 self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet) for catId in catIds for areaRng in p.areaRng for imgId in p.imgIds ] # 深拷贝一份参数,结束函数 self._paramsEval = copy.deepcopy(self.params) toc = time.time() print('DONE (t={:0.2f}s).'.format(toc-tic)) </code></pre> <h3 id="_prepare">_prepare</h3> <p>这个函数做一些准备工作,定义一些变量</p> <pre><code class="language-Python">def _prepare(self): ''' Prepare ._gts and ._dts for evaluation based on params :return: None ''' def _toMask(anns, coco): # modify ann['segmentation'] by reference for ann in anns: rle = coco.annToRLE(ann) ann['segmentation'] = rle p = self.params # 如果 useCats 的话 # 用户可以自己选择进行什么类别的评估 if p.useCats: # 变成列表,每一个元素都是一个 json # 记录着对应的图片 id,坐标、得分、类别等等信息 # ['image_id', 'bbox', 'score', 'category_id', 'segmentation', 'area', 'id', 'iscrowd'] gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # 不用的话就不传入 catIds 参数,默认使用的是类别的结果 else: gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) # convert ground truth to mask if iouType == 'segm' if p.iouType == 'segm': _toMask(gts, self.cocoGt) _toMask(dts, self.cocoDt) # set ignore flag for gt in gts: gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0 gt['ignore'] = 'iscrowd' in gt and gt['iscrowd'] if p.iouType == 'keypoints': gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] # defaultdict 里面传入一个工厂函数,如果获取不到 key 的话不会报错而是返回函数的默认值 # 用在这里的话相当于图片中没有 gt 或 dt 的话就返回一个空的列表 # 一开始列表就是空的,所以可以直接用 append,比较方便,这里的 key 就是(gt['image_id'], gt['category_id']) self._gts = defaultdict(list) # gt for evaluation self._dts = defaultdict(list) # dt for evaluation for gt in gts: self._gts[gt['image_id'], gt['category_id']].append(gt) for dt in dts: self._dts[dt['image_id'], dt['category_id']].append(dt) # 这里也同样,存放的 key 是每一张图的 id 以及每一个类别组成的 tuple self.evalImgs = defaultdict(list) # per-image per-category evaluation results # 这个是存放最终结果的变量 self.eval = {} # accumulated evaluation results </code></pre> <h3 id="computeiou">computeIoU</h3> <p>单张图片单个类别的所有框两两计算 IoU,maxDets 参数限制了一个类别最多检测出来的框的数量,最终返回一个 numpy 数组,形状为 (#dt, #gt)</p> <pre><code class="language-Python">def computeIoU(self, imgId, catId): p = self.params if p.useCats: gt = self._gts[imgId,catId] dt = self._dts[imgId,catId] else: gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] if len(gt) == 0 and len(dt) ==0: return [] inds = np.argsort([-d['score'] for d in dt], kind='mergesort') dt = [dt[i] for i in inds] if len(dt) &gt; p.maxDets[-1]: dt=dt[0:p.maxDets[-1]] if p.iouType == 'segm': g = [g['segmentation'] for g in gt] d = [d['segmentation'] for d in dt] elif p.iouType == 'bbox': g = [g['bbox'] for g in gt] d = [d['bbox'] for d in dt] else: raise Exception('unknown iouType for iou computation') </code></pre> <h3 id="evaluateimg">evaluateImg</h3> <p>这个函数进行单张图片单个类别的评估,里面根据几个条件使得检测框和 gt 匹配,得到匹配信息,是个比较重要的函数</p> <pre><code class="language-Python">def evaluateImg(self, imgId, catId, aRng, maxDet): ''' perform evaluation for single category and image :return: dict (single image results) ''' p = self.params if p.useCats: gt = self._gts[imgId,catId] dt = self._dts[imgId,catId] else: gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] if len(gt) == 0 and len(dt) ==0: return None for g in gt: # 如果不是本次要匹配的 gt 对象的话,就给 _ignore 字段置 1 # 因为根据面积来计算 AP 时得挑选出面积在范围内的 gt,超过范围的不会和 dt 进行匹配 if g['ignore'] or (g['area']&lt;aRng[0] or g['area']&gt;aRng[1]): g['_ignore'] = 1 else: g['_ignore'] = 0 # sort dt highest score first, sort gt ignore last # 进行从小到大的排序,返回索引值,所以没被忽略的 gt 会排在前面 gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') gt = [gt[i] for i in gtind] # 从大到小排序,得分高的检测框在前面 dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') dt = [dt[i] for i in dtind[0:maxDet]] iscrowd = [int(o['iscrowd']) for o in gt] # load computed ious ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) &gt; 0 else self.ious[imgId, catId] # 下面是重点,遍历 dt 和 gt 存储和每个 dt iou 最大的 gt 的索引 # 同样也存储与每一个 gt iou 最大的 dt 索引 T = len(p.iouThrs) G = len(gt) D = len(dt) gtm = np.zeros((T,G)) dtm = np.zeros((T,D)) gtIg = np.array([g['_ignore'] for g in gt]) dtIg = np.zeros((T,D)) if not len(ious)==0: for tind, t in enumerate(p.iouThrs): for dind, d in enumerate(dt): # information about best match so far (m=-1 -&gt; unmatched) iou = min([t,1-1e-10]) # 初始时候没有匹配,m=-1 m = -1 for gind, g in enumerate(gt): # 三种情况进行筛选 # if this gt already matched, and not a crowd, continue # 这个 gt 已经被匹配上了,找下一个 if gtm[tind,gind]&gt;0 and not iscrowd[gind]: continue # if dt matched to reg gt, and on ignore gt, stop # 不明白这里为啥要 break 而不是 continue # 明白了!!前面已经对 gtind 进行了排序,如果当前 gt 是 ignore 的话 # 后面的 gt 也一定是 ignore,所以直接跳过了,妙! if m&gt;-1 and gtIg[m]==0 and gtIg[gind]==1: break # continue to next gt unless better match made # 找到和 dt 最大 iou 的那个 gt if ious[dind,gind] &lt; iou: continue # if match successful and best so far, store appropriately # 存储临时变量 iou=ious[dind,gind] m=gind # if match made, store id of match for both dt and gt if m ==-1: continue # 如果匹配成功的话就保存各自匹配的 id,dt 保存的是匹配上的 gt 的 id,gt 保存的是匹配上的 dt 的 id # 所以没匹配上的 dtm 就是与目标框的IoU 小于阈值的,或者匹配的 gt 已经被其他框匹配了。 # 没匹配上的框又不是忽略框的话会在后面被当成 false positive dtIg[tind,dind] = gtIg[m] dtm[tind,dind] = gt[m]['id'] gtm[tind,m] = d['id'] # set unmatched detections outside of area range to ignore # a 存储了超过面积范围的 dt a = np.array([d['area']&lt;aRng[0] or d['area']&gt;aRng[1] for d in dt]).reshape((1, len(dt))) # dtIg 存储了超过面积范围的 dt dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0))) # store results for given image and category # 所以这里就返回了一些信息,表示该图片中,对该类别的一些匹配情况,方便进行后续的分析 return { 'image_id': imgId, 'category_id': catId, 'aRng': aRng, 'maxDet': maxDet, 'dtIds': [d['id'] for d in dt], 'gtIds': [g['id'] for g in gt], 'dtMatches': dtm, 'gtMatches': gtm, 'dtScores': [d['score'] for d in dt], 'gtIgnore': gtIg, 'dtIgnore': dtIg, } </code></pre> <h3 id="accumulate">accumulate</h3> <p>这个函数对刚刚 <code class="language-plaintext highlighter-rouge">evaluate</code> 的中间结果进行累加,来求详细的评估指标。</p> <pre><code class="language-Python">def accumulate(self, p = None): ''' Accumulate per image evaluation results and store the result in self.eval :param p: input params for evaluation :return: None ''' print('Accumulating evaluation results...') tic = time.time() if not self.evalImgs: print('Please run evaluate() first') # allows input customized parameters if p is None: p = self.params p.catIds = p.catIds if p.useCats == 1 else [-1] # 初始化一些变量,把 metrics 默认定义为 -1, T = len(p.iouThrs) R = len(p.recThrs) K = len(p.catIds) if p.useCats else 1 A = len(p.areaRng) M = len(p.maxDets) # 求出来的东西是针对每一个 iou 阈值,每一个 recall 阈值,每一个面积范围,每一个类别,每一个最大检测数量的列表 precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories recall = -np.ones((T,K,A,M)) scores = -np.ones((T,R,K,A,M)) # create dictionary for future indexing _pe = self._paramsEval catIds = _pe.catIds if _pe.useCats else [-1] setK = set(catIds) setA = set(map(tuple, _pe.areaRng)) setM = set(_pe.maxDets) setI = set(_pe.imgIds) # get inds to evaluate k_list = [n for n, k in enumerate(p.catIds) if k in setK] m_list = [m for n, m in enumerate(p.maxDets) if m in setM] # 所有面积范围列表 a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA] # 所有图片 id 列表 i_list = [n for n, i in enumerate(p.imgIds) if i in setI] I0 = len(_pe.imgIds) A0 = len(_pe.areaRng) # retrieve E at each category, area range, and max number of detections # 进行循环,求出每种情况下的指标 for k, k0 in enumerate(k_list): Nk = k0*A0*I0 for a, a0 in enumerate(a_list): Na = a0*I0 for m, maxDet in enumerate(m_list): # self.evalImgs 得到的是一个 #img*#class*#areaRng的列表,所以要根据上面计算出来的偏移量取值 # 一般取出来的 E 的长度就是验证集的图片数 E = [self.evalImgs[Nk + Na + i] for i in i_list] E = [e for e in E if not e is None] if len(E) == 0: continue # 把整个数据集的检测框的得分都拿出来进行从大到小的排序,shape: (#all_dets_in_the_dataset,) dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E]) # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. inds = np.argsort(-dtScores, kind='mergesort') dtScoresSorted = dtScores[inds] # 同上,也是将整个数据集的结果进行合并,shape: (#iou_thres, #all_dets_in_the_dataset) dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds] dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds] # 同上,shape: (#all_gts_in_the_dataset,) gtIg = np.concatenate([e['gtIgnore'] for e in E]) # 当前的 gt 数量,如果全是忽略的话,就跳过这一次评估 npig = np.count_nonzero(gtIg==0 ) if npig == 0: continue # 计算出当前的 tp 和 fp tps = np.logical_and( dtm, np.logical_not(dtIg) ) fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) ) # 按行累加,统计出到当前索引为止的 tp 和 fp 的总数量 tp_sum = np.cumsum(tps, axis=1).astype(dtype=float) fp_sum = np.cumsum(fps, axis=1).astype(dtype=float) # 对于每一个 iou 阈值来计算评估指标 for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): tp = np.array(tp) fp = np.array(fp) # 所有图片的检测框的数量 nd = len(tp) # 根据公式计算出每一次的 recall 和 precision rc = tp / npig pr = tp / (fp+tp+np.spacing(1)) # 存储结果,shape 为进行插值的 recall 的 size,也就是 [0.1: 0.01: 1] 的 101 个值 q = np.zeros((R,)) ss = np.zeros((R,)) # 召回率 if nd: recall[t,k,a,m] = rc[-1] else: recall[t,k,a,m] = 0 # numpy is slow without cython optimization for accessing elements # use python array gets significant speed improvement pr = pr.tolist(); q = q.tolist() # 对 recall 进行插值(mAP 的计算是计算 pr 曲线插值的矩形的面积) # COCO 的话就是利用 101 个 recall 值进行插值 [0.1: 0.01: 1] # 其实最终的 mAP 就相当于所有 recall 点上的 precision 的 mean # 因为默认是计算个各个矩形的面积相加,这里矩形的宽是 0.01,刚好可以当成 1/100,100 可以看成是 recall 插值点的个数 for i in range(nd-1, 0, -1): if pr[i] &gt; pr[i-1]: pr[i-1] = pr[i] # 将 p.recThrs 按大小插入到 rc 左边,返回索引值 inds = np.searchsorted(rc, p.recThrs, side='left') try: # q 存储经过插值后的 precision 值 # ss 存储得分 for ri, pi in enumerate(inds): q[ri] = pr[pi] ss[ri] = dtScoresSorted[pi] # 取不到的值就默认为 0 except: pass # 进行保存,shape (T,R,K,A,M) # 分别代表 iou 阈值数量,recall 阈值数量,类别数,面积范围数,最大检测框数量 precision[t,:,k,a,m] = np.array(q) scores[t,:,k,a,m] = np.array(ss) self.eval = { 'params': p, 'counts': [T, R, K, A, M], 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'precision': precision, 'recall': recall, 'scores': scores, } toc = time.time() print('DONE (t={:0.2f}s).'.format( toc-tic)) </code></pre> <h3 id="summarize">summarize</h3> <p>函数将上一步评估好的指标进行总结,以格式化输出。</p> <pre><code class="language-Python">def summarize(self): ''' Compute and display summary metrics for evaluation results. Note this functin can *only* be applied on the default parameter setting ''' def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): p = self.params iStr = ' {:&lt;18} {} @[ IoU={:&lt;9} | area={:&gt;6s} | maxDets={:&gt;3d} ] = {:0.3f}' titleStr = 'Average Precision' if ap == 1 else 'Average Recall' typeStr = '(AP)' if ap==1 else '(AR)' iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ if iouThr is None else '{:0.2f}'.format(iouThr) aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] if ap == 1: # dimension of precision: [TxRxKxAxM] s = self.eval['precision'] # IoU if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] s = s[t] s = s[:,:,:,aind,mind] else: # dimension of recall: [TxKxAxM] s = self.eval['recall'] if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] s = s[t] s = s[:,:,aind,mind] if len(s[s&gt;-1])==0: mean_s = -1 else: # 最终的 mAP 就相当于所有 recall 点上的 precision 的 mean # 因为默认是计算个各个矩形的面积相加,这里矩形的宽是 0.01,刚好可以当成 1/100,100 可以看成是 recall 插值点的个数 mean_s = np.mean(s[s&gt;-1]) print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) return mean_s def _summarizeDets(): # 前面几个算的是 AP stats = np.zeros((12,)) stats[0] = _summarize(1) stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2]) stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2]) stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2]) stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2]) # 后面几个算的是 recall stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2]) stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2]) return stats def _summarizeKps(): stats = np.zeros((10,)) stats[0] = _summarize(1, maxDets=20) stats[1] = _summarize(1, maxDets=20, iouThr=.5) stats[2] = _summarize(1, maxDets=20, iouThr=.75) stats[3] = _summarize(1, maxDets=20, areaRng='medium') stats[4] = _summarize(1, maxDets=20, areaRng='large') stats[5] = _summarize(0, maxDets=20) stats[6] = _summarize(0, maxDets=20, iouThr=.5) stats[7] = _summarize(0, maxDets=20, iouThr=.75) stats[8] = _summarize(0, maxDets=20, areaRng='medium') stats[9] = _summarize(0, maxDets=20, areaRng='large') return stats if not self.eval: raise Exception('Please run accumulate() first') iouType = self.params.iouType if iouType == 'segm' or iouType == 'bbox': summarize = _summarizeDets elif iouType == 'keypoints': summarize = _summarizeKps self.stats = summarize() def __str__(self): self.summarize() </code></pre> <h2 id="class-params">class Params</h2> <p>这是 COCOEval 默认的参数结构体,主要是根据传进来的 <code class="language-plaintext highlighter-rouge">iouType</code> 参数来调用相应的参数初始化函数,</p> <pre><code class="language-Python">class Params: ''' Params for coco evaluation api ''' def setDetParams(self): self.imgIds = [] self.catIds = [] # np.arange causes trouble. the data point on arange is slightly larger than the true value self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) # 用来计算 AR 的参数 self.maxDets = [1, 10, 100] # 面积的范围,分别对应所有、小目标、中目标、大目标 self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] self.areaRngLbl = ['all', 'small', 'medium', 'large'] # 是否用到类别标签,计算 Recall 的话会置 0 self.useCats = 1 def setKpParams(self): self.imgIds = [] self.catIds = [] # np.arange causes trouble. the data point on arange is slightly larger than the true value self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) self.maxDets = [20] self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] self.areaRngLbl = ['all', 'medium', 'large'] self.useCats = 1 self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0 def __init__(self, iouType='segm'): if iouType == 'segm' or iouType == 'bbox': self.setDetParams() elif iouType == 'keypoints': self.setKpParams() else: raise Exception('iouType not supported') self.iouType = iouType # useSegm is deprecated self.useSegm = None </code></pre> <h2 id="小结-1">小结</h2> <p>以上就是 COCOEval 对检测结果的评估方式,也就是调用三板斧就能得到最终的结果。完全掌握了上述的过程之后我们就可以按照我们想要的方式对结果进行评估了,也就是在评估之前传入指定的参数(iou 阈值、类别、面积范围、最大检测数量等等),我们的结果主要保存在 stats 以及 precision 里面,前者是一个长度为 11 的数组,每一个元素的意义都是固定的,后者是一个 5 维的向量,我们可以传入参数来得到我们想要的具体维度的结果。</p> <h1 id="自定义评估方式">自定义评估方式</h1> <p>比如 mmdet 评估 COCO 有给到一个 <code class="language-plaintext highlighter-rouge">classwise</code>选项,把这个选项打开之后呢可以看到每一个类别的 AP,但是默认是 mAP,假如我们想要 AP50 的结果就可以修改此处的代码,下面是我修改之后的,把所有的 AP 都计算出来了。</p> <pre><code class="language-Python">if classwise: # Compute per-category AP # Compute per-category AP # from https://github.com/facebookresearch/detectron2/ precisions = cocoEval.eval['precision'] # precision: (iou, recall, cls, area range, max dets) assert len(self.cat_ids) == precisions.shape[2] results_per_category = [] for idx, catId in enumerate(self.cat_ids): t = [] # area range index 0: all area ranges # max dets index -1: typically 100 per image nm = self.coco.loadCats(catId)[0] precision = precisions[:, :, idx, 0, -1] precision = precision[precision &gt; -1] if precision.size: ap = np.mean(precision) else: ap = float('nan') t.append(f'{nm["name"]}') t.append(f'{float(ap):0.3f}') for iou in [0, 5]: precision = precisions[iou, :, idx, 0, -1] precision = precision[precision &gt; -1] if precision.size: ap = np.mean(precision) else: ap = float('nan') t.append(f'{float(ap):0.3f}') for area in [1, 2, 3]: precision = precisions[:, :, idx, area, -1] precision = precision[precision &gt; -1] if precision.size: ap = np.mean(precision) else: ap = float('nan') t.append(f'{float(ap):0.3f}') results_per_category.append(tuple(t)) num_columns = len(results_per_category[0]) results_flatten = list( itertools.chain(*results_per_category)) headers = ['category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] results_2d = itertools.zip_longest(*[ results_flatten[i::num_columns] for i in range(num_columns) ]) table_data = [headers] table_data += [result for result in results_2d] table = AsciiTable(table_data) print_log('\n' + table.table, logger=logger) </code></pre> <p>对比一下,这是之前默认的表格,只显示出每一个类别的 mAP 的值。</p> <p><img src="https://user-images.githubusercontent.com/33142987/228175045-233bdf3e-4449-40b3-988e-456895edff6f.png" alt="img" /></p> <p>这是我修改之后的,每个类别的所有 AP 值都显示出来了。</p> <p><img src="https://user-images.githubusercontent.com/33142987/228175111-b0676f02-30bc-4358-9ef7-fdcbcbd0adcf.png" alt="img" /></p> Mon, 20 Mar 2023 00:00:00 +0000 https://yarkable.github.io/2023/03/20/MMDetection-&-pycocotools-eval-%E8%AF%A6%E8%A7%A3/ https://yarkable.github.io/2023/03/20/MMDetection-&-pycocotools-eval-%E8%AF%A6%E8%A7%A3/ linux object detection deep learning mmdetection 使用Proxychain4进行网络代理 <h2 id="背景">背景</h2> <p>学校的个人账号只能在一台设备上进行认证联网,但是我们使用的 GPU 服务器经常需要访问互联网,在服务器上认证之后我们自己的电脑就会掉线,所以可以通过代理的方式让服务器通过我们自己的设备进行联网,解决这个问题。</p> <blockquote> <p>本文在实验室师弟写的 pdf 版本教程上改编而来,方便自己查阅</p> </blockquote> <h2 id="安装软件">安装软件</h2> <ol> <li>主要是通过 proxychains-ng 来转发网络请求,可以通过 git 下载也可以直接下载压缩包。</li> </ol> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>git clone https://github.com/rofl0r/proxychains-ng </code></pre></div></div> <ol> <li>然后进入软件目录,用 <code class="language-plaintext highlighter-rouge">pwd</code> 命令看一下当前的绝对路径,这个在下一步中要用到</li> <li>进入目录执行命令,这里的 pwd 就是上一步输出的绝对路径,<strong>要输绝对路径</strong>,不然后面编译的时候会出错</li> </ol> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>./configure <span class="nt">--prefix</span><span class="o">=</span><span class="nb">pwd</span> <span class="nt">--sysconfdir</span><span class="o">=</span><span class="nb">pwd</span> </code></pre></div></div> <ol> <li>安装二进制文件(make install-config 之后会生成一个配置文件 proxychains.conf)</li> </ol> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>make <span class="nt">-j</span> make <span class="nb">install </span>make install-config </code></pre></div></div> <h2 id="配置">配置</h2> <p>进入安装目录找到配置文件 proxychains.conf,进行编辑,在底部添加需要代理的设备的 ip 和端口,我使用的 clash,是 socks 代理,所以我的配置是</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>socks5 172.31.xx.xx 7879 </code></pre></div></div> <p>那么我们自己的设备上也需要打开代理软件才能让服务器访问到网络,在 clash 中打开 <code class="language-plaintext highlighter-rouge">Allow LAN</code>, v2ray 中打开 <code class="language-plaintext highlighter-rouge">允许局域网的连接</code> 就行了。这样我们的设备可以访问的东西,服务器都可以访问到。</p> <p>此外,我们还需要在 bash 配置文件中加入二进制文件的路径,不然运行时会找不到文件(如果是通过管理员装的,则不用这一步)</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>vim ~/.bashrc <span class="nb">export </span><span class="nv">PATH</span><span class="o">=</span>/data/xxx/proxychains/bin:<span class="nv">$PATH</span> <span class="nb">export </span><span class="nv">PROXYCHAINS_CONF_FILE</span><span class="o">=</span>/data/xxx/proxychains/proxychains.conf </code></pre></div></div> <p>完事以后重新打开一个终端就生效了,<code class="language-plaintext highlighter-rouge">source ~/.bashrc</code> 我试过是没有效果的,建议直接新开一个终端使用。</p> <h2 id="使用">使用</h2> <p>在想要代理网络的时候就在命令前加上 <code class="language-plaintext highlighter-rouge">proxychains4</code> 就可以了,例如</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>proxychains4 curl cip.cc proxychains4 python main.py </code></pre></div></div> <h2 id="troubleshoot">troubleshoot</h2> <p>在使用的时候报错找不到 proxychains.conf 的,基本上都是编译的时候没有填绝对路径而是填了相对路径,用 <code class="language-plaintext highlighter-rouge">make uninstall</code> 以及 <code class="language-plaintext highlighter-rouge">make clean</code> 把刚刚生成的东西给删了,然后重新运行上述的安装步骤,一定要填绝对路径。</p> Sun, 19 Mar 2023 00:00:00 +0000 https://yarkable.github.io/2023/03/19/%E4%BD%BF%E7%94%A8Proxychain4%E8%BF%9B%E8%A1%8C%E7%BD%91%E7%BB%9C%E4%BB%A3%E7%90%86/ https://yarkable.github.io/2023/03/19/%E4%BD%BF%E7%94%A8Proxychain4%E8%BF%9B%E8%A1%8C%E7%BD%91%E7%BB%9C%E4%BB%A3%E7%90%86/ ssh linux tools MacFUSE+sshfs让Mac管理远程文件 <h2 id="背景">背景</h2> <p>在 MacBook 上开发变多,经常需要查看远程服务器上的图片,因此挂载远程目录是最方便的做法。Windows 上有 sshfs manager 这么优秀的 GUI 方便操作,但是 macOS 系统没有类似的替代品(要么就是太古老用不了),所以得用命令行手动挂载,这里记录一下。</p> <h2 id="安装软件">安装软件</h2> <p>要映射远程文件夹,需要两个工具:MacFUSE 和 sshfs。这两个文件都可以从 <a href="https://osxfuse.github.io/">osxfuse 网站</a>上下载,傻瓜式安装就行。先安装完 MacFUSE,在设置界面拉到最下面会有 MacFUSE 的 logo,然后重启电脑,安装 sshfs。</p> <h2 id="映射文件夹">映射文件夹</h2> <p>跟 windows 一样的,但是这里要加上 <code class="language-plaintext highlighter-rouge">-ovolname</code> 参数,方便辨认,不然的话默认的名字是 <code class="language-plaintext highlighter-rouge">macFUSE Volume 0 (sshfs)</code>,很难记,而且映射多个文件夹时很难记住映射对应的文件夹。</p> <div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>sshfs &lt;用户名&gt;@&lt;服务器&gt;:&lt;服务器上的绝对路径&gt; &lt;本地目标文件夹&gt; <span class="nt">-ovolname</span><span class="o">=</span>&lt;映射后的文件夹名称&gt; </code></pre></div></div> <p>取消映射的话可以在 finder 里面右键点击对应的盘符然后<code class="language-plaintext highlighter-rouge">推出xxx文件夹</code>,也可以在命令行里直接 <code class="language-plaintext highlighter-rouge">umount 文件夹</code>。挂载的文件夹很多的话可以写个脚本方便管理。</p> <h2 id="reference">reference</h2> <p>https://xmanyou.com/mac-mount-remote-folder/</p> Sun, 12 Mar 2023 00:00:00 +0000 https://yarkable.github.io/2023/03/12/MacFUSE+sshfs%E8%AE%A9Mac%E7%AE%A1%E7%90%86%E8%BF%9C%E7%A8%8B%E6%96%87%E4%BB%B6/ https://yarkable.github.io/2023/03/12/MacFUSE+sshfs%E8%AE%A9Mac%E7%AE%A1%E7%90%86%E8%BF%9C%E7%A8%8B%E6%96%87%E4%BB%B6/ ssh linux macOS tools 棒棒鸡的2022年小结 <p>– 写于 2023 年初</p> <p>2022 年对很多国人来说是难忘的一年,在疫情反反复复肆虐的第三年,国家终于一反之前严防死守的态度,开始对疫情进行开放政策,每个人的健康问题回到了自己的手上。在时代的大背景下,每个人也都在经历着属于自己的长征。</p> <p>农历 2022 年是我的本命年,于我而言,这个时间点属于研二下与研三上的阶段。按照以往的经验,2022 年农历过年期间我们这一届学生就要开始找实习了,因为一个不错的大厂实习会对秋招找工作更加分。因此我也是一样,过年期间断断续续刷 leetcode 算法题,一个一个专题刷,我本身在研二上的时候就已经开始了一些,所以会比同届同学的进度更快些。后来为了抱团取暖,互相鼓励,我把实验室这一届所有要找工作的同学都拉在一起组了个找工作的群,在里面交流找工作咨询信息、算法题解、每日一题等等。</p> <p>对于我们找研发岗位的同学来说,leetcode 算法题就是必须项,笔试面试都会有,虽然工作的时候不太会用到上面的知识。除了刷题就是面试了,面试得准备简历和面经。我记得我应该是 3 月中旬把我的初版简历给完成就各种投简历了,写简历也是个技术活,如何准确地抓住关键包装自己的项目论文决定了面试官看简历的时候给不给面试的机会。简历这块我应该做的还不错,参考了之前的师兄的模版,所以除了那种硬性刷学历的厂,我一般都能拿到面试资格。每一次面试都是一种积累,遇到过很多技术很牛的面试官,一下子就一针见血的指出问题来,问到要点。每次面试完也得总结经验和问题,形成面经,以备下一次面试使用。</p> <p>就这样,过完年的那段时间我的时间差不多都在准备找实习。本来是应该去学校了,没想到深圳被香港的货车司机给偷家了,感染了一大批新冠的,不得已,全市都放慢了脚步,公司开始居家办公,学校不让学生返校,我只能在家准备这些了。与此同时,我和高通公司合作的一篇 CCF-B 论文结果也出来了,没中,这篇没中真是麻烦,简历上少了一篇 B 类论文,而且还得跟高通那边商量转投新的会议,总之,最后简单修改了一下投了 IJCB2023,当时想着赶紧结束这个工作,主要重心还是在找实习上。对了,过完元宵以后,我妈也去县城上班了,家里只有我和奶奶两个人了,奶奶烧的饭不好吃,所以前期是我妈晚上回来给我烧饭,后期就是我妈把好几天的菜买好,我自己烧饭吃了。所以每天就相当于是自己照顾自己,同时还得兼顾学校的事情。在家的日子就这么过着,偶尔会和女朋友去图书馆学习,或者出去逛逛,简单倒也还挺充实。</p> <p>大概在四月多的时候,我的面试已经差不多了,当时只投了深圳的一些中大厂,或者其他 base 但是能够远程的公司,拿到了一些 offer,其中最满意的是字节的 offer,公司就在旁边,实习 400 一天的工资,而且做的事情也很符合我,所以在收到 offer letter 的时候我非常开心,后来把上海 AI Lab 的 offer 给拒了。但是接下去就有难题了,字节需要我尽快入职,但是当时由于我家有疫情,学校不让我回去,而且就算我回去了,那就直接是封闭式管理了,不能出来实习(虽然可以远程,但是和组里人都没见面的情况下就远程这样肯定不太好)。还好实验室有个师兄和我在一个部门不同组,向他打听了一些内部情况,觉得这个 offer 一定不能放弃。于是我就在想着快点回深圳入职,最终在经历了一番波折与一些幸运之后,5 月 9 号我来到了深圳湾科技创新中心,深圳字节的 base 地,开始了我第一天的工作,这时离我放寒假已经过去了快半年了,真的从来没放过这么久的假。晚上我就住附近的公寓,租了个床位,75 一天,一个房间里 6 个床位,但是一般没有住满,环境和设施倒是挺好,如果不碰上奇怪的室友的话还是很舒服的。地铁离公司只有两个站,前期我一般十点出门,然后十点半之前到公司。晚上十点下班回公寓,由于对公司的显卡平台环境不太熟悉,所以回去后我还会加一会儿班。</p> <p>大概在公寓住了两个星期吧,老家解除了疫情防控圈层,我就可以回到学校住了,那离公司就更近了,骑个小黄车速度快的话 3 分钟就到了,所以早上可以睡得更晚些。晚上大家一般也都是十点回去的,因为这个点打车刚好能报销。字节这是我人生第一份正儿八经的实习,也是一份大厂实习,在里面的这段日子确实学到了不少东西。组里分配了一个很好很负责的 mentor 带我,刚过去的时候团队 leader 在北京,所以深圳这边的 base 一直都是处于没有 leader 的状态,我是跟一个同事 jt 一起做项目,然后就跟着 mentor 搞一个科研任务,不得不说,和我合作的这些人都很 nice,十分照顾我,所以在字节的工作还是很不错的,有人带着。后来由于一些原因,我的 mentor 离职了,我换了一个 mentor,带着我继续做科研任务,不过 mentor 离职后也依然每周都会跟我开例会过进度一起讨论,这点是真的很难得,我非常感谢。</p> <p>我在字节本来也没有一个明确的实习期限,都是走一步看一步,在里面边干活边准备秋招,不仅要准备外面的机会,内部转正的机会也要争取。23 届的校招确实是最难的一届,很多厂都缩招,甚至裁员,但是我们这届研究生又是扩招了的,所以僧多粥少,难度十分大。我们实验室很早就开始准备笔试面试了,但是由于双非院校的关系,有些厂还是没有面试机会。我还是比较好的,第一个 offer 就是大疆给的。我那会儿在牛客的签名是 “DJI Dream Job”,这下可谓是圆梦了。有了一个 offer 以后压力就没那么大了,后面都是比较佛系的面试。字节转正答辩的反馈是比较好的,但是后来感觉这边的情况不太乐观,有裁员的声音传出来,一直在说降本增效,而且 offer 也迟迟不发下来,决定还是不等了。直到 DJI 开奖之后,我的秋招正式结束,第一个 offer 就是我最好的 offer,以后就去天空之城了!</p> <p>字节答辩完了之后,我和 mentor 一起把科研工作整理成了一篇论文投了 CVPR,那段时间每天都挺晚睡的,学校这边导师和 jh 都在帮忙改论文。还有个小插曲,本来实习这件事一直是瞒着导师偷偷去的,后面给导师改论文肯定瞒不住了,找了个好机会直接跟导师简单坦白了,结果导师没说啥,要知道他以前可是不同意我们去实习的哈哈。论文结束之后,其实我也没啥事干了,我不可能再实习个几个月了,所以 leader 也没有派比较核心的项目给我做了,让我在业务数据集上将我的科研工作落地。搞了一阵子实在觉得没啥意思,便打算提离职回去写毕业论文 &amp; 享受大学生活了。离职前给每个人都准备了一份礼物,本来想请部门几个有合作的人一起吃个饭的,后面由于疫情开放,一大片人都阳了,就作罢,简单地告别了这个待了半年多的地方。</p> <blockquote> <p>提离职的时候,组里的小组长跟我进行了一些谈话,肯定了我的贡献,也对我职业生涯规划给了一些很中肯的建议,我很感谢有前辈愿意说出这些话教诲新人,在此记录一下:</p> <ol> <li>不要把 leader 看成是上级,当成同事或者朋友来对待,不要害怕对话交流,要积极主动去汇报进度</li> <li>对技术要有深度上的追求,不要有惰性,要追求极致,这样才能够走得更远</li> <li>会说跟会做同样重要。要学会把自己的东西展示出来,让大家看到你的贡献。</li> </ol> </blockquote> <p>那会儿我的 IJCB 又没中,改投 MMM2023,还好中了。加上在字节这边的一篇工作,便凑够了两个工作,可以写一下毕业论文了。但是正准备好好写一下的时候,我也阳了,虽然症状不重,但是前几天还是有点头晕打不起精神。那几天便和初中的朋友们一起吃鸡,放松一下自己。这时候学校也开始遣返学生了,只有一部分人得到批准留校,所以学校的人也是比较少,女朋友 jennie 每天都和我待在一起吃饭学习看电视,她备考教资。也是在这一段时间里,她上岸了深圳的小学教师,结束了相对漫长的秋招。</p> <p>往后便是一些流水账,在 1.10 我们也回到了老家,同往年一样,回家的第一天总是充满落差,今年尤其是,具体的就不想回味了,但我知道离家时肯定又是十分怀念在家中的时光!2022 年急着去深圳是为了字节实习,字节实习结束之后又回到了家乡,不禁感叹时光荏苒!总的来说,我觉得这一年很像是我大三的时光,那会儿进了 RobotPilots,直接改变了大学的轨迹。希望以后依然做一个善良正直的人,不管会有怎样的成就。对了,本命年已经过去了,后面的日子应该要更加幸运才对。</p> Tue, 24 Jan 2023 00:00:00 +0000 https://yarkable.github.io/2023/01/24/%E6%A3%92%E6%A3%92%E9%B8%A1%E7%9A%842022%E5%B9%B4%E5%B0%8F%E7%BB%93/ https://yarkable.github.io/2023/01/24/%E6%A3%92%E6%A3%92%E9%B8%A1%E7%9A%842022%E5%B9%B4%E5%B0%8F%E7%BB%93/ say something Macbook配合VSCode和MacTeX本地编译LaTeX <h3 id="preface">preface</h3> <p>基于 overleaf 在线写 latex 太麻烦了,每次都需要重新编译一下要等很久,本地的话就很快。之前在 windows 上有试过 vscode 插件加上 texlive 进行本地编译,现在主要用的是 MacBook,所以记录一下在 MacBook 上通过 vscode 插件加上 latex 编译器实现本地的编译。</p> <h3 id="需要的东西">需要的东西</h3> <h4 id="vscode-插件-latex-workshop">vscode 插件: latex workshop</h4> <p>直接插件市场搜索安装就行</p> <h4 id="latex-编译器">latex 编译器</h4> <p>大家选的都是 mactex,有两种安装方式,一种是<a href="https://media.icml.cc/Conferences/CVPR2023/cvpr2023-author_kit-v1_1-1.zip">官网下载 pkg</a>,一种是用 brew 安装。</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>brew <span class="nb">install</span> <span class="nt">--cask</span> mactex-no-gui </code></pre></div></div> <p>装好之后将可执行程序添加到 $PATH 当中,不然会找不到</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>vim ~/.bash_profile <span class="nb">export </span><span class="nv">PATH</span><span class="o">=</span>/Library/Tex/texbin:<span class="nv">$PATH</span> <span class="nb">source</span> ~/.bash_profile </code></pre></div></div> <p>然后这里就完事了,接下去去配置 vscode</p> <h3 id="vscode-配置">VSCode 配置</h3> <p>在 setting 的 json 里面输入下面内容(command + shift + p)</p> <div class="language-json highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nl">"latex-workshop.latex.tools"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"latexmk"</span><span class="p">,</span><span class="w"> </span><span class="nl">"command"</span><span class="p">:</span><span class="w"> </span><span class="s2">"latexmk"</span><span class="p">,</span><span class="w"> </span><span class="nl">"args"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="s2">"-synctex=1"</span><span class="p">,</span><span class="w"> </span><span class="s2">"-interaction=nonstopmode"</span><span class="p">,</span><span class="w"> </span><span class="s2">"-file-line-error"</span><span class="p">,</span><span class="w"> </span><span class="s2">"-pdf"</span><span class="p">,</span><span class="w"> </span><span class="s2">"%DOC%"</span><span class="w"> </span><span class="p">]</span><span class="w"> </span><span class="p">},</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"cd"</span><span class="p">,</span><span class="w"> </span><span class="nl">"command"</span><span class="p">:</span><span class="w"> </span><span class="s2">"cd"</span><span class="p">,</span><span class="w"> </span><span class="nl">"args"</span><span class="w"> </span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="s2">"%DIR%"</span><span class="p">]</span><span class="w"> </span><span class="p">},</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"pdflatex"</span><span class="p">,</span><span class="w"> </span><span class="nl">"command"</span><span class="p">:</span><span class="w"> </span><span class="s2">"pdflatex"</span><span class="p">,</span><span class="w"> </span><span class="nl">"args"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="s2">"-synctex=1"</span><span class="p">,</span><span class="w"> </span><span class="s2">"-interaction=nonstopmode"</span><span class="p">,</span><span class="w"> </span><span class="s2">"-file-line-error"</span><span class="p">,</span><span class="w"> </span><span class="s2">"%DOC%"</span><span class="w"> </span><span class="p">]</span><span class="w"> </span><span class="p">},</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"bibtex"</span><span class="p">,</span><span class="w"> </span><span class="nl">"command"</span><span class="p">:</span><span class="w"> </span><span class="s2">"bibtex"</span><span class="p">,</span><span class="w"> </span><span class="nl">"args"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="s2">"%DOCFILE%"</span><span class="w"> </span><span class="p">]</span><span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="p">]</span><span class="err">,</span><span class="w"> </span><span class="nl">"latex-workshop.latex.recipes"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"latexmk"</span><span class="p">,</span><span class="w"> </span><span class="nl">"tools"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="s2">"latexmk"</span><span class="w"> </span><span class="p">]</span><span class="w"> </span><span class="p">},</span><span class="w"> </span><span class="p">{</span><span class="w"> </span><span class="nl">"name"</span><span class="p">:</span><span class="w"> </span><span class="s2">"pdflatex -&gt; bibtex -&gt; pdflatex*2"</span><span class="p">,</span><span class="w"> </span><span class="nl">"tools"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w"> </span><span class="s2">"cd"</span><span class="p">,</span><span class="w"> </span><span class="s2">"pdflatex"</span><span class="p">,</span><span class="w"> </span><span class="s2">"bibtex"</span><span class="p">,</span><span class="w"> </span><span class="s2">"pdflatex"</span><span class="p">,</span><span class="w"> </span><span class="s2">"pdflatex"</span><span class="w"> </span><span class="p">]</span><span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="p">]</span><span class="err">,</span><span class="w"> </span></code></pre></div></div> <p>然后就完事了,command + option + b 编译,command + option + v 查看 pdf</p> <h3 id="遇到的坑">遇到的坑</h3> <p>因为我是将整个项目都放在 iCloud 里面方便进行同步的,但是这样会报错</p> <pre><code class="language-txt">Latexmk: Filename '/Users/bytedance/Library/Mobile Documents/com~apple~CloudDocs/my_work/CrossDataset/cvpr2023-author_kit-v1_1-1/latex/PaperForReview' contains character not allowed for TeX file. Latexmk: Stopping because of bad filename(s). Rc files read: NONE Latexmk: This is Latexmk, John Collins, 17 Mar. 2022. Version 4.77, version: 4.77. </code></pre> <p>查看了 GitHub 上的 issue 发现是因为 Apple 对 iCloud 文件夹会添加一些奇怪的字符,一种曲线救国的方法就是给 iCloud 生成一个软链接,然后从软链接进去就能解决这个问题(一定要从软链接的根目录进去,不能从子文件夹进去,否则还是会报错)</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">ln</span> <span class="nt">-s</span> /Users/xxx/Library/Mobile<span class="se">\ </span>Documents/com~apple~CloudDocs/ iCloud code iCloud/my_work/xxx </code></pre></div></div> <h3 id="其他">其他</h3> <p>建议在 VSCode 安装一个 Grammly 插件,配合写作效果更佳。</p> <h3 id="reference">reference</h3> <p>https://github.com/James-Yu/LaTeX-Workshop/issues/234</p> <p>https://blog.csdn.net/qq_31460257/article/details/81592812</p> <p>https://zhuanlan.zhihu.com/p/102823687</p> Fri, 30 Sep 2022 00:00:00 +0000 https://yarkable.github.io/2022/09/30/Macbook%E9%85%8D%E5%90%88VSCode%E5%92%8CMacTeX%E6%9C%AC%E5%9C%B0%E7%BC%96%E8%AF%91LaTeX/ https://yarkable.github.io/2022/09/30/Macbook%E9%85%8D%E5%90%88VSCode%E5%92%8CMacTeX%E6%9C%AC%E5%9C%B0%E7%BC%96%E8%AF%91LaTeX/ troubleshoot python 常见问答题 <h2 id="1-list-set-dict-的查询效率">1. list set dict 的查询效率</h2> <blockquote> <p>在一个长度为一百万(<code class="language-plaintext highlighter-rouge">1000000</code>)的列表中搜索某个元素是否存在,用哪个数据结构的速度是最快的呢</p> </blockquote> <p>答案是 set 最快,dict 其次,list 是最慢的。</p> <p>set 对列表进行了去重,底层相当于是一颗红黑树,复杂度 O(logn);</p> <p>dict 对 key 进行了 hash,然后再通过 hash 的结果进行二分查找从而确定是否有这个 key,其查询的复杂度是 O(logn),并不是的 O(1),O(1) 是已知存在 key 的情况下查 value 的情况;</p> <p>list 就是纯纯的 O(n)。</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">time</span> <span class="c1"># 生成随机数 </span><span class="n">nums</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">sample</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1000000000</span><span class="p">),</span> <span class="mi">10000000</span><span class="p">)</span> <span class="c1"># print(nums) </span> <span class="n">my_list</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">my_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span> <span class="n">my_dict</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># 赋值 </span><span class="n">my_list</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">nums</span><span class="p">)</span> <span class="n">my_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">nums</span><span class="p">)</span> <span class="n">my_dict</span> <span class="o">=</span> <span class="p">{</span> <span class="n">i</span> <span class="p">:</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">nums</span><span class="p">}</span> <span class="c1"># set查询效率 </span><span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10000000</span><span class="p">):</span> <span class="n">flag</span> <span class="o">=</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">my_set</span> <span class="k">print</span><span class="p">(</span><span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start_time</span><span class="p">)</span> <span class="c1"># dict查询效率 </span><span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10000000</span><span class="p">):</span> <span class="n">flag</span> <span class="o">=</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">my_dict</span> <span class="k">print</span><span class="p">(</span><span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start_time</span><span class="p">)</span> <span class="c1"># list查找效率 </span><span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10000000</span><span class="p">):</span> <span class="n">flag</span> <span class="o">=</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">my_list</span> <span class="k">print</span><span class="p">(</span><span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start_time</span><span class="p">)</span> </code></pre></div></div> <blockquote> <p>reference: https://blog.csdn.net/weixin_48629601/article/details/107532754</p> </blockquote> Sat, 16 Jul 2022 00:00:00 +0000 https://yarkable.github.io/2022/07/16/python-%E5%B8%B8%E8%A7%81%E9%97%AE%E7%AD%94%E9%A2%98/ https://yarkable.github.io/2022/07/16/python-%E5%B8%B8%E8%A7%81%E9%97%AE%E7%AD%94%E9%A2%98/ python ByteTrack注释详解 <h2 id="preface">preface</h2> <p>最近有用到多目标追踪 Multi Object Tracking 的东西,看过了经典的 DeepSort 源码之后觉得 tracking 挺有意思的也挺有挑战的,ByteTrack 是多目标追踪里面一个相对比较新的的追踪器 (ECCV2022),也比较简单,这里就对源码做一些注释,仅供日后复习参考。</p> <p>ByteTrack 是 TBD(Tracking By Detection) 的方法,每一帧都需要进行检测,然后通过卡尔曼预测出每一条轨迹在当前帧的位置,通过所有轨迹和检测框进行关联给每一条轨迹找到对应的检测框,结合真实的检测框对每段轨迹的卡尔曼预测器的均值和方差进行调整。</p> <p>其中,卡尔曼预测器的作用就是预测出轨迹在当前帧可能出现的位置,一般在代码中会有两个函数,一个是 <code class="language-plaintext highlighter-rouge">predict</code>,一个是 <code class="language-plaintext highlighter-rouge">update</code>,其中 predict 用来预测位置,update 的作用是根据预测的位置和配对的检测框对卡尔曼的参数做调整。同样,Tracker 也有个 update 函数,这个函数就相当于检测领域的 detect 函数,调用之后会返回每一帧轨迹(追踪到的,丢弃的,消失的)</p> <p><a href="https://github.com/ifzhang/ByteTrack">ByteTrack</a> 主要的文件只有 4 个,其中主要的是 <code class="language-plaintext highlighter-rouge">byte_tracker.py</code>,里面包含了 Tracker 的逻辑以及每一段 tracklet 的成员信息。在这里说点预备知识,一段轨迹也就是 tracklet 是由很多个 box 组成的时序上的序列,其实就是某一个 id 在画面中按时序出现的位置;并且 ByteTrack 其实是不用训练的,只要在数据集上训练好检测模型就行了,TBD 形式的追踪器实际上就是对检测结果进行一些逻辑处理。</p> <h2 id="basetrackpy">basetrack.py</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span> <span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span> <span class="c1"># tracker 的 4 种状态 </span><span class="k">class</span> <span class="nc">TrackState</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="n">New</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">Tracked</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">Lost</span> <span class="o">=</span> <span class="mi">2</span> <span class="n">Removed</span> <span class="o">=</span> <span class="mi">3</span> <span class="c1"># tracklet 的基类,拥有一段轨迹的各种属性,包括 id,当前出现的 frame_id 等等 </span><span class="k">class</span> <span class="nc">BaseTrack</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">track_id</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">is_activated</span> <span class="o">=</span> <span class="bp">False</span> <span class="n">state</span> <span class="o">=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">New</span> <span class="n">history</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span> <span class="n">features</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">curr_feature</span> <span class="o">=</span> <span class="bp">None</span> <span class="n">score</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">start_frame</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">frame_id</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">time_since_update</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> </code></pre></div></div> <h2 id="kalman_filterpy">kalman_filter.py</h2> <p>这个函数就是关于卡尔曼滤波器的一些函数,我们将物体的运动假设为匀速运动,运用卡尔曼滤波器对物体在下一帧图像中出现的位置进行一个预测。他包含 8 个状态量(x, y, a, h, vx, vy, va, vh),分别是 bbox 的中心点坐标、 bbox 宽高比例、bbox 的高,以及对应的速度,这里只简单罗列一下,想了解更多的话建议去看这个<a href="https://zhuanlan.zhihu.com/p/90835266">知乎回答</a></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">class</span> <span class="nc">KalmanFilter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="s">""" A simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space x, y, a, h, vx, vy, va, vh contains the bounding box center position (x, y), aspect ratio a, height h, and their respective velocities. Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct observation of the state space (linear observation model). """</span> <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">ndim</span><span class="p">,</span> <span class="n">dt</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span> <span class="mf">1.</span> <span class="c1"># Create Kalman filter model matrices. </span> <span class="bp">self</span><span class="p">.</span><span class="n">_motion_mat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">):</span> <span class="bp">self</span><span class="p">.</span><span class="n">_motion_mat</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">ndim</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">dt</span> <span class="bp">self</span><span class="p">.</span><span class="n">_update_mat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">ndim</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">)</span> <span class="c1"># Motion and observation uncertainty are chosen relative to the current </span> <span class="c1"># state estimate. These weights control the amount of uncertainty in </span> <span class="c1"># the model. This is a bit hacky. </span> <span class="bp">self</span><span class="p">.</span><span class="n">_std_weight_position</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="mi">20</span> <span class="bp">self</span><span class="p">.</span><span class="n">_std_weight_velocity</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="mi">160</span> <span class="k">def</span> <span class="nf">initiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">measurement</span><span class="p">):</span> <span class="s">"""Create track from unassociated measurement. Parameters ---------- measurement : ndarray Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, and height h. Returns ------- (ndarray, ndarray) Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. """</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span><span class="p">):</span> <span class="s">"""Run Kalman filter prediction step. Parameters ---------- mean : ndarray The 8 dimensional mean vector of the object state at the previous time step. covariance : ndarray The 8x8 dimensional covariance matrix of the object state at the previous time step. Returns ------- (ndarray, ndarray) Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are initialized to 0 mean. """</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span> <span class="k">def</span> <span class="nf">project</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span><span class="p">):</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span> <span class="o">+</span> <span class="n">innovation_cov</span> <span class="k">def</span> <span class="nf">multi_predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span><span class="p">):</span> <span class="s">"""Run Kalman filter prediction step (Vectorized version). Parameters ---------- mean : ndarray The Nx8 dimensional mean matrix of the object states at the previous time step. covariance : ndarray The Nx8x8 dimensional covariance matrics of the object states at the previous time step. Returns ------- (ndarray, ndarray) Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are initialized to 0 mean. """</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span><span class="p">,</span> <span class="n">measurement</span><span class="p">):</span> <span class="s">"""Run Kalman filter correction step. Parameters ---------- mean : ndarray The predicted state's mean vector (8 dimensional). covariance : ndarray The state's covariance matrix (8x8 dimensional). measurement : ndarray The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect ratio, and h the height of the bounding box. Returns ------- (ndarray, ndarray) Returns the measurement-corrected state distribution. """</span> <span class="k">return</span> <span class="n">new_mean</span><span class="p">,</span> <span class="n">new_covariance</span> <span class="k">def</span> <span class="nf">gating_distance</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">covariance</span><span class="p">,</span> <span class="n">measurements</span><span class="p">,</span> <span class="n">only_position</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">metric</span><span class="o">=</span><span class="s">'maha'</span><span class="p">):</span> </code></pre></div></div> <h2 id="byte_trackerpy">byte_tracker.py</h2> <p>每一帧的主要逻辑(非常经典,建议背诵):</p> <blockquote> <p>检测器得到 bbox → 卡尔曼滤波对 track 进行预测 → 使用匈牙利算法将预测后的 tracks 和当前帧中的 detecions 进行匹配(IOU匹配) → 卡尔曼滤波状态更新</p> </blockquote> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span> <span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">deque</span> <span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">os.path</span> <span class="k">as</span> <span class="n">osp</span> <span class="kn">import</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span> <span class="kn">from</span> <span class="nn">.kalman_filter</span> <span class="kn">import</span> <span class="n">KalmanFilter</span> <span class="kn">from</span> <span class="nn">yolox.tracker</span> <span class="kn">import</span> <span class="n">matching</span> <span class="kn">from</span> <span class="nn">.basetrack</span> <span class="kn">import</span> <span class="n">BaseTrack</span><span class="p">,</span> <span class="n">TrackState</span> <span class="c1"># 继承 BaseTrack 的单个 track 类 </span><span class="k">class</span> <span class="nc">STrack</span><span class="p">(</span><span class="n">BaseTrack</span><span class="p">):</span> <span class="n">shared_kalman</span> <span class="o">=</span> <span class="n">KalmanFilter</span><span class="p">()</span> <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tlwh</span><span class="p">,</span> <span class="n">score</span><span class="p">):</span> <span class="c1"># wait activate </span> <span class="c1"># 初始化 track 全部都是 False 的状态 </span> <span class="c1"># 一般是第一次出现某个 track 的情景 </span> <span class="bp">self</span><span class="p">.</span><span class="n">_tlwh</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">tlwh</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="nb">float</span><span class="p">)</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span> <span class="o">=</span> <span class="bp">None</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span> <span class="bp">self</span><span class="p">.</span><span class="n">is_activated</span> <span class="o">=</span> <span class="bp">False</span> <span class="bp">self</span><span class="p">.</span><span class="n">score</span> <span class="o">=</span> <span class="n">score</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracklet_len</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># 预测这个 track 下一次的位置,其实就是调用自身卡尔曼的 predict 函数更新均值和方差 </span> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">mean_state</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">state</span> <span class="o">!=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">:</span> <span class="n">mean_state</span><span class="p">[</span><span class="mi">7</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">mean_state</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span><span class="p">)</span> <span class="o">@</span><span class="nb">staticmethod</span> <span class="c1"># 这个就是 predict 函数的矩阵版本,做的事情是一样的 </span> <span class="k">def</span> <span class="nf">multi_predict</span><span class="p">(</span><span class="n">stracks</span><span class="p">):</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">stracks</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="n">multi_mean</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">st</span><span class="p">.</span><span class="n">mean</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">st</span> <span class="ow">in</span> <span class="n">stracks</span><span class="p">])</span> <span class="n">multi_covariance</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">st</span><span class="p">.</span><span class="n">covariance</span> <span class="k">for</span> <span class="n">st</span> <span class="ow">in</span> <span class="n">stracks</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">st</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">stracks</span><span class="p">):</span> <span class="k">if</span> <span class="n">st</span><span class="p">.</span><span class="n">state</span> <span class="o">!=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">:</span> <span class="n">multi_mean</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">7</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">multi_mean</span><span class="p">,</span> <span class="n">multi_covariance</span> <span class="o">=</span> <span class="n">STrack</span><span class="p">.</span><span class="n">shared_kalman</span><span class="p">.</span><span class="n">multi_predict</span><span class="p">(</span><span class="n">multi_mean</span><span class="p">,</span> <span class="n">multi_covariance</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">cov</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">multi_mean</span><span class="p">,</span> <span class="n">multi_covariance</span><span class="p">)):</span> <span class="n">stracks</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span> <span class="n">stracks</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">covariance</span> <span class="o">=</span> <span class="n">cov</span> <span class="c1"># 新激活一个轨迹,用轨迹的初始框来初始化对应的卡尔曼滤波器的参数,并且记录下 track 的 id </span> <span class="c1"># 这个是新建一个 track 调用的函数,并且如果是视频刚开始的话,直接会将 track 的状态变成激活态 </span> <span class="c1"># 不是在视频刚开始激活的框的状态为未激活,需要下一帧还有检测框与其进行匹配才会变成激活状态 </span> <span class="k">def</span> <span class="nf">activate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kalman_filter</span><span class="p">,</span> <span class="n">frame_id</span><span class="p">):</span> <span class="s">"""Start a new tracklet"""</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span> <span class="o">=</span> <span class="n">kalman_filter</span> <span class="bp">self</span><span class="p">.</span><span class="n">track_id</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">next_id</span><span class="p">()</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span><span class="p">.</span><span class="n">initiate</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">tlwh_to_xyah</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">_tlwh</span><span class="p">))</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracklet_len</span> <span class="o">=</span> <span class="mi">0</span> <span class="bp">self</span><span class="p">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span> <span class="k">if</span> <span class="n">frame_id</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="bp">self</span><span class="p">.</span><span class="n">is_activated</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># self.is_activated = True </span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">=</span> <span class="n">frame_id</span> <span class="bp">self</span><span class="p">.</span><span class="n">start_frame</span> <span class="o">=</span> <span class="n">frame_id</span> <span class="c1"># 这个应该是轨迹被遮挡或者消失之后重新激活轨迹调用的函数 </span> <span class="k">def</span> <span class="nf">re_activate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_track</span><span class="p">,</span> <span class="n">frame_id</span><span class="p">,</span> <span class="n">new_id</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">tlwh_to_xyah</span><span class="p">(</span><span class="n">new_track</span><span class="p">.</span><span class="n">tlwh</span><span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracklet_len</span> <span class="o">=</span> <span class="mi">0</span> <span class="bp">self</span><span class="p">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span> <span class="bp">self</span><span class="p">.</span><span class="n">is_activated</span> <span class="o">=</span> <span class="bp">True</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">=</span> <span class="n">frame_id</span> <span class="k">if</span> <span class="n">new_id</span><span class="p">:</span> <span class="bp">self</span><span class="p">.</span><span class="n">track_id</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">next_id</span><span class="p">()</span> <span class="bp">self</span><span class="p">.</span><span class="n">score</span> <span class="o">=</span> <span class="n">new_track</span><span class="p">.</span><span class="n">score</span> <span class="c1"># 更新轨迹的位置 </span> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_track</span><span class="p">,</span> <span class="n">frame_id</span><span class="p">):</span> <span class="s">""" Update a matched track :type new_track: STrack :type frame_id: int :type update_feature: bool :return: """</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">=</span> <span class="n">frame_id</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracklet_len</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">new_tlwh</span> <span class="o">=</span> <span class="n">new_track</span><span class="p">.</span><span class="n">tlwh</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">covariance</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">tlwh_to_xyah</span><span class="p">(</span><span class="n">new_tlwh</span><span class="p">))</span> <span class="bp">self</span><span class="p">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span> <span class="bp">self</span><span class="p">.</span><span class="n">is_activated</span> <span class="o">=</span> <span class="bp">True</span> <span class="bp">self</span><span class="p">.</span><span class="n">score</span> <span class="o">=</span> <span class="n">new_track</span><span class="p">.</span><span class="n">score</span> <span class="o">@</span><span class="nb">property</span> <span class="c1"># @jit(nopython=True) </span> <span class="c1"># 这个函数很重要,在进行匹配的时候会调用到他,指的是 track 在经过卡尔曼预测之后在当前帧的位置 </span> <span class="c1"># 所以这里用了 mean,因为卡尔曼经过 predict 之后会更新 mean 和 covariance 的状态,mean 是 </span> <span class="c1"># [cx, cy, a, h, vx, vy, va, vh],所以 self.mean[:4] 指的就是预测框的坐标信息 </span> <span class="k">def</span> <span class="nf">tlwh</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="s">"""Get current position in bounding box format `(top left x, top left y, width, height)`. """</span> <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span> <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">_tlwh</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span> <span class="n">ret</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mean</span><span class="p">[:</span><span class="mi">4</span><span class="p">].</span><span class="n">copy</span><span class="p">()</span> <span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">*=</span> <span class="n">ret</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="n">ret</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-=</span> <span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">return</span> <span class="n">ret</span> <span class="o">@</span><span class="nb">property</span> <span class="c1"># @jit(nopython=True) </span> <span class="k">def</span> <span class="nf">tlbr</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="s">"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e., `(top left, bottom right)`. """</span> <span class="n">ret</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">tlwh</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span> <span class="n">ret</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span> <span class="o">+=</span> <span class="n">ret</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="k">return</span> <span class="n">ret</span> <span class="k">class</span> <span class="nc">BYTETracker</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">frame_rate</span><span class="o">=</span><span class="mi">30</span><span class="p">):</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># type: list[STrack] </span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># type: list[STrack] </span> <span class="bp">self</span><span class="p">.</span><span class="n">removed_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># type: list[STrack] </span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">=</span> <span class="mi">0</span> <span class="bp">self</span><span class="p">.</span><span class="n">args</span> <span class="o">=</span> <span class="n">args</span> <span class="c1">#self.det_thresh = args.track_thresh </span> <span class="bp">self</span><span class="p">.</span><span class="n">det_thresh</span> <span class="o">=</span> <span class="n">args</span><span class="p">.</span><span class="n">track_thresh</span> <span class="o">+</span> <span class="mf">0.1</span> <span class="c1"># 缓冲的帧数,超过这么多帧丢失目标才算是真正丢失 </span> <span class="bp">self</span><span class="p">.</span><span class="n">buffer_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">frame_rate</span> <span class="o">/</span> <span class="mf">30.0</span> <span class="o">*</span> <span class="n">args</span><span class="p">.</span><span class="n">track_buffer</span><span class="p">)</span> <span class="bp">self</span><span class="p">.</span><span class="n">max_time_lost</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">buffer_size</span> <span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span> <span class="o">=</span> <span class="n">KalmanFilter</span><span class="p">()</span> <span class="c1"># 追踪主要逻辑函数 </span> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_results</span><span class="p">,</span> <span class="n">img_info</span><span class="p">,</span> <span class="n">img_size</span><span class="p">):</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">activated_starcks</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">refind_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">lost_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">removed_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># output_results 是 [xyxy,score] 或者 [xyxy, score, conf] 的情况 </span> <span class="k">if</span> <span class="n">output_results</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">5</span><span class="p">:</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">output_results</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="n">bboxes</span> <span class="o">=</span> <span class="n">output_results</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="k">else</span><span class="p">:</span> <span class="n">output_results</span> <span class="o">=</span> <span class="n">output_results</span><span class="p">.</span><span class="n">cpu</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">output_results</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="n">output_results</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span> <span class="n">bboxes</span> <span class="o">=</span> <span class="n">output_results</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1"># x1y1x2y2 </span> <span class="n">img_h</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">=</span> <span class="n">img_info</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">img_info</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">scale</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">img_h</span><span class="p">),</span> <span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">img_w</span><span class="p">))</span> <span class="n">bboxes</span> <span class="o">/=</span> <span class="n">scale</span> <span class="c1"># 找到置信度高的框,作为第一次关联的框 </span> <span class="n">remain_inds</span> <span class="o">=</span> <span class="n">scores</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">args</span><span class="p">.</span><span class="n">track_thresh</span> <span class="n">inds_low</span> <span class="o">=</span> <span class="n">scores</span> <span class="o">&gt;</span> <span class="mf">0.1</span> <span class="n">inds_high</span> <span class="o">=</span> <span class="n">scores</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">args</span><span class="p">.</span><span class="n">track_thresh</span> <span class="c1"># 找到置信度低的框,作为第二次关联的框 </span> <span class="n">inds_second</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">inds_low</span><span class="p">,</span> <span class="n">inds_high</span><span class="p">)</span> <span class="n">dets_second</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[</span><span class="n">inds_second</span><span class="p">]</span> <span class="n">dets</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[</span><span class="n">remain_inds</span><span class="p">]</span> <span class="n">scores_keep</span> <span class="o">=</span> <span class="n">scores</span><span class="p">[</span><span class="n">remain_inds</span><span class="p">]</span> <span class="n">scores_second</span> <span class="o">=</span> <span class="n">scores</span><span class="p">[</span><span class="n">inds_second</span><span class="p">]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="s">'''Detections'''</span> <span class="c1"># 把初始框封装成 STrack 的格式 </span> <span class="n">detections</span> <span class="o">=</span> <span class="p">[</span><span class="n">STrack</span><span class="p">(</span><span class="n">STrack</span><span class="p">.</span><span class="n">tlbr_to_tlwh</span><span class="p">(</span><span class="n">tlbr</span><span class="p">),</span> <span class="n">s</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">tlbr</span><span class="p">,</span> <span class="n">s</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">dets</span><span class="p">,</span> <span class="n">scores_keep</span><span class="p">)]</span> <span class="k">else</span><span class="p">:</span> <span class="n">detections</span> <span class="o">=</span> <span class="p">[]</span> <span class="s">''' Add newly detected tracklets to tracked_stracks'''</span> <span class="n">unconfirmed</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">tracked_stracks</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># type: list[STrack] </span> <span class="k">for</span> <span class="n">track</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">:</span> <span class="c1"># 当 track 只有一帧的记录时,is_activated=False </span> <span class="k">if</span> <span class="ow">not</span> <span class="n">track</span><span class="p">.</span><span class="n">is_activated</span><span class="p">:</span> <span class="n">unconfirmed</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">tracked_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="s">''' Step 2: First association, with high score detection boxes'''</span> <span class="c1"># 将已经追踪到的 track 和丢失的 track 合并 </span> <span class="c1"># 丢失的 track 代表某一帧可能丢了一次,但是仍然在缓冲帧范围之内,所以依然可以用来匹配 </span> <span class="n">strack_pool</span> <span class="o">=</span> <span class="n">joint_stracks</span><span class="p">(</span><span class="n">tracked_stracks</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">)</span> <span class="c1"># Predict the current location with KF </span> <span class="c1"># 先用卡尔曼预测每一条轨迹在当前帧的位置 </span> <span class="n">STrack</span><span class="p">.</span><span class="n">multi_predict</span><span class="p">(</span><span class="n">strack_pool</span><span class="p">)</span> <span class="c1"># 让预测后的 track 和当前帧的 detection 框做 cost_matrix,用的方式为 IOU 关联 </span> <span class="c1"># 这里的 iou_distance 函数中调用了 track.tlbr,返回的是预测之后的 track 坐标信息 </span> <span class="n">dists</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">iou_distance</span><span class="p">(</span><span class="n">strack_pool</span><span class="p">,</span> <span class="n">detections</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">.</span><span class="n">args</span><span class="p">.</span><span class="n">mot20</span><span class="p">:</span> <span class="n">dists</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">fuse_score</span><span class="p">(</span><span class="n">dists</span><span class="p">,</span> <span class="n">detections</span><span class="p">)</span> <span class="c1"># 用匈牙利算法算出相匹配的 track 和 detection 的索引,以及没有被匹配到的 track 和没有被匹配到的 detection 框的索引 </span> <span class="n">matches</span><span class="p">,</span> <span class="n">u_track</span><span class="p">,</span> <span class="n">u_detection</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">linear_assignment</span><span class="p">(</span><span class="n">dists</span><span class="p">,</span> <span class="n">thresh</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">args</span><span class="p">.</span><span class="n">match_thresh</span><span class="p">)</span> <span class="k">for</span> <span class="n">itracked</span><span class="p">,</span> <span class="n">idet</span> <span class="ow">in</span> <span class="n">matches</span><span class="p">:</span> <span class="c1"># 找到匹配到的所有 track&amp;detection pair 并且用 detection 来更新卡尔曼的状态 </span> <span class="n">track</span> <span class="o">=</span> <span class="n">strack_pool</span><span class="p">[</span><span class="n">itracked</span><span class="p">]</span> <span class="n">det</span> <span class="o">=</span> <span class="n">detections</span><span class="p">[</span><span class="n">idet</span><span class="p">]</span> <span class="c1"># 对应 strack_pool 中的 tracked_stracks </span> <span class="k">if</span> <span class="n">track</span><span class="p">.</span><span class="n">state</span> <span class="o">==</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">detections</span><span class="p">[</span><span class="n">idet</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">)</span> <span class="n">activated_starcks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="c1"># 对应 strack_pool 中的 self.lost_stracks,重新激活 track </span> <span class="k">else</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">re_activate</span><span class="p">(</span><span class="n">det</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">,</span> <span class="n">new_id</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> <span class="n">refind_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="s">''' Step 3: Second association, with low score detection boxes'''</span> <span class="c1"># association the untrack to the low score detections </span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dets_second</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="s">'''Detections'''</span> <span class="o">=</span> <span class="p">[</span><span class="n">STrack</span><span class="p">(</span><span class="n">STrack</span><span class="p">.</span><span class="n">tlbr_to_tlwh</span><span class="p">(</span><span class="n">tlbr</span><span class="p">),</span> <span class="n">s</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">tlbr</span><span class="p">,</span> <span class="n">s</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">dets_second</span><span class="p">,</span> <span class="n">scores_second</span><span class="p">)]</span> <span class="k">else</span><span class="p">:</span> <span class="n">detections_second</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># 找出 strack_pool 中没有被匹配到的 track(这帧目标被遮挡的情况) </span> <span class="n">r_tracked_stracks</span> <span class="o">=</span> <span class="p">[</span><span class="n">strack_pool</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">u_track</span> <span class="k">if</span> <span class="n">strack_pool</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">state</span> <span class="o">==</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">]</span> <span class="n">dists</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">iou_distance</span><span class="p">(</span><span class="n">r_tracked_stracks</span><span class="p">,</span> <span class="n">detections_second</span><span class="p">)</span> <span class="n">matches</span><span class="p">,</span> <span class="n">u_track</span><span class="p">,</span> <span class="n">u_detection_second</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">linear_assignment</span><span class="p">(</span><span class="n">dists</span><span class="p">,</span> <span class="n">thresh</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> <span class="c1"># 在低置信度的检测框中再次与没有被匹配到的 track 做 IOU 匹配 </span> <span class="k">for</span> <span class="n">itracked</span><span class="p">,</span> <span class="n">idet</span> <span class="ow">in</span> <span class="n">matches</span><span class="p">:</span> <span class="n">track</span> <span class="o">=</span> <span class="n">r_tracked_stracks</span><span class="p">[</span><span class="n">itracked</span><span class="p">]</span> <span class="n">det</span> <span class="o">=</span> <span class="n">detections_second</span><span class="p">[</span><span class="n">idet</span><span class="p">]</span> <span class="k">if</span> <span class="n">track</span><span class="p">.</span><span class="n">state</span> <span class="o">==</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">det</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">)</span> <span class="n">activated_starcks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">re_activate</span><span class="p">(</span><span class="n">det</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">,</span> <span class="n">new_id</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> <span class="n">refind_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="c1"># 如果 track 经过两次匹配之后还没有匹配到 box 的话,就标记为丢失了 </span> <span class="k">for</span> <span class="n">it</span> <span class="ow">in</span> <span class="n">u_track</span><span class="p">:</span> <span class="n">track</span> <span class="o">=</span> <span class="n">r_tracked_stracks</span><span class="p">[</span><span class="n">it</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">track</span><span class="p">.</span><span class="n">state</span> <span class="o">==</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Lost</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">mark_lost</span><span class="p">()</span> <span class="n">lost_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="s">'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''</span> <span class="c1"># 处理第一次匹配时没有被 track 匹配的检测框(一般是这个检测框第一次出现的情形) </span> <span class="n">detections</span> <span class="o">=</span> <span class="p">[</span><span class="n">detections</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">u_detection</span><span class="p">]</span> <span class="c1"># 计算未被匹配的框和不确定的 track 之间的 cost_matrix </span> <span class="n">dists</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">iou_distance</span><span class="p">(</span><span class="n">unconfirmed</span><span class="p">,</span> <span class="n">detections</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">.</span><span class="n">args</span><span class="p">.</span><span class="n">mot20</span><span class="p">:</span> <span class="n">dists</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">fuse_score</span><span class="p">(</span><span class="n">dists</span><span class="p">,</span> <span class="n">detections</span><span class="p">)</span> <span class="n">matches</span><span class="p">,</span> <span class="n">u_unconfirmed</span><span class="p">,</span> <span class="n">u_detection</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">linear_assignment</span><span class="p">(</span><span class="n">dists</span><span class="p">,</span> <span class="n">thresh</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span> <span class="c1"># 如果能够匹配上的话,说明这个 track 已经是确定状态了, </span> <span class="c1"># 用当前匹配到的框对卡尔曼的预测进行调节,并且将其加入到 activated_starcks </span> <span class="k">for</span> <span class="n">itracked</span><span class="p">,</span> <span class="n">idet</span> <span class="ow">in</span> <span class="n">matches</span><span class="p">:</span> <span class="n">unconfirmed</span><span class="p">[</span><span class="n">itracked</span><span class="p">].</span><span class="n">update</span><span class="p">(</span><span class="n">detections</span><span class="p">[</span><span class="n">idet</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">)</span> <span class="n">activated_starcks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">unconfirmed</span><span class="p">[</span><span class="n">itracked</span><span class="p">])</span> <span class="c1"># 匹配不上的 unconfirmed_track 就直接删除了,说明这个 track 只出现了一帧 </span> <span class="k">for</span> <span class="n">it</span> <span class="ow">in</span> <span class="n">u_unconfirmed</span><span class="p">:</span> <span class="n">track</span> <span class="o">=</span> <span class="n">unconfirmed</span><span class="p">[</span><span class="n">it</span><span class="p">]</span> <span class="n">track</span><span class="p">.</span><span class="n">mark_removed</span><span class="p">()</span> <span class="n">removed_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="s">""" Step 4: Init new stracks"""</span> <span class="c1"># 经过上面这些步骤之后,如果还有没被匹配的检测框,说明可能画面中新来了一个物体 </span> <span class="c1"># 那么就直接将他视为一个新的 track,但是这个 track 的状态并不是激活态 </span> <span class="c1"># 在下一次循环的时候会先将他放到 unconfirmed_track 中去,然后根据有没有框匹配他来决定是激活还是丢弃 </span> <span class="k">for</span> <span class="n">inew</span> <span class="ow">in</span> <span class="n">u_detection</span><span class="p">:</span> <span class="n">track</span> <span class="o">=</span> <span class="n">detections</span><span class="p">[</span><span class="n">inew</span><span class="p">]</span> <span class="k">if</span> <span class="n">track</span><span class="p">.</span><span class="n">score</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">det_thresh</span><span class="p">:</span> <span class="k">continue</span> <span class="n">track</span><span class="p">.</span><span class="n">activate</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">kalman_filter</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span><span class="p">)</span> <span class="n">activated_starcks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="s">""" Step 5: Update state"""</span> <span class="c1"># 对于丢失目标的 track 来说,判断他丢失的帧数是不是超过了 buffer 缓冲帧数,超过就删除 </span> <span class="k">for</span> <span class="n">track</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">:</span> <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">frame_id</span> <span class="o">-</span> <span class="n">track</span><span class="p">.</span><span class="n">end_frame</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">max_time_lost</span><span class="p">:</span> <span class="n">track</span><span class="p">.</span><span class="n">mark_removed</span><span class="p">()</span> <span class="n">removed_stracks</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">track</span><span class="p">)</span> <span class="c1"># print('Ramained match {} s'.format(t4-t3)) </span> <span class="c1"># 指上一帧匹配上的 track </span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="k">if</span> <span class="n">t</span><span class="p">.</span><span class="n">state</span> <span class="o">==</span> <span class="n">TrackState</span><span class="p">.</span><span class="n">Tracked</span><span class="p">]</span> <span class="c1"># 加上这一帧新激活的 track(两次匹配到的 track,以及由 unconfirm 状态变为激活态的 track) </span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="o">=</span> <span class="n">joint_stracks</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">,</span> <span class="n">activated_starcks</span><span class="p">)</span> <span class="c1"># 加上丢帧目标重新被匹配的 track </span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="o">=</span> <span class="n">joint_stracks</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">,</span> <span class="n">refind_stracks</span><span class="p">)</span> <span class="c1"># self.lost_stracks 在经过这一帧的匹配之后如果被重新激活的话就将其移出列表 </span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span> <span class="o">=</span> <span class="n">sub_stracks</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">)</span> <span class="c1"># 将这一帧丢失的 track 添加进列表 </span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">lost_stracks</span><span class="p">)</span> <span class="c1"># self.lost_stracks 如果在缓冲帧数内一直没有被匹配上被 remove 的话也将其移出 lost_stracks 列表 </span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span> <span class="o">=</span> <span class="n">sub_stracks</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">removed_stracks</span><span class="p">)</span> <span class="c1"># 更新被移除的 track 列表 </span> <span class="bp">self</span><span class="p">.</span><span class="n">removed_stracks</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">removed_stracks</span><span class="p">)</span> <span class="c1"># 将这两段 track 中重合度高的部分给移除掉(暂时还不是特别理解为啥要这样) </span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span> <span class="o">=</span> <span class="n">remove_duplicate_stracks</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">lost_stracks</span><span class="p">)</span> <span class="c1"># get scores of lost tracks </span> <span class="c1"># 得到最终的结果,也就是成功追踪的 track 序列 </span> <span class="n">output_stracks</span> <span class="o">=</span> <span class="p">[</span><span class="n">track</span> <span class="k">for</span> <span class="n">track</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">tracked_stracks</span> <span class="k">if</span> <span class="n">track</span><span class="p">.</span><span class="n">is_activated</span><span class="p">]</span> <span class="k">return</span> <span class="n">output_stracks</span> <span class="c1"># 将 tlista 和 tlistb 的 track 给合并成一个大的列表 </span><span class="k">def</span> <span class="nf">joint_stracks</span><span class="p">(</span><span class="n">tlista</span><span class="p">,</span> <span class="n">tlistb</span><span class="p">):</span> <span class="n">exists</span> <span class="o">=</span> <span class="p">{}</span> <span class="n">res</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tlista</span><span class="p">:</span> <span class="n">exists</span><span class="p">[</span><span class="n">t</span><span class="p">.</span><span class="n">track_id</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">res</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tlistb</span><span class="p">:</span> <span class="n">tid</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">track_id</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">exists</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">tid</span><span class="p">,</span> <span class="mi">0</span><span class="p">):</span> <span class="n">exists</span><span class="p">[</span><span class="n">tid</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">res</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">return</span> <span class="n">res</span> <span class="c1"># 取两个 track 的不重合部分 </span><span class="k">def</span> <span class="nf">sub_stracks</span><span class="p">(</span><span class="n">tlista</span><span class="p">,</span> <span class="n">tlistb</span><span class="p">):</span> <span class="n">stracks</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tlista</span><span class="p">:</span> <span class="n">stracks</span><span class="p">[</span><span class="n">t</span><span class="p">.</span><span class="n">track_id</span><span class="p">]</span> <span class="o">=</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tlistb</span><span class="p">:</span> <span class="n">tid</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">track_id</span> <span class="k">if</span> <span class="n">stracks</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">tid</span><span class="p">,</span> <span class="mi">0</span><span class="p">):</span> <span class="k">del</span> <span class="n">stracks</span><span class="p">[</span><span class="n">tid</span><span class="p">]</span> <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">stracks</span><span class="p">.</span><span class="n">values</span><span class="p">())</span> <span class="c1"># 如果两段 track 离得很近的话,就要去掉一个 # 根据时间维度上出现的帧数多少来决定移除哪一边的 track </span><span class="k">def</span> <span class="nf">remove_duplicate_stracks</span><span class="p">(</span><span class="n">stracksa</span><span class="p">,</span> <span class="n">stracksb</span><span class="p">):</span> <span class="n">pdist</span> <span class="o">=</span> <span class="n">matching</span><span class="p">.</span><span class="n">iou_distance</span><span class="p">(</span><span class="n">stracksa</span><span class="p">,</span> <span class="n">stracksb</span><span class="p">)</span> <span class="n">pairs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">pdist</span> <span class="o">&lt;</span> <span class="mf">0.15</span><span class="p">)</span> <span class="n">dupa</span><span class="p">,</span> <span class="n">dupb</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(),</span> <span class="nb">list</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">pairs</span><span class="p">):</span> <span class="n">timep</span> <span class="o">=</span> <span class="n">stracksa</span><span class="p">[</span><span class="n">p</span><span class="p">].</span><span class="n">frame_id</span> <span class="o">-</span> <span class="n">stracksa</span><span class="p">[</span><span class="n">p</span><span class="p">].</span><span class="n">start_frame</span> <span class="n">timeq</span> <span class="o">=</span> <span class="n">stracksb</span><span class="p">[</span><span class="n">q</span><span class="p">].</span><span class="n">frame_id</span> <span class="o">-</span> <span class="n">stracksb</span><span class="p">[</span><span class="n">q</span><span class="p">].</span><span class="n">start_frame</span> <span class="k">if</span> <span class="n">timep</span> <span class="o">&gt;</span> <span class="n">timeq</span><span class="p">:</span> <span class="n">dupb</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">q</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">dupa</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="n">resa</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">stracksa</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dupa</span><span class="p">]</span> <span class="n">resb</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">stracksb</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dupb</span><span class="p">]</span> <span class="k">return</span> <span class="n">resa</span><span class="p">,</span> <span class="n">resb</span> </code></pre></div></div> <h2 id="matchingpy">matching.py</h2> <p>只用到了这一个函数,用 IOU 作为匹配度量,计算 cost_matrix</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">iou_distance</span><span class="p">(</span><span class="n">atracks</span><span class="p">,</span> <span class="n">btracks</span><span class="p">):</span> <span class="s">""" Compute cost based on IoU :type atracks: list[STrack] :type btracks: list[STrack] :rtype cost_matrix np.ndarray """</span> <span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">atracks</span><span class="p">)</span><span class="o">&gt;</span><span class="mi">0</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">atracks</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">ndarray</span><span class="p">))</span> <span class="ow">or</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">btracks</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">btracks</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">ndarray</span><span class="p">)):</span> <span class="n">atlbrs</span> <span class="o">=</span> <span class="n">atracks</span> <span class="n">btlbrs</span> <span class="o">=</span> <span class="n">btracks</span> <span class="k">else</span><span class="p">:</span> <span class="c1"># 注意这里调用的是 track.tlbr,是经过了卡尔曼 predict 之后的坐标! </span> <span class="n">atlbrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">track</span><span class="p">.</span><span class="n">tlbr</span> <span class="k">for</span> <span class="n">track</span> <span class="ow">in</span> <span class="n">atracks</span><span class="p">]</span> <span class="n">btlbrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">track</span><span class="p">.</span><span class="n">tlbr</span> <span class="k">for</span> <span class="n">track</span> <span class="ow">in</span> <span class="n">btracks</span><span class="p">]</span> <span class="n">_ious</span> <span class="o">=</span> <span class="n">ious</span><span class="p">(</span><span class="n">atlbrs</span><span class="p">,</span> <span class="n">btlbrs</span><span class="p">)</span> <span class="n">cost_matrix</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">_ious</span> <span class="k">return</span> <span class="n">cost_matrix</span> </code></pre></div></div> <h2 id="reference">reference</h2> <p>https://zhuanlan.zhihu.com/p/90835266</p> Sat, 16 Jul 2022 00:00:00 +0000 https://yarkable.github.io/2022/07/16/ByteTrack%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ https://yarkable.github.io/2022/07/16/ByteTrack%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ deep learning MOT ONNX与TensorRT系列 <p>## onnx</p> <p>本质上就是一个有向无环图,用 trace 的方法以一个 dummy tensor 来前向推理一遍网络,来记录下经过的结点,形成一个 graph。</p> <p>用 <code class="language-plaintext highlighter-rouge">onnx_model.graph.node</code> 可以得到所有的节点信息,每一个节点里面都有属性,name, input,output,等信息,netron 就是根据这个进行可视化的。</p> <p>PyTorch 模型在导出到 ONNX 模型时,模型的输入参数的类型必须全部是 torch.Tensor。而实际上我们传入的第二个参数” 3 “是一个整形变量。这不符合 PyTorch 转 ONNX 的规定。我们必须要修改一下原来的模型的输入。为了保证输入的所有参数都是 torch.Tensor 类型的。</p> <p>torch.onnx.export 中需要的模型实际上是一个 torch.jit.ScriptModule。而要把普通 PyTorch 模型转一个这样的 TorchScript 模型,有跟踪(trace)和记录(script)两种导出计算图的方法。如果给 torch.onnx.export 传入了一个普通 PyTorch 模型 (torch.nn.Module),那么这个模型会默认使用跟踪的方法导出。这一过程如下图所示:</p> <p>有些时候,我们希望模型在直接用 PyTorch 推理时有一套逻辑,而在导出的 ONNX 模型中有另一套逻辑。比如,我们可以把一些后处理的逻辑放在模型里,以简化除运行模型之外的其他代码。torch.onnx.is_in_onnx_export() 可以实现这一任务,该函数仅在执行 torch.onnx.export() 时为真。以下是一个例子:</p> <p>-–</p> <p>在转换普通的 torch.nn.Module 模型时,PyTorch 一方面会用跟踪法执行前向推理,把遇到的算子整合成计算图;另一方面,PyTorch 还会把遇到的每个算子翻译成 ONNX 中定义的算子。在这个翻译过程中,可能会碰到以下情况:</p> <p>· 该算子可以一对一地翻译成一个 ONNX 算子。</p> <p>· 该算子在 ONNX 中没有直接对应的算子,会翻译成一至多个 ONNX 算子。</p> <p>· 该算子没有定义翻译成 ONNX 的规则,报错。</p> <p>-–</p> <p>ONNX 算子的定义情况,都可以在官方的算子文档中查看。这份文档十分重要,我们碰到任何和 ONNX 算子有关的问题都得来”请教“这份文档。</p> <p>算子文档链接:</p> <p>https://github.com/onnx/onnx/blob/main/docs/Operators.md</p> <p>在 PyTorch 中,和 ONNX 有关的定义全部放在 torch.onnx 目录中,如下图所示:</p> <p>torch.onnx 目录网址:</p> <p>https://github.com/pytorch/pytorch/tree/master/torch/onnx</p> <p>使用 torch.onnx.is_in_onnx_export() 来使模型在转换到 ONNX 时有不同的行为.</p> <p>-–</p> <p>跟踪法得到的 ONNX 模型结构。可以看出来,对于不同的 n,ONNX 模型的结构是不一样的。</p> <p>而用记录法的话,最终的 ONNX 模型用 Loop 节点来表示循环。这样哪怕对于不同的 n,ONNX 模型也有同样的结构。</p> <p>-–</p> <p>在实际的部署过程中,难免碰到模型无法用原生 PyTorch 算子表示的情况。这个时候,我们就得考虑扩充 PyTorch,即在 PyTorch 中支持更多 ONNX 算子。</p> <p>而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:</p> <p>· 算子在 PyTorch 中有实现</p> <p>· 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法</p> <p>· ONNX 有相应的算子</p> <p>可在实际部署中,这三部分的内容都可能有所缺失。其中最坏的情况是:我们定义了一个全新的算子,它不仅缺少 PyTorch 实现,还缺少 PyTorch 到 ONNX 的映射关系。但所谓车到山前必有路,对于这三个环节,我们也分别都有以下的添加支持的方法:</p> <p>· PyTorch 算子</p> <p>- 组合现有算子</p> <p>- 添加 TorchScript 算子</p> <p>- 添加普通 C++ 拓展算子</p> <p>· 映射方法</p> <p>- 为 ATen 算子添加符号函数</p> <p>- 为 TorchScript 算子添加符号函数</p> <p>- 封装成 torch.autograd.Function 并添加符号函数</p> <p>· ONNX 算子</p> <p>- 使用现有 ONNX 算子</p> <p>- 定义新 ONNX 算子</p> <hr /> <p>一般转完 onnx 之后会用 onnxruntime 推理一下进行验证</p> <h2 id="trt-推理遇到的坑">TRT 推理遇到的坑</h2> <ol> <li>pycuda 安装失败 <ol> <li>源码编译就行了</li> </ol> </li> </ol> <h2 id="量化三问">量化三问</h2> <p>1) 为什么量化有用?</p> <p>因为CNN对噪声不敏感。</p> <p>2) 为什么用量化?</p> <p>模型太大,比如alexnet就200MB,贼大,存储压力太大啦;每个层的weights范围基本都是确定的,且波动不大。而且减少访存减少计算量,优势很大的啊!</p> <p>3) 为什么不直接训练低精度的模型?</p> <p>因为你训练是需要反向传播和梯度下降的,int8就非常不好做了,举个例子就是我们的学习率一般都是零点几零点几的,你一个int8怎么玩?其次大家的生态就是浮点模型,因此直接转换有效的多啊!</p> <blockquote> <p><a href="https://blog.csdn.net/weixin_34910922/article/details/108502449">(35条消息) 基于tensorRT方案的INT8量化实现原理_alex1801的博客-CSDN博客_tensorrt量化原理</a></p> </blockquote> Sun, 22 May 2022 00:00:00 +0000 https://yarkable.github.io/2022/05/22/ONNX%E4%B8%8ETensorRT%E7%B3%BB%E5%88%97/ https://yarkable.github.io/2022/05/22/ONNX%E4%B8%8ETensorRT%E7%B3%BB%E5%88%97/ linux object detection deep learning mmdetection mmdetection之DETR注释详解 <h2 id="preface">preface</h2> <p>本文记录 mmdetection 对 DETR 训练的流程,包括标签获取,transformer encoder&amp;decoder,前向训练,以及各步骤中 tensor 的形状,仅供复习用处。mmdetection 版本为 2.11.0。</p> <h2 id="detr">DETR</h2> <p>先从整个模型的 detector 看起,DETR 直接继承了 <em><code class="language-plaintext highlighter-rouge">SingleStageDetector</code></em>,所以改变的就是检测头,重点都在 TransformerHead 里面,我们直接从 forward_train 开始看</p> <h2 id="transformerhead">TransformerHead</h2> <h3 id="forward_train">forward_train</h3> <p>跟其他的检测头差不多,先是调用自己,也就是自身的 forward 函数,得到输出的 class label 和 reg coordinate,再调用自身的 loss 函数,不过这里是重载了一下,将 img_meta 传输进了 forward 函数的参数。</p> <pre><code class="language-Python"># over-write because img_metas are needed as inputs for bbox_head. def forward_train(self, x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=None, proposal_cfg=None, **kwargs): """Forward function for training mode. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert proposal_cfg is None, '"proposal_cfg" must be None' # 前向推理结果,后面有分析 outs = self(x, img_metas) if gt_labels is None: loss_inputs = outs + (gt_bboxes, img_metas) else: loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) return losses </code></pre> <h3 id="forwardforward_single">forward&amp;forward_single</h3> <pre><code class="language-Python">def forward(self, feats, img_metas): """Forward function. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. img_metas (list[dict]): List of image information. Returns: tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels. - all_cls_scores_list (list[Tensor]): Classification scores \ for each scale level. Each is a 4D-tensor with shape \ [nb_dec, bs, num_query, cls_out_channels]. Note \ `cls_out_channels` should includes background. - all_bbox_preds_list (list[Tensor]): Sigmoid regression \ outputs for each scale level. Each is a 4D-tensor with \ normalized coordinate format (cx, cy, w, h) and shape \ [nb_dec, bs, num_query, 4]. """ # 这里是 1, 因为 DETR 默认用最后一层特征图 num_levels = len(feats) img_metas_list = [img_metas for _ in range(num_levels)] return multi_apply(self.forward_single, feats, img_metas_list) </code></pre> <p>直接看 forward_single,里面是 head 前向的逻辑。</p> <pre><code class="language-Python">def forward_single(self, x, img_metas): """"Forward function for a single feature level. Args: x (Tensor): Input feature from backbone's single stage, shape [bs, c, h, w]. img_metas (list[dict]): List of image information. Returns: all_cls_scores (Tensor): Outputs from the classification head, shape [nb_dec, bs, num_query, cls_out_channels]. Note cls_out_channels should includes background. all_bbox_preds (Tensor): Sigmoid outputs from the regression head with normalized coordinate format (cx, cy, w, h). Shape [nb_dec, bs, num_query, 4]. """ # construct binary masks which used for the transformer. # NOTE following the official DETR repo, non-zero values representing # ignored positions, while zero values means valid positions. batch_size = x.size(0) # batch 中每张图的 batch_input_shape 都是一样的 input_img_h, input_img_w = img_metas[0]['batch_input_shape'] # 先将 mask 设置为全 1 masks = x.new_ones((batch_size, input_img_h, input_img_w)) # 对每一张图来说,在原来图片有像素的地方把 mask 置 0 # 因此 mask 中 padding 的地方才是 1 for img_id in range(batch_size): img_h, img_w, _ = img_metas[img_id]['img_shape'] masks[img_id, :img_h, :img_w] = 0 # 将每一层的特征图先投影到指定的特征维度 # self.input_proj = Conv2d(self.in_channels, self.embed_dims, kernel_size=1) x = self.input_proj(x) # shape:(B,embed_dims,H,W) # interpolate masks to have the same spatial shape with x # masks: [B, H, W] masks = F.interpolate( masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # position encoding # 得到位置编码 shape:[B, 256, H, W] pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w] # outs_dec: [nb_dec, bs, num_query, embed_dim] # self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight, pos_embed) # 对 query 进行分类和回归 # shape [num_decoder, B, num_query, num_class+1] all_cls_scores = self.fc_cls(outs_dec) # 经过 ffn 再经过一个卷积得到 4 个输出的值,经过 sigmoid 归一化到 0-1,输出的是 xyhw # shape [num_decoder, B, num_query, 4] all_bbox_preds = self.fc_reg(self.activate( self.reg_ffn(outs_dec))).sigmoid() return all_cls_scores, all_bbox_preds </code></pre> <h3 id="loss">loss</h3> <p>来这里看看 DETR 怎么计算 loss 的</p> <pre><code class="language-Python">@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list')) def loss(self, all_cls_scores_list, all_bbox_preds_list, gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore=None): """"Loss function. Only outputs from the last feature level are used for computing losses by default. Args: all_cls_scores_list (list[Tensor]): Classification outputs for each feature level. Each is a 4D-tensor with shape [nb_dec, bs, num_query, cls_out_channels]. all_bbox_preds_list (list[Tensor]): Sigmoid regression outputs for each feature level. Each is a 4D-tensor with normalized coordinate format (cx, cy, w, h) and shape [nb_dec, bs, num_query, 4]. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels_list (list[Tensor]): Ground truth class indices for each image with shape (num_gts, ). img_metas (list[dict]): List of image meta information. gt_bboxes_ignore (list[Tensor], optional): Bounding boxes which can be ignored for each image. Default None. Returns: dict[str, Tensor]: A dictionary of loss components. """ # NOTE defaultly only the outputs from the last feature scale is used. # shape: [num_decoder, B, num_query, num_class+1] all_cls_scores = all_cls_scores_list[-1] # shape: [num_decoder, B, num_query, 4] all_bbox_preds = all_bbox_preds_list[-1] assert gt_bboxes_ignore is None, \ 'Only supports for gt_bboxes_ignore setting to None.' # decoder 的层数,默认是 6 num_dec_layers = len(all_cls_scores) all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] all_gt_bboxes_ignore_list = [ gt_bboxes_ignore for _ in range(num_dec_layers) ] img_metas_list = [img_metas for _ in range(num_dec_layers)] # 调用 loss_single 函数 losses_cls, losses_bbox, losses_iou = multi_apply( self.loss_single, all_cls_scores, all_bbox_preds, all_gt_bboxes_list, all_gt_labels_list, img_metas_list, all_gt_bboxes_ignore_list) # 分别计算每一层 decoder 的 loss loss_dict = dict() # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_bbox'] = losses_bbox[-1] loss_dict['loss_iou'] = losses_iou[-1] # loss from other decoder layers num_dec_layer = 0 for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i num_dec_layer += 1 return loss_dict </code></pre> <h3 id="loss_single">loss_single</h3> <p>主要的 loss 逻辑在这里</p> <pre><code class="language-Python">def loss_single(self, cls_scores, bbox_preds, gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list=None): """"Loss function for outputs from a single decoder layer of a single feature level. Args: cls_scores (Tensor): Box score logits from a single decoder layer for all images. Shape [bs, num_query, cls_out_channels]. bbox_preds (Tensor): Sigmoid outputs from a single decoder layer for all images, with normalized coordinate (cx, cy, w, h) and shape [bs, num_query, 4]. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels_list (list[Tensor]): Ground truth class indices for each image with shape (num_gts, ). img_metas (list[dict]): List of image meta information. gt_bboxes_ignore_list (list[Tensor], optional): Bounding boxes which can be ignored for each image. Default None. Returns: dict[str, Tensor]: A dictionary of loss components for outputs from a single decoder layer. """ num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list) (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) label_weights = torch.cat(label_weights_list, 0) bbox_targets = torch.cat(bbox_targets_list, 0) bbox_weights = torch.cat(bbox_weights_list, 0) # classification loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ num_total_neg * self.bg_cls_weight loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) # Compute the average number of gt boxes accross all gpus, for # normalization purposes num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() # construct factors used for rescale bboxes factors = [] for img_meta, bbox_pred in zip(img_metas, bbox_preds): img_h, img_w, _ = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors, 0) # DETR regress the relative position of boxes (cxcywh) in the image, # thus the learning target is normalized by the image size. So here # we need to re-scale them for calculating IoU loss bbox_preds = bbox_preds.reshape(-1, 4) bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) return loss_cls, loss_bbox, loss_iou </code></pre> <h2 id="transformer">Transformer</h2> <p>前面在 TransformerHead 中,特征图通过调用<code class="language-plaintext highlighter-rouge">self.transformer</code> 经过了 transformer 编解码得到了输出,这里就来分析一下 transformer 里面的一些组件。</p> <pre><code class="language-Bash">transformer=dict( type='Transformer', embed_dims=256, num_heads=8, num_encoder_layers=6, num_decoder_layers=6, feedforward_channels=2048, dropout=0.1, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2, pre_norm=False, return_intermediate_dec=True), </code></pre> <p>Transformer 类的主要代码如下</p> <pre><code class="language-Python">class Transformer(nn.Module): """Implements the DETR transformer. Following the official DETR implementation, this module copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MultiheadAttention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers See `paper: End-to-End Object Detection with Transformers &lt;https://arxiv.org/pdf/2005.12872&gt;`_ for details. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. num_encoder_layers (int): Number of `TransformerEncoderLayer`. num_decoder_layers (int): Number of `TransformerDecoderLayer`. feedforward_channels (int): The hidden dimension for FFNs used in both encoder and decoder. dropout (float): Probability of an element to be zeroed. Default 0.0. act_cfg (dict): Activation config for FFNs used in both encoder and decoder. Default ReLU. norm_cfg (dict): Config dict for normalization used in both encoder and decoder. Default layer normalization. num_fcs (int): The number of fully-connected layers in FFNs, which is used for both encoder and decoder. pre_norm (bool): Whether the normalization layer is ordered first in the encoder and decoder. Default False. return_intermediate_dec (bool): Whether to return the intermediate output from each TransformerDecoderLayer or only the last TransformerDecoderLayer. Default False. If False, the returned `hs` has shape [num_decoder_layers, bs, num_query, embed_dims]. If True, the returned `hs` will have shape [1, bs, num_query, embed_dims]. """ def __init__(self, embed_dims=512, num_heads=8, num_encoder_layers=6, num_decoder_layers=6, feedforward_channels=2048, dropout=0.0, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2, pre_norm=False, return_intermediate_dec=False): super(Transformer, self).__init__() self.embed_dims = embed_dims self.num_heads = num_heads self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.feedforward_channels = feedforward_channels self.dropout = dropout self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.num_fcs = num_fcs self.pre_norm = pre_norm self.return_intermediate_dec = return_intermediate_dec # 进行 operation 的顺序 if self.pre_norm: encoder_order = ('norm', 'selfattn', 'norm', 'ffn') decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn') else: encoder_order = ('selfattn', 'norm', 'ffn', 'norm') decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') # 编码器与解码器 self.encoder = TransformerEncoder(num_encoder_layers, embed_dims, num_heads, feedforward_channels, dropout, encoder_order, act_cfg, norm_cfg, num_fcs) self.decoder = TransformerDecoder(num_decoder_layers, embed_dims, num_heads, feedforward_channels, dropout, decoder_order, act_cfg, norm_cfg, num_fcs, return_intermediate_dec) def init_weights(self, distribution='uniform'): """Initialize the transformer weights.""" # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() &gt; 1: xavier_init(m, distribution=distribution) def forward(self, x, mask, query_embed, pos_embed): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, c, h, w = x.shape x = x.flatten(2).permute(2, 0, 1) # [bs, c, h, w] -&gt; [h*w, bs, c] pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # 同上 query_embed = query_embed.unsqueeze(1).repeat( 1, bs, 1) # [num_query, dim] -&gt; [num_query, bs, dim] # mask 为 0 的地方表示有像素存在(非 padding) mask = mask.flatten(1) # [bs, h, w] -&gt; [bs, h*w] # 得到经过 encode 的中间特征: [h*w, bs, c],和 x 是一样的 shape,也就是说 encoder 并不改变 shape memory = self.encoder( x, pos=pos_embed, attn_mask=None, key_padding_mask=mask) # target 相当于将 quey_embed 置初始值 0 传入 decoder 进行查询 target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( target, memory, memory_pos=pos_embed, query_pos=query_embed, memory_attn_mask=None, target_attn_mask=None, memory_key_padding_mask=mask, target_key_padding_mask=None) # [num_layers, num_query, bs, dim] -&gt; [num_layers, bs, num_query, dim] out_dec = out_dec.transpose(1, 2) # [h*w, bs, dim] -&gt; [bs, dim, h*w] -&gt; [bs, dim, h, w] memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) return out_dec, memory </code></pre> <h3 id="transformerencoder">TransformerEncoder</h3> <pre><code class="language-Python">class TransformerEncoder(nn.Module): """Implements the encoder in DETR transformer. Args: num_layers (int): The number of `TransformerEncoderLayer`. embed_dims (int): Same as `TransformerEncoderLayer`. num_heads (int): Same as `TransformerEncoderLayer`. feedforward_channels (int): Same as `TransformerEncoderLayer`. dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. order (tuple[str]): Same as `TransformerEncoderLayer`. act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. norm_cfg (dict): Same as `TransformerEncoderLayer`. Default layer normalization. num_fcs (int): Same as `TransformerEncoderLayer`. Default 2. """ def __init__(self, num_layers, embed_dims, num_heads, feedforward_channels, dropout=0.0, order=('selfattn', 'norm', 'ffn', 'norm'), act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2): super(TransformerEncoder, self).__init__() assert isinstance(order, tuple) and len(order) == 4 assert set(order) == set(['selfattn', 'norm', 'ffn']) self.num_layers = num_layers self.embed_dims = embed_dims self.num_heads = num_heads self.feedforward_channels = feedforward_channels self.dropout = dropout self.order = order self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.num_fcs = num_fcs self.pre_norm = order[0] == 'norm' self.layers = nn.ModuleList() # 一共要经过 num_layers 层 transformer encoder 进行编码 for _ in range(num_layers): self.layers.append( TransformerEncoderLayer(embed_dims, num_heads, feedforward_channels, dropout, order, act_cfg, norm_cfg, num_fcs)) self.norm = build_norm_layer(norm_cfg, embed_dims)[1] if self.pre_norm else None def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): """Forward function for `TransformerEncoder`. Args: x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`. pos (Tensor): Positional encoding for query. Default None. Same in `TransformerEncoderLayer.forward`. attn_mask (Tensor): ByteTensor attention mask. Default None. Same in `TransformerEncoderLayer.forward`. key_padding_mask (Tensor): Same in `TransformerEncoderLayer.forward`. Default None. Returns: Tensor: Results with shape [num_key, bs, embed_dims]. """ # 不断地经过 encoder 进行编码 for layer in self.layers: x = layer(x, pos, attn_mask, key_padding_mask) if self.norm is not None: x = self.norm(x) return x </code></pre> <h3 id="transformerencoderlayer">TransformerEncoderLayer</h3> <pre><code class="language-Python">class TransformerEncoderLayer(nn.Module): """Implements one encoder layer in DETR transformer. Args: embed_dims (int): The feature dimension. Same as `FFN`. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. dropout (float): Probability of an element to be zeroed. Default 0.0. order (tuple[str]): The order for encoder layer. Valid examples are ('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn', 'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm'). act_cfg (dict): The activation config for FFNs. Default ReLU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization. num_fcs (int): The number of fully-connected layers for FFNs. Default 2. """ def __init__(self, embed_dims, num_heads, feedforward_channels, dropout=0.0, order=('selfattn', 'norm', 'ffn', 'norm'), act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2): super(TransformerEncoderLayer, self).__init__() assert isinstance(order, tuple) and len(order) == 4 assert set(order) == set(['selfattn', 'norm', 'ffn']) self.embed_dims = embed_dims self.num_heads = num_heads self.feedforward_channels = feedforward_channels self.dropout = dropout self.order = order self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.num_fcs = num_fcs self.pre_norm = order[0] == 'norm' self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, dropout) self.norms = nn.ModuleList() self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): """Forward function for `TransformerEncoderLayer`. Args: x (Tensor): The input query with shape [num_key, bs, embed_dims]. Same in `MultiheadAttention.forward`. pos (Tensor): The positional encoding for query. Default None. Same as `query_pos` in `MultiheadAttention.forward`. attn_mask (Tensor): ByteTensor mask with shape [num_key, num_key]. Same in `MultiheadAttention.forward`. Default None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. Same in `MultiheadAttention.forward`. Default None. Returns: Tensor: forwarded results with shape [num_key, bs, embed_dims]. """ norm_cnt = 0 inp_residual = x for layer in self.order: # encoder 中的 self_att 是把输入同时作为 kqv if layer == 'selfattn': # self attention query = key = value = x x = self.self_attn( query, key, value, inp_residual if self.pre_norm else None, query_pos=pos, key_pos=pos, attn_mask=attn_mask, key_padding_mask=key_padding_mask) inp_residual = x elif layer == 'norm': x = self.norms[norm_cnt](x) norm_cnt += 1 elif layer == 'ffn': x = self.ffn(x, inp_residual if self.pre_norm else None) return x </code></pre> <h3 id="multiheadattention">MultiheadAttention</h3> <p>Transformer 里面主要就是这个多头注意力在起作用,代码如下,其实主要就还是调用 <code class="language-plaintext highlighter-rouge">nn.MultiheadAttention</code>,做了一些位置编码的判断 \(\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)</p> <pre><code class="language-Python">class MultiheadAttention(nn.Module): """A warpper for torch.nn.MultiheadAttention. This module implements MultiheadAttention with residual connection, and positional encoding used in DETR is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. dropout (float): A Dropout layer on attn_output_weights. Default 0.0. """ def __init__(self, embed_dims, num_heads, dropout=0.0): super(MultiheadAttention, self).__init__() assert embed_dims % num_heads == 0, 'embed_dims must be ' \ f'divisible by num_heads. got {embed_dims} and {num_heads}.' self.embed_dims = embed_dims self.num_heads = num_heads self.dropout = dropout self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout) self.dropout = nn.Dropout(dropout) def forward(self, x, key=None, value=None, residual=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None): """Forward function for `MultiheadAttention`. Args: x (Tensor): The input query with shape [num_query, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. key (Tensor): The key tensor with shape [num_key, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. Default None. If None, the `query` will be used. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Default None. If None, the `key` will be used. residual (Tensor): The tensor used for addition, with the same shape as `x`. Default None. If None, `x` will be used. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. Default None. If not None, it will be added to `x` before forward function. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Default None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. attn_mask (Tensor): ByteTensor mask with shape [num_query, num_key]. Same in `nn.MultiheadAttention.forward`. Default None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. Same in `nn.MultiheadAttention.forward`. Default None. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ query = x if key is None: key = query if value is None: value = key if residual is None: residual = x if key_pos is None: if query_pos is not None and key is not None: if query_pos.shape == key.shape: key_pos = query_pos if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos out = self.attn( query, key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] return residual + self.dropout(out) </code></pre> <p>这边贴一个民间版 PyTorch 的实现,来方便理解 MultiheadAttention 在干啥。其实主要就是用三个线性层将 qkv 给映射到指定维度,然后 reshape 一下让维度里面有 head 这一个 dim,以此进行并行的 scaled dot-product attention 计算。然后最后将结果给 concat 起来</p> <pre><code class="language-Python">import torch import torch.nn as nn import numpy as np import torch.nn.functional as F class MultiHeadAttention(nn.Module): ''' input: query --- [N, T_q, query_dim] key --- [N, T_k, key_dim] mask --- [N, T_k] T_q 相当于是图像中的 H*W output: out --- [N, T_q, embed_dim] scores -- [h, N, T_q, T_k] ''' def __init__(self, query_dim, key_dim, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.key_dim = key_dim self.W_query = nn.Linear(in_features=query_dim, out_features=embed_dim, bias=False) self.W_key = nn.Linear(in_features=key_dim, out_features=embed_dim, bias=False) self.W_value = nn.Linear(in_features=key_dim, out_features=embed_dim, bias=False) def forward(self, query, key, mask=None): querys = self.W_query(query) # [N, T_q, embed_dim] keys = self.W_key(key) # [N, T_k, embed_dim] values = self.W_value(key) assert self.embed_dim % self.num_heads == 0 split_size = self.embed_dim // self.num_heads querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, embed_dim/h] keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, embed_dim/h] values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, embed_dim/h] ## score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = scores / (self.key_dim ** 0.5) ## mask if mask is not None: ## mask: [N, T_k] --&gt; [h, N, T_q, T_k] mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1) # 将 mask 中为 1 的元素所在的索引在 score 中替换成 -np.inf,经过 softmax 之后这部分的值会变成 0 # 相当于这部分就不进行 attention 的计算 ( np.exp(-np.inf) = 0 ) scores = scores.masked_fill(mask, -np.inf) scores = F.softmax(scores, dim=3) ## out = score * V out = torch.matmul(scores, values) # [h, N, T_q, embed_dim/h] out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, embed_dim] return out,scores attention = MultiHeadAttention(256,512,256,16) ## 输入 qurry = torch.randn(8, 2, 256) key = torch.randn(8, 6 ,512) mask = torch.tensor([[False, False, False, False, True, True], [False, False, False, True, True, True], [False, False, False, False, True, True], [False, False, False, True, True, True], [False, False, False, False, True, True], [False, False, False, True, True, True], [False, False, False, False, True, True], [False, False, False, True, True, True],]) ## 输出 out, scores = attention(qurry, key, mask) print('out:', out.shape) ## torch.Size([8, 2, 256]) print('scores:', scores.shape) ## torch.Size([16, 8, 2, 6]) </code></pre> <h3 id="ffn">FFN</h3> <p>这个没啥说的,就是一个前馈的 MLP,将特征进行全连接输出</p> <pre><code class="language-Python">class FFN(nn.Module): """Implements feed-forward networks (FFNs) with residual connection. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. feedforward_channels (int): The hidden dimension of FFNs. num_fcs (int, optional): The number of fully-connected layers in FFNs. Defaults to 2. act_cfg (dict, optional): The activation config for FFNs. dropout (float, optional): Probability of an element to be zeroed. Default 0.0. add_residual (bool, optional): Add resudual connection. Defaults to True. """ def __init__(self, embed_dims, feedforward_channels, num_fcs=2, act_cfg=dict(type='ReLU', inplace=True), dropout=0.0, add_residual=True): super(FFN, self).__init__() assert num_fcs &gt;= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs self.act_cfg = act_cfg self.dropout = dropout self.activate = build_activation_layer(act_cfg) layers = nn.ModuleList() in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( nn.Sequential( Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(dropout))) in_channels = feedforward_channels layers.append(Linear(feedforward_channels, embed_dims)) self.layers = nn.Sequential(*layers) self.dropout = nn.Dropout(dropout) self.add_residual = add_residual def forward(self, x, residual=None): """Forward function for `FFN`.""" # out 和输入的 x 是相同的 shape out = self.layers(x) if not self.add_residual: return out if residual is None: residual = x return residual + self.dropout(out) </code></pre> <h3 id="transformerdecoder">TransformerDecoder</h3> <p>Decoder 在进行解码的时候加入了 query 信息进去,个人觉得这里比 encoder 部分要更加重要</p> <pre><code class="language-Python">class TransformerDecoder(nn.Module): """Implements the decoder in DETR transformer. Args: num_layers (int): The number of `TransformerDecoderLayer`. embed_dims (int): Same as `TransformerDecoderLayer`. num_heads (int): Same as `TransformerDecoderLayer`. feedforward_channels (int): Same as `TransformerDecoderLayer`. dropout (float): Same as `TransformerDecoderLayer`. Default 0.0. order (tuple[str]): Same as `TransformerDecoderLayer`. act_cfg (dict): Same as `TransformerDecoderLayer`. Default ReLU. norm_cfg (dict): Same as `TransformerDecoderLayer`. Default layer normalization. num_fcs (int): Same as `TransformerDecoderLayer`. Default 2. """ def __init__(self, num_layers, embed_dims, num_heads, feedforward_channels, dropout=0.0, # 顺序是已经固定的,在下面进行了一个 assert 断言 order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm'), act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2, return_intermediate=False): super(TransformerDecoder, self).__init__() assert isinstance(order, tuple) and len(order) == 6 assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) self.num_layers = num_layers self.embed_dims = embed_dims self.num_heads = num_heads self.feedforward_channels = feedforward_channels self.dropout = dropout self.order = order self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.num_fcs = num_fcs self.return_intermediate = return_intermediate self.layers = nn.ModuleList() # 同样也要经过好多层 decoder 一步一步解码 for _ in range(num_layers): self.layers.append( TransformerDecoderLayer(embed_dims, num_heads, feedforward_channels, dropout, order, act_cfg, norm_cfg, num_fcs)) self.norm = build_norm_layer(norm_cfg, embed_dims)[1] def forward(self, x, # 这里传入的是全 0 的,和 query_embed 形状一样的 tensor memory, # 这里传入的是经过 encoder 编码过的特征 memory_pos=None, # 这里传入的是 encoder 中的 mask query_pos=None, # 这里传入的是 query_embed memory_attn_mask=None, target_attn_mask=None, memory_key_padding_mask=None, target_key_padding_mask=None): """Forward function for `TransformerDecoder`. Args: x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`. memory (Tensor): Same in `TransformerDecoderLayer.forward`. memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. query_pos (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. memory_attn_mask (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. target_attn_mask (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. memory_key_padding_mask (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. target_key_padding_mask (Tensor): Same in `TransformerDecoderLayer.forward`. Default None. Returns: Tensor: Results with shape [num_query, bs, embed_dims]. """ intermediate = [] for layer in self.layers: x = layer(x, memory, memory_pos, query_pos, memory_attn_mask, target_attn_mask, memory_key_padding_mask, target_key_padding_mask) if self.return_intermediate: intermediate.append(self.norm(x)) if self.norm is not None: x = self.norm(x) if self.return_intermediate: intermediate.pop() intermediate.append(x) if self.return_intermediate: return torch.stack(intermediate) return x.unsqueeze(0) </code></pre> <h3 id="transformerdecoderlayer">TransformerDecoderLayer</h3> <pre><code class="language-Python">class TransformerDecoderLayer(nn.Module): """Implements one decoder layer in DETR transformer. Args: embed_dims (int): The feature dimension. Same as `TransformerEncoderLayer`. num_heads (int): Parallel attention heads. feedforward_channels (int): Same as `TransformerEncoderLayer`. dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. order (tuple[str]): The order for decoder layer. Valid examples are ('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn'). Default the former. act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization. num_fcs (int): The number of fully-connected layers in FFNs. """ def __init__(self, embed_dims, num_heads, feedforward_channels, dropout=0.0, order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm'), act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), num_fcs=2): super(TransformerDecoderLayer, self).__init__() assert isinstance(order, tuple) and len(order) == 6 assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) self.embed_dims = embed_dims self.num_heads = num_heads self.feedforward_channels = feedforward_channels self.dropout = dropout self.order = order self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.num_fcs = num_fcs self.pre_norm = order[0] == 'norm' self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) self.multihead_attn = MultiheadAttention(embed_dims, num_heads, dropout) self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, dropout) self.norms = nn.ModuleList() # 3 norm layers in official DETR's TransformerDecoderLayer for _ in range(3): self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) def forward(self, x, memory, memory_pos=None, query_pos=None, memory_attn_mask=None, target_attn_mask=None, memory_key_padding_mask=None, target_key_padding_mask=None): """Forward function for `TransformerDecoderLayer`. Args: x (Tensor): Input query with shape [num_query, bs, embed_dims]. memory (Tensor): Tensor got from `TransformerEncoder`, with shape [num_key, bs, embed_dims]. memory_pos (Tensor): The positional encoding for `memory`. Default None. Same as `key_pos` in `MultiheadAttention.forward`. query_pos (Tensor): The positional encoding for `query`. Default None. Same as `query_pos` in `MultiheadAttention.forward`. memory_attn_mask (Tensor): ByteTensor mask for `memory`, with shape [num_key, num_key]. Same as `attn_mask` in `MultiheadAttention.forward`. Default None. target_attn_mask (Tensor): ByteTensor mask for `x`, with shape [num_query, num_query]. Same as `attn_mask` in `MultiheadAttention.forward`. Default None. memory_key_padding_mask (Tensor): ByteTensor for `memory`, with shape [bs, num_key]. Same as `key_padding_mask` in `MultiheadAttention.forward`. Default None. target_key_padding_mask (Tensor): ByteTensor for `x`, with shape [bs, num_query]. Same as `key_padding_mask` in `MultiheadAttention.forward`. Default None. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ norm_cnt = 0 inp_residual = x # 对应的是DETR 论文附录的流程图,先是 self-att,再是 cross-att for layer in self.order: if layer == 'selfattn': query = key = value = x x = self.self_attn( query, key, value, inp_residual if self.pre_norm else None, query_pos, key_pos=query_pos, attn_mask=target_attn_mask, key_padding_mask=target_key_padding_mask) inp_residual = x elif layer == 'norm': x = self.norms[norm_cnt](x) norm_cnt += 1 # 这里虽然也是调用 MultiheadAttention,但是输入的 qkv 并不同,所以不是 self-att elif layer == 'multiheadattn': query = x key = value = memory x = self.multihead_attn( query, key, value, inp_residual if self.pre_norm else None, query_pos, key_pos=memory_pos, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask) inp_residual = x elif layer == 'ffn': x = self.ffn(x, inp_residual if self.pre_norm else None) return x </code></pre> <h2 id="sinepositionalencoding">SinePositionalEncoding</h2> <p>这个类对特征图中有像素的地方生成位置编码,以减缓 transformer 丢失位置信息,在 config 中默认使用 sin 形式的位置编码,并且防止编码太大,normalize 到 0-1 之间</p> <pre><code class="language-Python">positional_encoding=dict( type='SinePositionalEncoding', num_feats=128, normalize=True), </code></pre> <p>见注释 <del>(感觉这里理解的不是很深刻)</del></p> <blockquote> <p>最后之所以要 concat x 和 y 方向上的 positional encoding 是因为单单 x 的 pe 不能使得每一个像素生成独一无二的 pe,要加上 y 方向的 pe 之后,每一个位置生成的才会是独特的 pe。(例如第一行和第二行的首元素生成的 x 方向的 pe 是一样的,但是他们在 y 方向的 pe 不一样)</p> </blockquote> <pre><code class="language-Python">@POSITIONAL_ENCODING.register_module() class SinePositionalEncoding(nn.Module): """Position encoding with sine and cosine functions. See `End-to-End Object Detection with Transformers &lt;https://arxiv.org/pdf/2005.12872&gt;`_ for details. Args: num_feats (int): The feature dimension for each position along x-axis or y-axis. Note the final returned dimension for each position is 2 times of this value. temperature (int, optional): The temperature used for scaling the position embedding. Default 10000. normalize (bool, optional): Whether to normalize the position embedding. Default False. scale (float, optional): A scale factor that scales the position embedding. The scale will be used only when `normalize` is True. Default 2*pi. eps (float, optional): A value added to the denominator for numerical stability. Default 1e-6. """ def __init__(self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6): super(SinePositionalEncoding, self).__init__() if normalize: assert isinstance(scale, (float, int)), 'when normalize is set,' \ 'scale should be provided and in float or int type, ' \ f'found {type(scale)}' self.num_feats = num_feats self.temperature = temperature self.normalize = normalize self.scale = scale self.eps = eps def forward(self, mask): """Forward function for `SinePositionalEncoding`. Args: mask (Tensor): ByteTensor mask. Non-zero values representing ignored positions, while zero values means valid positions for this image. Shape [bs, h, w]. Returns: pos (Tensor): Returned position embedding with shape [bs, num_feats*2, h, w]. """ # not_mask 就是表示有像素的地方就是 1,没有的地方就是 0 # shape:[B, H, W] not_mask = ~mask # y 方向累加,(1,1,1)-&gt;(1,2,3) # (1,1,1,...) #y_embed # (2,2,2,...) # (3,3,3,...) # (...) y_embed = not_mask.cumsum(1, dtype=torch.float32) # x 方向累加,(1,1,1)-&gt;(1,2,3) # (1,2,3,...) #x_embed # (1,2,3,...) # (1,2,3,...) # (...) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: # 将编码归一化,y_embed[:, -1:, :] -&gt; shape [B, 1, W] # 保留了 y 方向上最大的编码数,防止除以 0 加上了 eps # 然后乘上了scale,默认是 2pi,所以最终结果为 0-2pi 之间 # 列表 l[-1] 和 l[-1:] 是不一样的,前者返回一个值,后者返回只有一个值的列表 y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale # dim_t -&gt; [0,...,127], # 这个数是 FPN 通道数的一半,因为最终要 concat 在一起然后和特征图相加起来 dim_t = torch.arange( self.num_feats, dtype=torch.float32, device=mask.device) # 按照公式进行 # (2 * (dim_t // 2) -&gt; [0,0,2,2,...,126,126] # 每一个通道上的 sin 函数的周期都不一样,dim 越大周期越大,这里算出来相当于 sin(wx) 函数中的 w dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) # shape: [B, H, W, 128] pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t # 算出每一个 position 在 x 和 y 方向上的位置编码 # shape: [B, H, W, 128] pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape: [B, H, W, 256] -&gt; [B, 256, H, W] pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos </code></pre> <blockquote> <p>一些比较好理解的 Positional Encoding 的博客</p> </blockquote> <blockquote> <p>https://zhuanlan.zhihu.com/p/166244505</p> </blockquote> <blockquote> <p>https://blog.csdn.net/weixin_42715977/article/details/122135262</p> </blockquote> <h2 id="ffn-1">FFN</h2> <p>在 TransformerHead 中解码出向量之后,经过 FFN 得到每一个 query 最终的预测结果, DETR 中 FFN 的实例化如下</p> <pre><code class="language-Python">self.reg_ffn = FFN( self.embed_dims, # 256, 和 FPN 通道一样 self.embed_dims, self.num_fcs, # 2 self.act_cfg, # ReLU dropout=0.0, add_residual=False) </code></pre> <p>FFN 代码如下,就是进行两层 FC,然后输出</p> <pre><code class="language-Python">class FFN(nn.Module): """Implements feed-forward networks (FFNs) with residual connection. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. feedforward_channels (int): The hidden dimension of FFNs. num_fcs (int, optional): The number of fully-connected layers in FFNs. Defaults to 2. act_cfg (dict, optional): The activation config for FFNs. dropout (float, optional): Probability of an element to be zeroed. Default 0.0. add_residual (bool, optional): Add resudual connection. Defaults to True. """ def __init__(self, embed_dims, feedforward_channels, num_fcs=2, act_cfg=dict(type='ReLU', inplace=True), dropout=0.0, add_residual=True): super(FFN, self).__init__() assert num_fcs &gt;= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs self.act_cfg = act_cfg self.dropout = dropout self.activate = build_activation_layer(act_cfg) layers = nn.ModuleList() in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( nn.Sequential( Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(dropout))) in_channels = feedforward_channels layers.append(Linear(feedforward_channels, embed_dims)) self.layers = nn.Sequential(*layers) self.dropout = nn.Dropout(dropout) self.add_residual = add_residual def forward(self, x, residual=None): """Forward function for `FFN`.""" out = self.layers(x) if not self.add_residual: return out if residual is None: residual = x return residual + self.dropout(out) def __repr__(self): """str: a string that describes the module""" repr_str = self.__class__.__name__ repr_str += f'(embed_dims={self.embed_dims}, ' repr_str += f'feedforward_channels={self.feedforward_channels}, ' repr_str += f'num_fcs={self.num_fcs}, ' repr_str += f'act_cfg={self.act_cfg}, ' repr_str += f'dropout={self.dropout}, ' repr_str += f'add_residual={self.add_residual})' return repr_str </code></pre> <h2 id="hungarianassigner">HungarianAssigner</h2> <pre><code class="language-Python">import torch from ..builder import BBOX_ASSIGNERS from ..match_costs import build_match_cost from ..transforms import bbox_cxcywh_to_xyxy from .assign_result import AssignResult from .base_assigner import BaseAssigner try: from scipy.optimize import linear_sum_assignment except ImportError: linear_sum_assignment = None @BBOX_ASSIGNERS.register_module() class HungarianAssigner(BaseAssigner): """Computes one-to-one matching between predictions and ground truth. This class computes an assignment between the targets and the predictions based on the costs. The costs are weighted sum of three components: classification cost, regression L1 cost and regression iou cost. The targets don't include the no_object, so generally there are more predictions than targets. After the one-to-one matching, the un-matched are treated as backgrounds. Thus each query prediction will be assigned with `0` or a positive integer indicating the ground truth index: - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt Args: cls_weight (int | float, optional): The scale factor for classification cost. Default 1.0. bbox_weight (int | float, optional): The scale factor for regression L1 cost. Default 1.0. iou_weight (int | float, optional): The scale factor for regression iou cost. Default 1.0. iou_calculator (dict | optional): The config for the iou calculation. Default type `BboxOverlaps2D`. iou_mode (str | optional): "iou" (intersection over union), "iof" (intersection over foreground), or "giou" (generalized intersection over union). Default "giou". """ def __init__(self, cls_cost=dict(type='ClassificationCost', weight=1.), reg_cost=dict(type='BBoxL1Cost', weight=1.0), iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)): self.cls_cost = build_match_cost(cls_cost) self.reg_cost = build_match_cost(reg_cost) self.iou_cost = build_match_cost(iou_cost) def assign(self, bbox_pred, cls_pred, gt_bboxes, gt_labels, img_meta, gt_bboxes_ignore=None, eps=1e-7): """Computes one-to-one matching based on the weighted costs. This method assign each query prediction to a ground truth or background. The `assigned_gt_inds` with -1 means don't care, 0 means negative sample, and positive number is the index (1-based) of assigned gt. The assignment is done in the following steps, the order matters. 1. assign every prediction to -1 2. compute the weighted costs 3. do Hungarian matching on CPU based on the costs 4. assign all to 0 (background) first, then for each matched pair between predictions and gts, treat this prediction as foreground and assign the corresponding gt index (plus 1) to it. Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx, cy, w, h), which are all in range [0, 1]. Shape [num_query, 4]. cls_pred (Tensor): Predicted classification logits, shape [num_query, num_class]. gt_bboxes (Tensor): Ground truth boxes with unnormalized coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). img_meta (dict): Meta information for current image. gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are labelled as `ignored`. Default None. eps (int | float, optional): A value added to the denominator for numerical stability. Default 1e-7. Returns: :obj:`AssignResult`: The assigned result. """ assert gt_bboxes_ignore is None, \ 'Only case when gt_bboxes_ignore is None is supported.' num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) # 1. assign -1 by default # shape: num_query assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) if num_gts == 0 or num_bboxes == 0: # No ground truth or boxes, return empty assignment if num_gts == 0: # No ground truth, assign all to background assigned_gt_inds[:] = 0 return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) img_h, img_w, _ = img_meta['img_shape'] factor = gt_bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0) # 2. compute the weighted costs # classification and bboxcost. cls_cost = self.cls_cost(cls_pred, gt_labels) # regression L1 cost # L1 cost 需要归一化坐标 normalize_gt_bboxes = gt_bboxes / factor reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes) # regression iou cost, defaultly giou is used in official DETR. # IoU cost 用 GIoU 来衡量,需要 xyxy 形式的绝对坐标 bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor iou_cost = self.iou_cost(bboxes, gt_bboxes) # weighted sum of above three costs # shape: [num_query, num_gt] cost = cls_cost + reg_cost + iou_cost # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() if linear_sum_assignment is None: raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') # 根据 cost,利用匈牙利算法对每个 gt 匹配一个 query,使得被匹配的 query 的总 cost 最小 # 下一小节的 demo 有介绍 matched_row_inds, matched_col_inds = linear_sum_assignment(cost) # shape: num_gt matched_row_inds = torch.from_numpy(matched_row_inds).to( bbox_pred.device) # shape: num_gt matched_col_inds = torch.from_numpy(matched_col_inds).to( bbox_pred.device) # 4. assign backgrounds and foregrounds # assign all indices to backgrounds first # 先给每一个 query 都匹配到背景类 assigned_gt_inds[:] = 0 # assign foregrounds based on matching results # 有匹配的 query 匹配的 gt 的索引数 (从 1 开始) assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 # 有匹配的 query 负责分类的标签 assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] # 没有被匹配的 query 就被认为是背景 return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) </code></pre> <h3 id="demo">demo</h3> <p>假设下面 row 是 num_query,col 是 num_gt,那么我们要将每一个 gt 只匹配给一个 query,匹配原则是让他们的总 cost 最小,那么 2 1 2 是最优选择,坐标就是 [0,1] [1,0] [2,2],row_ind=[0,1,2],col_ind=[1,0,2]</p> <pre><code class="language-Python">cost = np.array([[4, 1, 3], [2, 0, 5], [3, 2, 2]]) from scipy.optimize import linear_sum_assignment row_ind, col_ind = linear_sum_assignment(cost) col_ind array([1, 0, 2]) cost[row_ind, col_ind].sum() 5 </code></pre> <h2 id="costfunction">CostFunction</h2> <p>detr 做二分匹配时根据 cost function 的大小来为每一个 prediction 分配标签,主要用到了三个 cost function,都在 <em><code class="language-plaintext highlighter-rouge">mmdet/core/bbox/match_costs.py</code></em> 里面 ,下面介绍。</p> <h3 id="classificationcost">ClassificationCost</h3> <p>见注释</p> <pre><code class="language-Python">MATCH_COST.register_module() class ClassificationCost(object): """ClsSoftmaxCost. Args: weight (int | float, optional): loss_weight Examples: &gt;&gt;&gt; from mmdet.core.bbox.match_costs.match_cost import \ ... ClassificationCost &gt;&gt;&gt; import torch &gt;&gt;&gt; self = ClassificationCost() &gt;&gt;&gt; cls_pred = torch.rand(4, 3) &gt;&gt;&gt; gt_labels = torch.tensor([0, 1, 2]) &gt;&gt;&gt; factor = torch.tensor([10, 8, 10, 8]) &gt;&gt;&gt; self(cls_pred, gt_labels) tensor([[-0.3430, -0.3525, -0.3045], [-0.3077, -0.2931, -0.3992], [-0.3664, -0.3455, -0.2881], [-0.3343, -0.2701, -0.3956]]) """ def __init__(self, weight=1.): self.weight = weight def __call__(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classification logits, shape [num_query, num_class]. gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). Returns: torch.Tensor: cls_cost value with weight """ # Following the official DETR repo, contrary to the loss that # NLL is used, we approximate it in 1 - cls_score[gt_label]. # The 1 is a constant that doesn't change the matching, # so it can be omitted. # shape: [num_query, num_class] cls_score = cls_pred.softmax(-1) # shape: [num_query, num_gt] # 返回每一个 prediction 对每一个 gt_label 的 cost,越小代表得分越高 cls_cost = -cls_score[:, gt_labels] return cls_cost * self.weight </code></pre> <h3 id="ioucost">IoUCost</h3> <p>见注释</p> <pre><code class="language-Python">@MATCH_COST.register_module() class IoUCost(object): """IoUCost. Args: iou_mode (str, optional): iou mode such as 'iou' | 'giou' weight (int | float, optional): loss weight Examples: &gt;&gt;&gt; from mmdet.core.bbox.match_costs.match_cost import IoUCost &gt;&gt;&gt; import torch &gt;&gt;&gt; self = IoUCost() &gt;&gt;&gt; bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) &gt;&gt;&gt; gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) &gt;&gt;&gt; self(bboxes, gt_bboxes) tensor([[-0.1250, 0.1667], [ 0.1667, -0.5000]]) """ def __init__(self, iou_mode='giou', weight=1.): self.weight = weight self.iou_mode = iou_mode def __call__(self, bboxes, gt_bboxes): """ Args: bboxes (Tensor): Predicted boxes with unnormalized coordinates (x1, y1, x2, y2). Shape [num_query, 4]. gt_bboxes (Tensor): Ground truth boxes with unnormalized coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. Returns: torch.Tensor: iou_cost value with weight """ # overlaps: [num_query, num_gt] # 返回一一配对的 IoU overlaps = bbox_overlaps( bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) # The 1 is a constant that doesn't change the matching, so omitted. # IoU 越大,cost 越小 iou_cost = -overlaps return iou_cost * self.weight </code></pre> <h3 id="bboxl1cost">BBoxL1Cost</h3> <p>见注释</p> <pre><code class="language-Python">@MATCH_COST.register_module() class BBoxL1Cost(object): """BBoxL1Cost. Args: weight (int | float, optional): loss_weight box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN Examples: &gt;&gt;&gt; from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost &gt;&gt;&gt; import torch &gt;&gt;&gt; self = BBoxL1Cost() &gt;&gt;&gt; bbox_pred = torch.rand(1, 4) &gt;&gt;&gt; gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) &gt;&gt;&gt; factor = torch.tensor([10, 8, 10, 8]) &gt;&gt;&gt; self(bbox_pred, gt_bboxes, factor) tensor([[1.6172, 1.6422]]) """ def __init__(self, weight=1., box_format='xyxy'): self.weight = weight assert box_format in ['xyxy', 'xywh'] self.box_format = box_format def __call__(self, bbox_pred, gt_bboxes): """ Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx, cy, w, h), which are all in range [0, 1]. Shape [num_query, 4]. gt_bboxes (Tensor): Ground truth boxes with normalized coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. Returns: torch.Tensor: bbox_cost value with weight """ # 注意这个是经过缩放的坐标,是 0-1 范围的 if self.box_format == 'xywh': gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) elif self.box_format == 'xyxy': bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) # 每一个 prediction box 到 gt_box 的距离,越大说明 cost 越大 # shape: [num_query, num_gt] bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) return bbox_cost * self.weight </code></pre> <h2 id="reference">reference</h2> <p>https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html#scipy.optimize.linear_sum_assignment</p> <p>https://zhuanlan.zhihu.com/p/348060767</p> <p>https://zhuanlan.zhihu.com/p/345985277</p> Wed, 30 Mar 2022 00:00:00 +0000 https://yarkable.github.io/2022/03/30/mmdetection%E4%B9%8BDETR%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ https://yarkable.github.io/2022/03/30/mmdetection%E4%B9%8BDETR%E6%B3%A8%E9%87%8A%E8%AF%A6%E8%A7%A3/ linux object detection deep learning mmdetection