From 106a0f1be88eb2a23df015e35d9b6c2e1ba75419 Mon Sep 17 00:00:00 2001 From: Igor Ryabchikov Date: Sun, 28 Nov 2021 21:33:57 +0300 Subject: [PATCH] fairmot tracker corrections --- docker-compose.yml | 6 +- test/tracking_test/docker-compose.yml | 3 +- tracking/docker-build-context/Dockerfile | 15 +- .../byte_track/requirements.txt | 4 +- .../fairmot/DCNv2/setup.py | 2 +- .../docker-build-context/fairmot/setup.py | 54 +---- .../fairmot/src/lib/tracker/multitracker.py | 97 ++++++--- tracking/run.py | 184 ++++++++++++++---- 8 files changed, 251 insertions(+), 114 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index e98e494..126c0d7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -120,8 +120,10 @@ services: runtime: nvidia environment: - NVIDIA_VISIBLE_DEVICES=${DEVBEH_TRACKING_NVIDIA_VISIBLE_DEVICES:-0} +# https://drive.google.com/u/0/uc?export=download&confirm=2ySX&id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U - bytrack command: 'python run.py tracking --log_dir /opt/tracking/logs - --weights_url ${DEVBEH_POSE3D_WEIGHTS_URL:-https://drive.google.com/u/0/uc?export=download&confirm=2ySX&id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U} + --fairmot_weights_url ${DEVBEH_FAIRMOT_WEIGHTS_URL:-https://drive.google.com/u/0/uc?export=download&confirm=VHcw&id=1iqRQjsG9BawIl8SlFomMg5iwkb6nqSpi} + --use_fairmot --bootstrap_servers ${DEVBEH_KAFKA_BOOTSTRAP_SERVERS:-localhost:9095} --username "${DEVBEH_KAFKA_USERNAME:-}" --password "${DEVBEH_KAFKA_PASSWORD:-}" @@ -135,4 +137,4 @@ volumes: zookeeper_data: driver: local kafka_data: - driver: local + driver: local \ No newline at end of file diff --git a/test/tracking_test/docker-compose.yml b/test/tracking_test/docker-compose.yml index ffad694..e191d27 100644 --- a/test/tracking_test/docker-compose.yml +++ b/test/tracking_test/docker-compose.yml @@ -55,7 +55,8 @@ services: environment: - NVIDIA_VISIBLE_DEVICES=${DEVBEH_TRACKING_NVIDIA_VISIBLE_DEVICES:-0} command: 'python run.py tracking --log_dir /opt/tracking/logs - --weights_url ${DEVBEH_TRACKING_WEIGHTS_URL:-https://drive.google.com/u/0/uc?export=download&confirm=2ySX&id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U} + --fairmot_weights_url ${DEVBEH_FAIRMOT_WEIGHTS_URL:-https://drive.google.com/u/0/uc?export=download&confirm=VHcw&id=1iqRQjsG9BawIl8SlFomMg5iwkb6nqSpi} + --use_fairmot --frames_commit_latency 100 --bootstrap_servers kafka:29097' networks: diff --git a/tracking/docker-build-context/Dockerfile b/tracking/docker-build-context/Dockerfile index 4a8a780..5e5473e 100644 --- a/tracking/docker-build-context/Dockerfile +++ b/tracking/docker-build-context/Dockerfile @@ -9,12 +9,25 @@ RUN apt install -y libcurl4-openssl-dev libssl-dev libcairo2-dev libgirepository # cv2 RUN apt install -y ffmpeg libsm6 libxext6 libxrender-dev +RUN pip install torch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 + # Install byte_track COPY byte_track ./byte_track WORKDIR /opt/tracking/install/byte_track RUN pip install -r requirements.txt RUN python setup.py develop -RUN pip install cython; pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' +RUN pip install cython +RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' + +# Install fairmot +WORKDIR /opt/tracking/install +COPY fairmot ./fairmot +WORKDIR /opt/tracking/install/fairmot +RUN pip install -r requirements.txt +RUN python setup.py build develop + +WORKDIR /opt/tracking/install/fairmot/DCNv2 +RUN ./make.sh WORKDIR /opt/tracking/install COPY additional-requirements.txt ./ diff --git a/tracking/docker-build-context/byte_track/requirements.txt b/tracking/docker-build-context/byte_track/requirements.txt index 14c7ac8..dc778cd 100644 --- a/tracking/docker-build-context/byte_track/requirements.txt +++ b/tracking/docker-build-context/byte_track/requirements.txt @@ -1,11 +1,11 @@ # TODO: Update with exact module version numpy -torch>=1.8 +# torch>=1.8 opencv_python loguru scikit-image tqdm -torchvision>=0.10.0 +# torchvision>=0.10.0 Pillow thop ninja diff --git a/tracking/docker-build-context/fairmot/DCNv2/setup.py b/tracking/docker-build-context/fairmot/DCNv2/setup.py index 887cce2..d3d029d 100644 --- a/tracking/docker-build-context/fairmot/DCNv2/setup.py +++ b/tracking/docker-build-context/fairmot/DCNv2/setup.py @@ -35,7 +35,7 @@ def get_extensions(): "-D__CUDA_NO_HALF2_OPERATORS__", ] else: - # raise NotImplementedError('Cuda is not available') + raise NotImplementedError('Cuda is not available') pass sources = [os.path.join(extensions_dir, s) for s in sources] diff --git a/tracking/docker-build-context/fairmot/setup.py b/tracking/docker-build-context/fairmot/setup.py index c33ef17..6a3f38b 100644 --- a/tracking/docker-build-context/fairmot/setup.py +++ b/tracking/docker-build-context/fairmot/setup.py @@ -1,10 +1,11 @@ #!/usr/bin/python from __future__ import print_function + import os -import sys -import re import os.path as op +import sys + from setuptools import find_packages, setup # change directory to this module path @@ -23,53 +24,16 @@ def readme(fname): return open(op.join(script_dir, fname)).read() -def find_version(fname): - version_file = readme(fname) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - version_file, re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") - - setup( - name="graphormer", - version=find_version("src/__init__.py"), - description="graphormer", + name="fairmot", + version="0.1", + description="fairmot", long_description=readme('README.md'), - packages=find_packages(), + package_dir = {'': 'src/lib'}, + # packages=find_packages(), classifiers=[ 'Intended Audience :: Developers', "Programming Language :: Python", 'Topic :: Software Development', ] -) - - -# https://dzone.com/articles/executable-package-pip-install -# python setup.py bdist_wheel -# python -m pip install dist/dokr-0.1-py3-none-any.whl - -# import setuptools -# -# with open("README.md", "r") as fh: -# long_description = fh.read() -# -# -# setuptools.setup( -# name='dokr', -# version='0.1', -# scripts=['dokr'] , -# author="Deepak Kumar", -# author_email="deepak.kumar.iet@gmail.com", -# description="A Docker and AWS utility package", -# long_description=long_description, -# long_description_content_type="text/markdown", -# url="https://github.com/javatechy/dokr", -# packages=setuptools.find_packages(), -# classifiers=[ -# "Programming Language :: Python :: 3", -# "License :: OSI Approved :: MIT License", -# "Operating System :: OS Independent", -# ], -# ) \ No newline at end of file +) \ No newline at end of file diff --git a/tracking/docker-build-context/fairmot/src/lib/tracker/multitracker.py b/tracking/docker-build-context/fairmot/src/lib/tracker/multitracker.py index 6dd8b53..26ee99c 100644 --- a/tracking/docker-build-context/fairmot/src/lib/tracker/multitracker.py +++ b/tracking/docker-build-context/fairmot/src/lib/tracker/multitracker.py @@ -1,27 +1,24 @@ +import numpy as np +from collections import deque import itertools import os import os.path as osp import time -from collections import deque - -import cv2 -import numpy as np import torch +import cv2 import torch.nn.functional as F -from models import * -from models.decode import mot_decode + from models.model import create_model, load_model -from models.utils import _tranpose_and_gather_feat -from tracking_utils.kalman_filter import KalmanFilter -from tracking_utils.log import logger +from models.decode import mot_decode from tracking_utils.utils import * -from utils.image import get_affine_transform -from utils.post_process import ctdet_post_process - +from tracking_utils.log import logger +from tracking_utils.kalman_filter import KalmanFilter +from models import * from tracker import matching - from .basetrack import BaseTrack, TrackState - +from utils.post_process import ctdet_post_process +from utils.image import get_affine_transform +from models.utils import _tranpose_and_gather_feat class STrack(BaseTrack): shared_kalman = KalmanFilter() @@ -34,6 +31,7 @@ class STrack(BaseTrack): self.is_activated = False self.score = score + self.score_list = [] self.tracklet_len = 0 self.smooth_feat = None @@ -83,6 +81,7 @@ class STrack(BaseTrack): #self.is_activated = True self.frame_id = frame_id self.start_frame = frame_id + self.score_list.append(self.score) def re_activate(self, new_track, frame_id, new_id=False): self.mean, self.covariance = self.kalman_filter.update( @@ -96,6 +95,8 @@ class STrack(BaseTrack): self.frame_id = frame_id if new_id: self.track_id = self.next_id() + self.score = new_track.score + self.score_list.append(self.score) def update(self, new_track, frame_id, update_feature=True): """ @@ -115,10 +116,12 @@ class STrack(BaseTrack): self.is_activated = True self.score = new_track.score + self.score_list.append(self.score) if update_feature: self.update_features(new_track.curr_feat) @property + # @jit(nopython=True) def tlwh(self): """Get current position in bounding box format `(top left x, top left y, width, height)`. @@ -131,6 +134,7 @@ class STrack(BaseTrack): return ret @property + # @jit(nopython=True) def tlbr(self): """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., `(top left, bottom right)`. @@ -140,6 +144,7 @@ class STrack(BaseTrack): return ret @staticmethod + # @jit(nopython=True) def tlwh_to_xyah(tlwh): """Convert bounding box to format `(center x, center y, aspect ratio, height)`, where the aspect ratio is `width / height`. @@ -153,12 +158,14 @@ class STrack(BaseTrack): return self.tlwh_to_xyah(self.tlwh) @staticmethod + # @jit(nopython=True) def tlbr_to_tlwh(tlbr): ret = np.asarray(tlbr).copy() ret[2:] -= ret[:2] return ret @staticmethod + # @jit(nopython=True) def tlwh_to_tlbr(tlwh): ret = np.asarray(tlwh).copy() ret[2:] += ret[:2] @@ -169,24 +176,17 @@ class STrack(BaseTrack): class JDETracker(object): - def __init__(self, opt, frame_rate=30): + def __init__(self, opt, model, frame_rate=30): self.opt = opt - if opt.gpus[0] >= 0: - opt.device = torch.device('cuda') - else: - opt.device = torch.device('cpu') - print('Creating model...') - self.model = create_model(opt.arch, opt.heads, opt.head_conv) - self.model = load_model(self.model, opt.load_model) - self.model = self.model.to(opt.device) - self.model.eval() + self.model = model self.tracked_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack] self.frame_id = 0 - self.det_thresh = opt.conf_thres + #self.det_thresh = opt.conf_thres + self.det_thresh = opt.conf_thres + 0.1 self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer) self.max_time_lost = self.buffer_size self.max_per_image = opt.K @@ -256,6 +256,12 @@ class JDETracker(object): dets = self.merge_outputs([dets])[1] remain_inds = dets[:, 4] > self.opt.conf_thres + inds_low = dets[:, 4] > 0.2 + #inds_low = dets[:, 4] > self.opt.conf_thres + inds_high = dets[:, 4] < self.opt.conf_thres + inds_second = np.logical_and(inds_low, inds_high) + dets_second = dets[inds_second] + id_feature_second = id_feature[inds_second] dets = dets[remain_inds] id_feature = id_feature[remain_inds] @@ -290,13 +296,12 @@ class JDETracker(object): ''' Step 2: First association, with embedding''' strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) # Predict the current location with KF - #for strack in strack_pool: - #strack.predict() STrack.multi_predict(strack_pool) dists = matching.embedding_distance(strack_pool, detections) + #dists = matching.fuse_iou(dists, strack_pool, detections) #dists = matching.iou_distance(strack_pool, detections) dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections) - matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.4) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.opt.match_thres) for itracked, idet in matches: track = strack_pool[itracked] @@ -324,8 +329,29 @@ class JDETracker(object): track.re_activate(det, self.frame_id, new_id=False) refind_stracks.append(track) + # association the untrack to the low score detections + if len(dets_second) > 0: + '''Detections''' + detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for + (tlbrs, f) in zip(dets_second[:, :5], id_feature_second)] + else: + detections_second = [] + second_tracked_stracks = [r_tracked_stracks[i] for i in u_track if r_tracked_stracks[i].state == TrackState.Tracked] + dists = matching.iou_distance(second_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4) + for itracked, idet in matches: + track = second_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + for it in u_track: - track = r_tracked_stracks[it] + #track = r_tracked_stracks[it] + track = second_tracked_stracks[it] if not track.state == TrackState.Lost: track.mark_lost() lost_stracks.append(track) @@ -365,6 +391,7 @@ class JDETracker(object): self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) self.removed_stracks.extend(removed_stracks) self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + #self.tracked_stracks = remove_fp_stracks(self.tracked_stracks) # get scores of lost tracks output_stracks = [track for track in self.tracked_stracks if track.is_activated] @@ -416,3 +443,15 @@ def remove_duplicate_stracks(stracksa, stracksb): resa = [t for i, t in enumerate(stracksa) if not i in dupa] resb = [t for i, t in enumerate(stracksb) if not i in dupb] return resa, resb + + +def remove_fp_stracks(stracksa, n_frame=10): + remain = [] + for t in stracksa: + score_5 = t.score_list[-n_frame:] + score_5 = np.array(score_5, dtype=np.float32) + index = score_5 < 0.45 + num = np.sum(index) + if num < n_frame: + remain.append(t) + return remain \ No newline at end of file diff --git a/tracking/run.py b/tracking/run.py index de1d6b6..e4ab30b 100644 --- a/tracking/run.py +++ b/tracking/run.py @@ -7,6 +7,9 @@ from unittest import mock import cv2 import numpy as np +import torch +from models.model import create_model, load_model +from tracker.multitracker import JDETracker, STrack as FairmotSTrack from yolox.tracker.byte_tracker import BYTETracker, STrack from common import download_file, logging_levels @@ -14,30 +17,34 @@ from common.communication.kafka_common import KafkaRequestProcessor from common.communication.messages_pb2 import PerceptionRequest, PerceptionResponse, TrackingEntry, DetectionEntry CURRENT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) +PERSON_CLASS = 1 # noinspection DuplicatedCode def process_tracking_requests(requests: Iterable[PerceptionRequest], results: Iterable[PerceptionResponse], - tracker_factory, trackers): + tracker_factory, fairmot_tracker_factory, use_fairmot, trackers): for req, res in zip(requests, results): if req.finished: - track(req.manager_id, req.video_id, tracker_factory, trackers, None, None) + track(req.manager_id, req.video_id, tracker_factory, fairmot_tracker_factory, use_fairmot, trackers, None, + None) else: image = np.fromstring(req.tracking.image, np.uint8) image = cv2.imdecode(image, cv2.IMREAD_ANYCOLOR) logging.info('Processing image of shape %s', image.shape) - tracking_entries = track(req.manager_id, req.video_id, tracker_factory, trackers, image, - req.tracking.entries) + tracking_entries = track(req.manager_id, req.video_id, tracker_factory, fairmot_tracker_factory, + use_fairmot, trackers, image, req.tracking.entries) for tracking_entry in tracking_entries: res.tracking.entries.append(tracking_entry) -def track(manager_id, video_id, tracker_factory, trackers, image, detections): +def track(manager_id, video_id, tracker_factory, fairmot_tracker_factory, use_fairmot, trackers, image, detections): """ Обрабатывает очередной кадр видеозаписи, сопоставляя объекты с прошлым кадром. :param manager_id: идентификатор менеджера, передающего видепоток :param video_id: идентификатор видеопотока :param tracker_factory: фабрика для создания трекеров + :param fairmot_tracker_factory: фабрика для создания трекеров fairmot или None + :param use_fairmot: следует ли использовать fairmot для отслеживания людей :param trackers: хранит трекеры для видеозаписей {(manager_id, video_id): tracker} :param image: изображение или None, если видеопоток завершился :param detections: набор DetectionEntry @@ -51,34 +58,89 @@ def track(manager_id, video_id, tracker_factory, trackers, image, detections): return [] if tracker_key not in trackers.keys(): - trackers[tracker_key] = tracker_factory() + byte_tracker = tracker_factory() + fairmot_tracker = fairmot_tracker_factory() if use_fairmot else None + trackers[tracker_key] = (byte_tracker, fairmot_tracker) + + detection_info_list = [] + detection_classes_list = [] + fairmot_detection_info_list = [] - detection_info = np.zeros((len(detections), 5), dtype=float) - detection_classes = np.zeros((len(detections))) for i, detection in enumerate(detections): - detection_info[i, 0] = detection.box_top_left_x - detection_info[i, 1] = detection.box_top_left_y - detection_info[i, 2] = detection.box_bottom_right_x - detection_info[i, 3] = detection.box_bottom_right_y - detection_info[i, 4] = detection.score - detection_classes[i] = detection.class_id - - tracker: BYTETracker = trackers[tracker_key] + if use_fairmot and detection.class_id == PERSON_CLASS: + fairmot_detection_info_list.append(np.expand_dims(np.array([detection.box_top_left_x, + detection.box_top_left_y, + detection.box_bottom_right_x, + detection.box_bottom_right_y, + detection.score], dtype=float), axis=0)) + else: + detection_classes_list.append(detection.class_id) + detection_info_list.append(np.expand_dims(np.array([detection.box_top_left_x, + detection.box_top_left_y, + detection.box_bottom_right_x, + detection.box_bottom_right_y, + detection.score], dtype=float), axis=0)) + + fairmot_detection_info = np.concatenate(fairmot_detection_info_list) if len(fairmot_detection_info_list) > 0 else np.zeros((0, 5), dtype=float) + detection_info = np.concatenate(detection_info_list) if len(detection_info_list) > 0 else np.zeros((0, 5), dtype=float) + detection_classes = np.array(detection_classes_list) + + tracker: BYTETracker = trackers[tracker_key][0] height_width = (image.shape[0], image.shape[1]) tracks: Iterable[STrack] = tracker.update(detection_info, height_width, height_width, classes=detection_classes) result = [] for t in tracks: tlbr = t.tlbr - entry = TrackingEntry(id=t.track_id) + entry = TrackingEntry(id=correct_track_id(t.track_id, use_fairmot)) entry.detection.CopyFrom(DetectionEntry(box_top_left_x=int(tlbr[0]), box_top_left_y=int(tlbr[1]), box_bottom_right_x=int(tlbr[2]), box_bottom_right_y=int(tlbr[3]), class_id=int(t.class_id), score=t.score)) result.append(entry) + + if use_fairmot: + img0 = image + img, _, _, _ = letterbox(img0) + + # Normalize RGB + img = img[:, :, ::-1].transpose(2, 0, 1) + img = np.ascontiguousarray(img, dtype=np.float32) + img /= 255.0 + + blob = torch.from_numpy(img).cuda().unsqueeze(0) + + fairmot_tracker: JDETracker = trackers[tracker_key][1] + fairmot_tracks: Iterable[FairmotSTrack] = fairmot_tracker.update(blob, img0) + + for t in fairmot_tracks: + tlbr = t.tlbr + entry = TrackingEntry(id=t.track_id * 2) + entry.detection.CopyFrom(DetectionEntry(box_top_left_x=int(tlbr[0]), box_top_left_y=int(tlbr[1]), + box_bottom_right_x=int(tlbr[2]), box_bottom_right_y=int(tlbr[3]), + class_id=PERSON_CLASS, score=t.score)) + result.append(entry) + return result -def get_tracker_factory(weights_path, fps): +def letterbox(img, height=608, width=1088, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded rectangular + shape = img.shape[:2] # shape = [height, width] + ratio = min(float(height) / shape[0], float(width) / shape[1]) + new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # new_shape = [width, height] + dw = (width - new_shape[0]) / 2 # width padding + dh = (height - new_shape[1]) / 2 # height padding + top, bottom = round(dh - 0.1), round(dh + 0.1) + left, right = round(dw - 0.1), round(dw + 0.1) + img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded rectangular + return img, ratio, dw, dh + + +def correct_track_id(id, use_fairmot): + return id * 2 + 1 if use_fairmot else id + + +def get_tracker_factory(fps): return lambda: create_tracker(fps) @@ -92,18 +154,70 @@ def create_tracker(fps): return BYTETracker(args, fps) +def get_fairmot_tracker_factory(fairmot_weights_path, fps): + device = torch.device('cuda') + ltrb = True + reid_dim = 128 + num_classes = 1 + arch = 'dla_34' + heads = {'hm': num_classes, + 'wh': 2 if not ltrb else 4, + 'id': reid_dim} + heads.update({'reg': 2}) + head_conv = 256 + + model = create_model(arch, heads, head_conv) + model = load_model(model, fairmot_weights_path) + model = model.to(device) + model.eval() + + return lambda: create_fairmot_tracker(model, fps) + + +def create_fairmot_tracker(model, fps): + opt = mock.Mock() + # mean и std не используются + opt.mean = [0.408, 0.447, 0.470] + opt.std = [0.289, 0.274, 0.278] + opt.img_size = (1088, 608) # input w,h + opt.K = 500 + # 0.4 - дефолт. Попробовать + opt.conf_thres = 0.6 + # todo: попробовать другие параметры (30 - default) + opt.track_buffer = 150 + opt.reid_dim = 128 + opt.num_classes = 1 + opt.nID = 14455 + # Во сколько раз высота и ширина выходного набора свойств меньше входного изображения + opt.down_ratio = 4 + # regress left, top, right, bottom of bbox. Т.е. это повышение точности детектировани рамок при помощи отдельной head. + # В нашем случае не важно, так как мы уже предоставляем результаты детектирования + opt.ltrb = True + opt.reg_offset = True + opt.input_w = opt.img_size[0] + opt.input_h = opt.img_size[1] + opt.output_h = opt.input_h // opt.down_ratio + opt.output_w = opt.input_w // opt.down_ratio + opt.input_res = max(opt.input_h, opt.input_w) + opt.output_res = max(opt.output_h, opt.output_w) + opt.match_thres = 0.4 + + return JDETracker(opt, model, frame_rate=fps) + + # noinspection DuplicatedCode def run(parser: argparse.ArgumentParser): parser.add_argument('--name', default='tracking', help='instance name used as part of the client id') parser.add_argument('--log_dir', help='logs dir path', required=True) parser.add_argument('--logging_level', dest='logging_level', choices=logging_levels.keys(), default='INFO', help='logging level. One of ' + str(logging_levels.keys())) - parser.add_argument('--weights_path', default='',# todo + parser.add_argument('--fairmot_weights_path', default='',# todo help='path to the model weights file. Default: ' - 'PROJECT_DIR_PATH/tracking/models/bytetrack_x_mot20.tar') - parser.add_argument('--weights_url', default='', help='url to the model weights file that ' - 'should be downloaded and placed to the ' - 'weights_path if it does not already exist') + 'PROJECT_DIR_PATH/tracking/models/fairmot_dla34.pth') + parser.add_argument('--fairmot_weights_url', default='', help='url to the model weights file that ' + 'should be downloaded and placed to the ' + 'fairmot_weights_path if it does not already exist') + parser.add_argument('--use_fairmot', action='store_true', help='enables application of fairmot for people tracking') parser.add_argument('--bootstrap_servers', type=str, required=True, help='comma separated kafka bootstrap servers. Example: kafka1.local:9092,kafka2.local:9092') parser.add_argument('--username', type=str, default='', help='kafka SASL username or empty str') @@ -121,19 +235,21 @@ def run(parser: argparse.ArgumentParser): logging.basicConfig(format="%(asctime)s: %(levelname)s - %(message)s", filename=os.path.join(args.log_dir, 'tracking.log'), level=logging_levels[args.logging_level]) - weights_path = args.weights_path if args.weights_path else \ - os.path.join(CURRENT_DIR_PATH, 'models/bytetrack_x_mot20.tar') + fairmot_weights_path = args.fairmot_weights_path if args.fairmot_weights_path else \ + os.path.join(CURRENT_DIR_PATH, 'models/fairmot_dla34.pth') - if os.path.exists(weights_path) and not os.path.isfile(weights_path): - raise ValueError("weights_path '{}' must denote a file or mustn't exist".format(weights_path)) + if os.path.exists(fairmot_weights_path) and not os.path.isfile(fairmot_weights_path): + raise ValueError("fairmot_weights_path '{}' must denote a file or mustn't exist".format(fairmot_weights_path)) - if not os.path.exists(weights_path): - logging.info("Downloading model weights from '{}' to '{}'".format(args.weights_url, weights_path)) - download_file(args.weights_url, weights_path) - logging.info('Model weights download finished') + if not os.path.exists(fairmot_weights_path): + logging.info("Downloading fairmot model weights from '{}' to '{}'".format(args.fairmot_weights_url, + fairmot_weights_path)) + download_file(args.fairmot_weights_url, fairmot_weights_path) + logging.info('Fairmot model weights download finished') - logging.info("Initializing tracking model") - tracker_factory = get_tracker_factory(weights_path, args.fps) + logging.info("Initializing tracking models") + tracker_factory = get_tracker_factory(args.fps) + fairmot_tracker_factory = get_fairmot_tracker_factory(fairmot_weights_path, args.fps) if args.use_fairmot else None logging.info("Model was initialized") trackers = {} @@ -147,6 +263,8 @@ def run(parser: argparse.ArgumentParser): stream_max_idle_time_ms=600_000, frames_commit_latency=args.frames_commit_latency, request_handler=lambda req, res: process_tracking_requests(req, res, tracker_factory, + fairmot_tracker_factory, + args.use_fairmot, trackers)) logging.info("Starting tracking service '%s'", args.name) -- GitLab