commit c8504c8e3458ad3ecc0242f5670896c15a8bf3fd
Author: zhurui <274461951@qq.com>
Date: Thu Jul 4 17:00:21 2024 +0800
first commit
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7bc69e3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,22 @@
+work_dirs/
+predicts/
+output/
+data/
+data
+
+__pycache__/
+*/*.un~
+.*.swp
+
+
+
+*.egg-info/
+*.egg
+
+output.txt
+.vscode/*
+.DS_Store
+tmp.*
+*.pt
+*.pth
+*.un~
diff --git a/INSTALL.md b/INSTALL.md
new file mode 100644
index 0000000..63025f8
--- /dev/null
+++ b/INSTALL.md
@@ -0,0 +1,74 @@
+
+# Install
+
+1. Clone the RESA repository
+ ```
+ git clone https://github.com/zjulearning/resa.git
+ ```
+ We call this directory as `$RESA_ROOT`
+
+2. Create a conda virtual environment and activate it (conda is optional)
+
+ ```Shell
+ conda create -n resa python=3.8 -y
+ conda activate resa
+ ```
+
+3. Install dependencies
+
+ ```Shell
+ # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
+ conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
+
+ # Or you can install via pip
+ pip install torch torchvision
+
+ # Install python packages
+ pip install -r requirements.txt
+ ```
+
+4. Data preparation
+
+ Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
+
+ ```Shell
+ cd $RESA_ROOT
+ ln -s $CULANEROOT data/CULane
+ ln -s $TUSIMPLEROOT data/tusimple
+ ```
+
+ For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
+
+ ```Shell
+ python scripts/convert_tusimple.py --root $TUSIMPLEROOT
+ # this will generate segmentations and two list files: train_gt.txt and test.txt
+ ```
+
+ For CULane, you should have structure like this:
+ ```
+ $RESA_ROOT/data/CULane/driver_xx_xxframe # data folders x6
+ $RESA_ROOT/data/CULane/laneseg_label_w16 # lane segmentation labels
+ $RESA_ROOT/data/CULane/list # data lists
+ ```
+
+ For Tusimple, you should have structure like this:
+ ```
+ $RESA_ROOT/data/tusimple/clips # data folders
+ $RESA_ROOT/data/tusimple/lable_data_xxxx.json # label json file x4
+ $RESA_ROOT/data/tusimple/test_tasks_0627.json # test tasks json file
+ $RESA_ROOT/data/tusimple/test_label.json # test label json file
+ ```
+
+5. Install CULane evaluation tools.
+
+ This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
+
+
+ Then compile the evaluation tool of CULane.
+ ```Shell
+ cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
+ make
+ cd -
+ ```
+
+ Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..8df642d
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021 Tu Zheng
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c2063f6
--- /dev/null
+++ b/README.md
@@ -0,0 +1,148 @@
+# RESA
+PyTorch implementation of the paper "[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719)".
+
+Our paper has been accepted by AAAI2021.
+
+**News**: We also release RESA on [LaneDet](https://github.com/Turoad/lanedet). It's also recommended for you to try LaneDet.
+
+## Introduction
+![intro](intro.png "intro")
+- RESA shifts sliced
+feature map recurrently in vertical and horizontal directions
+and enables each pixel to gather global information.
+- RESA achieves SOTA results on CULane and Tusimple Dataset.
+
+## Get started
+1. Clone the RESA repository
+ ```
+ git clone https://github.com/zjulearning/resa.git
+ ```
+ We call this directory as `$RESA_ROOT`
+
+2. Create a conda virtual environment and activate it (conda is optional)
+
+ ```Shell
+ conda create -n resa python=3.8 -y
+ conda activate resa
+ ```
+
+3. Install dependencies
+
+ ```Shell
+ # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
+ conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
+
+ # Or you can install via pip
+ pip install torch torchvision
+
+ # Install python packages
+ pip install -r requirements.txt
+ ```
+
+4. Data preparation
+
+ Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
+
+ ```Shell
+ cd $RESA_ROOT
+ mkdir -p data
+ ln -s $CULANEROOT data/CULane
+ ln -s $TUSIMPLEROOT data/tusimple
+ ```
+
+ For CULane, you should have structure like this:
+ ```
+ $CULANEROOT/driver_xx_xxframe # data folders x6
+ $CULANEROOT/laneseg_label_w16 # lane segmentation labels
+ $CULANEROOT/list # data lists
+ ```
+
+ For Tusimple, you should have structure like this:
+ ```
+ $TUSIMPLEROOT/clips # data folders
+ $TUSIMPLEROOT/lable_data_xxxx.json # label json file x4
+ $TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
+ $TUSIMPLEROOT/test_label.json # test label json file
+
+ ```
+
+ For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
+
+ ```Shell
+ python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
+ # this will generate seg_label directory
+ ```
+
+5. Install CULane evaluation tools.
+
+ This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
+
+
+ Then compile the evaluation tool of CULane.
+ ```Shell
+ cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
+ make
+ cd -
+ ```
+
+ Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
+
+
+## Training
+
+For training, run
+
+```Shell
+python main.py [configs/path_to_your_config] --gpus [gpu_ids]
+```
+
+
+For example, run
+```Shell
+python main.py configs/culane.py --gpus 0 1 2 3
+```
+
+## Testing
+For testing, run
+```Shell
+python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num]
+```
+
+For example, run
+```Shell
+python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3
+
+python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3
+```
+
+
+We provide two trained ResNet models on CULane and Tusimple, downloading our best performed model (Tusimple: [GoogleDrive](https://drive.google.com/file/d/1M1xi82y0RoWUwYYG9LmZHXWSD2D60o0D/view?usp=sharing)/[BaiduDrive(code:s5ii)](https://pan.baidu.com/s/1CgJFrt9OHe-RUNooPpHRGA),
+CULane: [GoogleDrive](https://drive.google.com/file/d/1pcqq9lpJ4ixJgFVFndlPe42VgVsjgn0Q/view?usp=sharing)/[BaiduDrive(code:rlwj)](https://pan.baidu.com/s/1ODKAZxpKrZIPXyaNnxcV3g)
+)
+
+## Visualization
+Just add `--view`.
+
+For example:
+```Shell
+python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 --view
+```
+You will get the result in the directory: `work_dirs/[DATASET]/xxx/vis`.
+
+## Citation
+If you use our method, please consider citing:
+```BibTeX
+@inproceedings{zheng2021resa,
+ title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection},
+ author={Zheng, Tu and Fang, Hao and Zhang, Yi and Tang, Wenjian and Yang, Zheng and Liu, Haifeng and Cai, Deng},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={35},
+ number={4},
+ pages={3547--3554},
+ year={2021}
+}
+```
+
+
diff --git a/configs/culane.py b/configs/culane.py
new file mode 100644
index 0000000..2bc022f
--- /dev/null
+++ b/configs/culane.py
@@ -0,0 +1,88 @@
+net = dict(
+ type='RESANet',
+)
+
+backbone = dict(
+ type='ResNetWrapper',
+ resnet='resnet50',
+ pretrained=True,
+ replace_stride_with_dilation=[False, True, True],
+ out_conv=True,
+ fea_stride=8,
+)
+
+resa = dict(
+ type='RESA',
+ alpha=2.0,
+ iter=4,
+ input_channel=128,
+ conv_stride=9,
+)
+
+#decoder = 'PlainDecoder'
+decoder = 'BUSD'
+
+trainer = dict(
+ type='RESA'
+)
+
+evaluator = dict(
+ type='CULane',
+)
+
+optimizer = dict(
+ type='sgd',
+ lr=0.025,
+ weight_decay=1e-4,
+ momentum=0.9
+)
+
+epochs = 12
+batch_size = 8
+total_iter = (88880 // batch_size) * epochs
+import math
+scheduler = dict(
+ type = 'LambdaLR',
+ lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
+)
+
+loss_type = 'dice_loss'
+seg_loss_weight = 2.
+eval_ep = 6
+save_ep = epochs
+
+bg_weight = 0.4
+
+img_norm = dict(
+ mean=[103.939, 116.779, 123.68],
+ std=[1., 1., 1.]
+)
+
+img_height = 288
+img_width = 800
+cut_height = 240
+
+dataset_path = './data/CULane'
+dataset = dict(
+ train=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='train_gt.txt',
+ ),
+ val=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='test.txt',
+ ),
+ test=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='test.txt',
+ )
+)
+
+
+workers = 12
+num_classes = 4 + 1
+ignore_label = 255
+log_interval = 500
diff --git a/configs/culane_copy.py b/configs/culane_copy.py
new file mode 100644
index 0000000..d2478b6
--- /dev/null
+++ b/configs/culane_copy.py
@@ -0,0 +1,97 @@
+net = dict(
+ type='RESANet',
+)
+
+# backbone = dict(
+# type='ResNetWrapper',
+# resnet='resnet50',
+# pretrained=True,
+# replace_stride_with_dilation=[False, True, True],
+# out_conv=True,
+# fea_stride=8,
+# )
+
+backbone = dict(
+ type='ResNetWrapper',
+ resnet='resnet34',
+ pretrained=True,
+ replace_stride_with_dilation=[False, False, False],
+ out_conv=False,
+ fea_stride=8,
+)
+
+resa = dict(
+ type='RESA',
+ alpha=2.0,
+ iter=4,
+ input_channel=128,
+ conv_stride=9,
+)
+
+#decoder = 'PlainDecoder'
+decoder = 'BUSD'
+
+trainer = dict(
+ type='RESA'
+)
+
+evaluator = dict(
+ type='CULane',
+)
+
+optimizer = dict(
+ type='sgd',
+ lr=0.025,
+ weight_decay=1e-4,
+ momentum=0.9
+)
+
+epochs = 20
+batch_size = 8
+total_iter = (88880 // batch_size) * epochs
+import math
+scheduler = dict(
+ type = 'LambdaLR',
+ lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
+)
+
+loss_type = 'dice_loss'
+seg_loss_weight = 2.
+eval_ep = 1
+save_ep = epochs
+
+bg_weight = 0.4
+
+img_norm = dict(
+ mean=[103.939, 116.779, 123.68],
+ std=[1., 1., 1.]
+)
+
+img_height = 288
+img_width = 800
+cut_height = 240
+
+dataset_path = './data/CULane'
+dataset = dict(
+ train=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='train_gt.txt',
+ ),
+ val=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='test.txt',
+ ),
+ test=dict(
+ type='CULane',
+ img_path=dataset_path,
+ data_list='test.txt',
+ )
+)
+
+
+workers = 12
+num_classes = 4 + 1
+ignore_label = 255
+log_interval = 500
diff --git a/configs/tusimple.py b/configs/tusimple.py
new file mode 100644
index 0000000..a075c19
--- /dev/null
+++ b/configs/tusimple.py
@@ -0,0 +1,93 @@
+net = dict(
+ type='RESANet',
+)
+
+backbone = dict(
+ type='ResNetWrapper',
+ resnet='resnet34',
+ pretrained=True,
+ replace_stride_with_dilation=[False, True, True],
+ out_conv=True,
+ fea_stride=8,
+)
+
+resa = dict(
+ type='RESA',
+ alpha=2.0,
+ iter=5,
+ input_channel=128,
+ conv_stride=9,
+)
+
+decoder = 'BUSD'
+
+trainer = dict(
+ type='RESA'
+)
+
+evaluator = dict(
+ type='Tusimple',
+ thresh = 0.60
+)
+
+optimizer = dict(
+ type='sgd',
+ lr=0.020,
+ weight_decay=1e-4,
+ momentum=0.9
+)
+
+total_iter = 181400
+import math
+scheduler = dict(
+ type = 'LambdaLR',
+ lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
+)
+
+bg_weight = 0.4
+
+img_norm = dict(
+ mean=[103.939, 116.779, 123.68],
+ std=[1., 1., 1.]
+)
+
+img_height = 368
+img_width = 640
+cut_height = 160
+seg_label = "seg_label"
+
+dataset_path = './data/tusimple'
+test_json_file = './data/tusimple/test_label.json'
+
+dataset = dict(
+ train=dict(
+ type='TuSimple',
+ img_path=dataset_path,
+ data_list='train_val_gt.txt',
+ ),
+ val=dict(
+ type='TuSimple',
+ img_path=dataset_path,
+ data_list='test_gt.txt'
+ ),
+ test=dict(
+ type='TuSimple',
+ img_path=dataset_path,
+ data_list='test_gt.txt'
+ )
+)
+
+
+loss_type = 'cross_entropy'
+seg_loss_weight = 1.0
+
+
+batch_size = 4
+workers = 12
+num_classes = 6 + 1
+ignore_label = 255
+epochs = 300
+log_interval = 100
+eval_ep = 1
+save_ep = epochs
+log_note = ''
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000..b94b540
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,4 @@
+from .registry import build_dataset, build_dataloader
+
+from .tusimple import TuSimple
+from .culane import CULane
diff --git a/datasets/base_dataset.py b/datasets/base_dataset.py
new file mode 100644
index 0000000..33a7f18
--- /dev/null
+++ b/datasets/base_dataset.py
@@ -0,0 +1,86 @@
+import os.path as osp
+import os
+import numpy as np
+import cv2
+import torch
+from torch.utils.data import Dataset
+import torchvision
+import utils.transforms as tf
+from .registry import DATASETS
+
+
+@DATASETS.register_module
+class BaseDataset(Dataset):
+ def __init__(self, img_path, data_list, list_path='list', cfg=None):
+ self.cfg = cfg
+ self.img_path = img_path
+ self.list_path = osp.join(img_path, list_path)
+ self.data_list = data_list
+ self.is_training = ('train' in data_list)
+
+ self.img_name_list = []
+ self.full_img_path_list = []
+ self.label_list = []
+ self.exist_list = []
+
+ self.transform = self.transform_train() if self.is_training else self.transform_val()
+
+ self.init()
+
+ def transform_train(self):
+ raise NotImplementedError()
+
+ def transform_val(self):
+ val_transform = torchvision.transforms.Compose([
+ tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
+ tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
+ self.cfg.img_norm['std'], (1, ))),
+ ])
+ return val_transform
+
+ def view(self, img, coords, file_path=None):
+ for coord in coords:
+ for x, y in coord:
+ if x <= 0 or y <= 0:
+ continue
+ x, y = int(x), int(y)
+ cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
+
+ if file_path is not None:
+ if not os.path.exists(osp.dirname(file_path)):
+ os.makedirs(osp.dirname(file_path))
+ cv2.imwrite(file_path, img)
+
+
+ def init(self):
+ raise NotImplementedError()
+
+
+ def __len__(self):
+ return len(self.full_img_path_list)
+
+ def __getitem__(self, idx):
+ img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
+ img = img[self.cfg.cut_height:, :, :]
+
+ if self.is_training:
+ label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
+ if len(label.shape) > 2:
+ label = label[:, :, 0]
+ label = label.squeeze()
+ label = label[self.cfg.cut_height:, :]
+ exist = self.exist_list[idx]
+ if self.transform:
+ img, label = self.transform((img, label))
+ label = torch.from_numpy(label).contiguous().long()
+ else:
+ img, = self.transform((img,))
+
+ img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
+ meta = {'full_img_path': self.full_img_path_list[idx],
+ 'img_name': self.img_name_list[idx]}
+
+ data = {'img': img, 'meta': meta}
+ if self.is_training:
+ data.update({'label': label, 'exist': exist})
+ return data
diff --git a/datasets/culane.py b/datasets/culane.py
new file mode 100644
index 0000000..c64a1bb
--- /dev/null
+++ b/datasets/culane.py
@@ -0,0 +1,72 @@
+import os
+import os.path as osp
+import numpy as np
+import torchvision
+import utils.transforms as tf
+from .base_dataset import BaseDataset
+from .registry import DATASETS
+import cv2
+import torch
+
+
+@DATASETS.register_module
+class CULane(BaseDataset):
+ def __init__(self, img_path, data_list, cfg=None):
+ super().__init__(img_path, data_list, cfg=cfg)
+ self.ori_imgh = 590
+ self.ori_imgw = 1640
+
+ def init(self):
+ with open(osp.join(self.list_path, self.data_list)) as f:
+ for line in f:
+ line_split = line.strip().split(" ")
+ self.img_name_list.append(line_split[0])
+ self.full_img_path_list.append(self.img_path + line_split[0])
+ if not self.is_training:
+ continue
+ self.label_list.append(self.img_path + line_split[1])
+ self.exist_list.append(
+ np.array([int(line_split[2]), int(line_split[3]),
+ int(line_split[4]), int(line_split[5])]))
+
+ def transform_train(self):
+ train_transform = torchvision.transforms.Compose([
+ tf.GroupRandomRotation(degree=(-2, 2)),
+ tf.GroupRandomHorizontalFlip(),
+ tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
+ tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
+ self.cfg.img_norm['std'], (1, ))),
+ ])
+ return train_transform
+
+ def probmap2lane(self, probmaps, exists, pts=18):
+ coords = []
+ probmaps = probmaps[1:, ...]
+ exists = exists > 0.5
+ for probmap, exist in zip(probmaps, exists):
+ if exist == 0:
+ continue
+ probmap = cv2.blur(probmap, (9, 9), borderType=cv2.BORDER_REPLICATE)
+ thr = 0.3
+ coordinate = np.zeros(pts)
+ cut_height = self.cfg.cut_height
+ for i in range(pts):
+ line = probmap[round(
+ self.cfg.img_height-i*20/(self.ori_imgh-cut_height)*self.cfg.img_height)-1]
+
+ if np.max(line) > thr:
+ coordinate[i] = np.argmax(line)+1
+ if np.sum(coordinate > 0) < 2:
+ continue
+
+ img_coord = np.zeros((pts, 2))
+ img_coord[:, :] = -1
+ for idx, value in enumerate(coordinate):
+ if value > 0:
+ img_coord[idx][0] = round(value*self.ori_imgw/self.cfg.img_width-1)
+ img_coord[idx][1] = round(self.ori_imgh-idx*20-1)
+
+ img_coord = img_coord.astype(int)
+ coords.append(img_coord)
+
+ return coords
diff --git a/datasets/registry.py b/datasets/registry.py
new file mode 100644
index 0000000..103a3ed
--- /dev/null
+++ b/datasets/registry.py
@@ -0,0 +1,36 @@
+from utils import Registry, build_from_cfg
+
+import torch
+
+DATASETS = Registry('datasets')
+
+def build(cfg, registry, default_args=None):
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return nn.Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+def build_dataset(split_cfg, cfg):
+ args = split_cfg.copy()
+ args.pop('type')
+ args = args.to_dict()
+ args['cfg'] = cfg
+ return build(split_cfg, DATASETS, default_args=args)
+
+def build_dataloader(split_cfg, cfg, is_train=True):
+ if is_train:
+ shuffle = True
+ else:
+ shuffle = False
+
+ dataset = build_dataset(split_cfg, cfg)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset, batch_size = cfg.batch_size, shuffle = shuffle,
+ num_workers = cfg.workers, pin_memory = False, drop_last = False)
+
+ return data_loader
diff --git a/datasets/tusimple.py b/datasets/tusimple.py
new file mode 100644
index 0000000..3097a23
--- /dev/null
+++ b/datasets/tusimple.py
@@ -0,0 +1,150 @@
+import os.path as osp
+import numpy as np
+import cv2
+import torchvision
+import utils.transforms as tf
+from .base_dataset import BaseDataset
+from .registry import DATASETS
+
+
+@DATASETS.register_module
+class TuSimple(BaseDataset):
+ def __init__(self, img_path, data_list, cfg=None):
+ super().__init__(img_path, data_list, 'seg_label/list', cfg)
+
+ def transform_train(self):
+ input_mean = self.cfg.img_norm['mean']
+ train_transform = torchvision.transforms.Compose([
+ tf.GroupRandomRotation(),
+ tf.GroupRandomHorizontalFlip(),
+ tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
+ tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
+ self.cfg.img_norm['std'], (1, ))),
+ ])
+ return train_transform
+
+
+ def init(self):
+ with open(osp.join(self.list_path, self.data_list)) as f:
+ for line in f:
+ line_split = line.strip().split(" ")
+ self.img_name_list.append(line_split[0])
+ self.full_img_path_list.append(self.img_path + line_split[0])
+ if not self.is_training:
+ continue
+ self.label_list.append(self.img_path + line_split[1])
+ self.exist_list.append(
+ np.array([int(line_split[2]), int(line_split[3]),
+ int(line_split[4]), int(line_split[5]),
+ int(line_split[6]), int(line_split[7])
+ ]))
+
+ def fix_gap(self, coordinate):
+ if any(x > 0 for x in coordinate):
+ start = [i for i, x in enumerate(coordinate) if x > 0][0]
+ end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
+ lane = coordinate[start:end+1]
+ if any(x < 0 for x in lane):
+ gap_start = [i for i, x in enumerate(
+ lane[:-1]) if x > 0 and lane[i+1] < 0]
+ gap_end = [i+1 for i,
+ x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
+ gap_id = [i for i, x in enumerate(lane) if x < 0]
+ if len(gap_start) == 0 or len(gap_end) == 0:
+ return coordinate
+ for id in gap_id:
+ for i in range(len(gap_start)):
+ if i >= len(gap_end):
+ return coordinate
+ if id > gap_start[i] and id < gap_end[i]:
+ gap_width = float(gap_end[i] - gap_start[i])
+ lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
+ gap_end[i] - id) / gap_width * lane[gap_start[i]])
+ if not all(x > 0 for x in lane):
+ print("Gaps still exist!")
+ coordinate[start:end+1] = lane
+ return coordinate
+
+ def is_short(self, lane):
+ start = [i for i, x in enumerate(lane) if x > 0]
+ if not start:
+ return 1
+ else:
+ return 0
+
+ def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
+ """
+ Arguments:
+ ----------
+ prob_map: prob map for single lane, np array size (h, w)
+ resize_shape: reshape size target, (H, W)
+
+ Return:
+ ----------
+ coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
+ """
+ if resize_shape is None:
+ resize_shape = prob_map.shape
+ h, w = prob_map.shape
+ H, W = resize_shape
+ H -= self.cfg.cut_height
+
+ coords = np.zeros(pts)
+ coords[:] = -1.0
+ for i in range(pts):
+ y = int((H - 10 - i * y_px_gap) * h / H)
+ if y < 0:
+ break
+ line = prob_map[y, :]
+ id = np.argmax(line)
+ if line[id] > thresh:
+ coords[i] = int(id / w * W)
+ if (coords > 0).sum() < 2:
+ coords = np.zeros(pts)
+ self.fix_gap(coords)
+ #print(coords.shape)
+
+ return coords
+
+ def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
+ """
+ Arguments:
+ ----------
+ seg_pred: np.array size (5, h, w)
+ resize_shape: reshape size target, (H, W)
+ exist: list of existence, e.g. [0, 1, 1, 0]
+ smooth: whether to smooth the probability or not
+ y_px_gap: y pixel gap for sampling
+ pts: how many points for one lane
+ thresh: probability threshold
+
+ Return:
+ ----------
+ coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
+ """
+ if resize_shape is None:
+ resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
+ _, h, w = seg_pred.shape
+ H, W = resize_shape
+ coordinates = []
+
+ for i in range(self.cfg.num_classes - 1):
+ prob_map = seg_pred[i + 1]
+ if smooth:
+ prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
+ coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
+ if self.is_short(coords):
+ continue
+ coordinates.append(
+ [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
+ range(pts)])
+
+
+ if len(coordinates) == 0:
+ coords = np.zeros(pts)
+ coordinates.append(
+ [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
+ range(pts)])
+ #print(coordinates)
+
+ return coordinates
diff --git a/intro.png b/intro.png
new file mode 100644
index 0000000..28dc664
Binary files /dev/null and b/intro.png differ
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..578d37b
--- /dev/null
+++ b/main.py
@@ -0,0 +1,73 @@
+import os
+import os.path as osp
+import time
+import shutil
+import torch
+import torchvision
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.nn.functional as F
+import torch.optim
+import cv2
+import numpy as np
+import models
+import argparse
+from utils.config import Config
+from runner.runner import Runner
+from datasets import build_dataloader
+
+
+def main():
+ args = parse_args()
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
+
+ cfg = Config.fromfile(args.config)
+ cfg.gpus = len(args.gpus)
+
+ cfg.load_from = args.load_from
+ cfg.finetune_from = args.finetune_from
+ cfg.view = args.view
+
+ cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
+
+ cudnn.benchmark = True
+ cudnn.fastest = True
+
+ runner = Runner(cfg)
+
+ if args.validate:
+ val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
+ runner.validate(val_loader)
+ else:
+ runner.train()
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a detector')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument(
+ '--work_dirs', type=str, default='work_dirs',
+ help='work dirs')
+ parser.add_argument(
+ '--load_from', default=None,
+ help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--finetune_from', default=None,
+ help='whether to finetune from the checkpoint')
+ parser.add_argument(
+ '--validate',
+ action='store_true',
+ help='whether to evaluate the checkpoint during training')
+ parser.add_argument(
+ '--view',
+ action='store_true',
+ help='whether to show visualization result')
+ parser.add_argument('--gpus', nargs='+', type=int, default='0')
+ parser.add_argument('--seed', type=int,
+ default=None, help='random seed')
+ args = parser.parse_args()
+
+ return args
+
+
+if __name__ == '__main__':
+ main()
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..fc812be
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1 @@
+from .resa import *
diff --git a/models/decoder.py b/models/decoder.py
new file mode 100644
index 0000000..f4228a7
--- /dev/null
+++ b/models/decoder.py
@@ -0,0 +1,129 @@
+from torch import nn
+import torch.nn.functional as F
+
+class PlainDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(PlainDecoder, self).__init__()
+ self.cfg = cfg
+
+ self.dropout = nn.Dropout2d(0.1)
+ self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.conv8(x)
+ x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
+ mode='bilinear', align_corners=False)
+
+ return x
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class non_bottleneck_1d(nn.Module):
+ def __init__(self, chann, dropprob, dilated):
+ super().__init__()
+
+ self.conv3x1_1 = nn.Conv2d(
+ chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
+
+ self.conv1x3_1 = nn.Conv2d(
+ chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
+
+ self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
+ dilation=(dilated, 1))
+
+ self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
+ dilation=(1, dilated))
+
+ self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.dropout = nn.Dropout2d(dropprob)
+
+ def forward(self, input):
+ output = self.conv3x1_1(input)
+ output = F.relu(output)
+ output = self.conv1x3_1(output)
+ output = self.bn1(output)
+ output = F.relu(output)
+
+ output = self.conv3x1_2(output)
+ output = F.relu(output)
+ output = self.conv1x3_2(output)
+ output = self.bn2(output)
+
+ if (self.dropout.p != 0):
+ output = self.dropout(output)
+
+ # +input = identity (residual connection)
+ return F.relu(output + input)
+
+
+class UpsamplerBlock(nn.Module):
+ def __init__(self, ninput, noutput, up_width, up_height):
+ super().__init__()
+
+ self.conv = nn.ConvTranspose2d(
+ ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
+
+ self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
+
+ self.follows = nn.ModuleList()
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+
+ # interpolate
+ self.up_width = up_width
+ self.up_height = up_height
+ self.interpolate_conv = conv1x1(ninput, noutput)
+ self.interpolate_bn = nn.BatchNorm2d(
+ noutput, eps=1e-3, track_running_stats=True)
+
+ def forward(self, input):
+ output = self.conv(input)
+ output = self.bn(output)
+ out = F.relu(output)
+ for follow in self.follows:
+ out = follow(out)
+
+ interpolate_output = self.interpolate_conv(input)
+ interpolate_output = self.interpolate_bn(interpolate_output)
+ interpolate_output = F.relu(interpolate_output)
+
+ interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
+ mode='bilinear', align_corners=False)
+
+ return out + interpolate
+
+class BUSD(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ img_height = cfg.img_height
+ img_width = cfg.img_width
+ num_classes = cfg.num_classes
+
+ self.layers = nn.ModuleList()
+
+ self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
+ up_height=int(img_height)//4, up_width=int(img_width)//4))
+ self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
+ up_height=int(img_height)//2, up_width=int(img_width)//2))
+ self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
+ up_height=int(img_height)//1, up_width=int(img_width)//1))
+
+ self.output_conv = conv1x1(16, num_classes)
+
+ def forward(self, input):
+ output = input
+
+ for layer in self.layers:
+ output = layer(output)
+
+ output = self.output_conv(output)
+
+ return output
diff --git a/models/decoder_copy.py b/models/decoder_copy.py
new file mode 100644
index 0000000..209a966
--- /dev/null
+++ b/models/decoder_copy.py
@@ -0,0 +1,135 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+
+class PlainDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(PlainDecoder, self).__init__()
+ self.cfg = cfg
+
+ self.dropout = nn.Dropout2d(0.1)
+ self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.conv8(x)
+ x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
+ mode='bilinear', align_corners=False)
+
+ return x
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class non_bottleneck_1d(nn.Module):
+ def __init__(self, chann, dropprob, dilated):
+ super().__init__()
+
+ self.conv3x1_1 = nn.Conv2d(
+ chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
+
+ self.conv1x3_1 = nn.Conv2d(
+ chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
+
+ self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
+ dilation=(dilated, 1))
+
+ self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
+ dilation=(1, dilated))
+
+ self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.dropout = nn.Dropout2d(dropprob)
+
+ def forward(self, input):
+ output = self.conv3x1_1(input)
+ output = F.relu(output)
+ output = self.conv1x3_1(output)
+ output = self.bn1(output)
+ output = F.relu(output)
+
+ output = self.conv3x1_2(output)
+ output = F.relu(output)
+ output = self.conv1x3_2(output)
+ output = self.bn2(output)
+
+ if (self.dropout.p != 0):
+ output = self.dropout(output)
+
+ # +input = identity (residual connection)
+ return F.relu(output + input)
+
+
+class UpsamplerBlock(nn.Module):
+ def __init__(self, ninput, noutput, up_width, up_height):
+ super().__init__()
+
+ self.conv = nn.ConvTranspose2d(
+ ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
+
+ self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
+
+ self.follows = nn.ModuleList()
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+
+ # interpolate
+ self.up_width = up_width
+ self.up_height = up_height
+ self.interpolate_conv = conv1x1(ninput, noutput)
+ self.interpolate_bn = nn.BatchNorm2d(
+ noutput, eps=1e-3, track_running_stats=True)
+
+ def forward(self, input):
+ output = self.conv(input)
+ output = self.bn(output)
+ out = F.relu(output)
+ for follow in self.follows:
+ out = follow(out)
+
+ interpolate_output = self.interpolate_conv(input)
+ interpolate_output = self.interpolate_bn(interpolate_output)
+ interpolate_output = F.relu(interpolate_output)
+
+ interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
+ mode='bilinear', align_corners=False)
+
+ return out + interpolate
+
+class BUSD(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ img_height = cfg.img_height
+ img_width = cfg.img_width
+ num_classes = cfg.num_classes
+
+ self.layers = nn.ModuleList()
+
+ self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
+ up_height=int(img_height)//4, up_width=int(img_width)//4))
+ self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
+ up_height=int(img_height)//2, up_width=int(img_width)//2))
+ self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
+ up_height=int(img_height)//1, up_width=int(img_width)//1))
+
+ self.output_conv = conv1x1(32, num_classes)
+
+ def forward(self, input):
+ x = input[0]
+ output = input[1]
+
+ for i,layer in enumerate(self.layers):
+ output = layer(output)
+ if i == 0:
+ output = torch.cat((x, output), dim=1)
+
+
+
+ output = self.output_conv(output)
+
+ return output
diff --git a/models/decoder_copy2.py b/models/decoder_copy2.py
new file mode 100644
index 0000000..9d092a4
--- /dev/null
+++ b/models/decoder_copy2.py
@@ -0,0 +1,143 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+class PlainDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(PlainDecoder, self).__init__()
+ self.cfg = cfg
+
+ self.dropout = nn.Dropout2d(0.1)
+ self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.conv8(x)
+ x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
+ mode='bilinear', align_corners=False)
+
+ return x
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class non_bottleneck_1d(nn.Module):
+ def __init__(self, chann, dropprob, dilated):
+ super().__init__()
+
+ self.conv3x1_1 = nn.Conv2d(
+ chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
+
+ self.conv1x3_1 = nn.Conv2d(
+ chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
+
+ self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
+ dilation=(dilated, 1))
+
+ self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
+ dilation=(1, dilated))
+
+ self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
+
+ self.dropout = nn.Dropout2d(dropprob)
+
+ def forward(self, input):
+ output = self.conv3x1_1(input)
+ output = F.relu(output)
+ output = self.conv1x3_1(output)
+ output = self.bn1(output)
+ output = F.relu(output)
+
+ output = self.conv3x1_2(output)
+ output = F.relu(output)
+ output = self.conv1x3_2(output)
+ output = self.bn2(output)
+
+ if (self.dropout.p != 0):
+ output = self.dropout(output)
+
+ # +input = identity (residual connection)
+ return F.relu(output + input)
+
+
+class UpsamplerBlock(nn.Module):
+ def __init__(self, ninput, noutput, up_width, up_height):
+ super().__init__()
+
+ self.conv = nn.ConvTranspose2d(
+ ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
+
+ self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
+
+ self.follows = nn.ModuleList()
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+ self.follows.append(non_bottleneck_1d(noutput, 0, 1))
+
+ # interpolate
+ self.up_width = up_width
+ self.up_height = up_height
+ self.interpolate_conv = conv1x1(ninput, noutput)
+ self.interpolate_bn = nn.BatchNorm2d(
+ noutput, eps=1e-3, track_running_stats=True)
+
+ def forward(self, input):
+ output = self.conv(input)
+ output = self.bn(output)
+ out = F.relu(output)
+ for follow in self.follows:
+ out = follow(out)
+
+ interpolate_output = self.interpolate_conv(input)
+ interpolate_output = self.interpolate_bn(interpolate_output)
+ interpolate_output = F.relu(interpolate_output)
+
+ interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
+ mode='bilinear', align_corners=False)
+
+ return out + interpolate
+
+class BUSD(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ img_height = cfg.img_height
+ img_width = cfg.img_width
+ num_classes = cfg.num_classes
+
+ self.layers = nn.ModuleList()
+
+ self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
+ up_height=int(img_height)//4, up_width=int(img_width)//4))
+ self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
+ up_height=int(img_height)//2, up_width=int(img_width)//2))
+ self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
+ up_height=int(img_height)//1, up_width=int(img_width)//1))
+ self.out1 = conv1x1(128, 64)
+ self.out2 = conv1x1(64, 32)
+ self.output_conv = conv1x1(16, num_classes)
+
+
+ def forward(self, input):
+ out1 = input[0]
+ out2 = input[1]
+ output = input[2]
+
+ for i,layer in enumerate(self.layers):
+ if i == 0:
+ output = layer(output)
+ output = torch.cat((out2, output), dim=1)
+ output = self.out1(output)
+ elif i == 1:
+ output = layer(output)
+ output = torch.cat((out1, output), dim=1)
+ output = self.out2(output)
+ else:
+ output = layer(output)
+
+ output = self.output_conv(output)
+
+ return output
diff --git a/models/mobilenetv2.py b/models/mobilenetv2.py
new file mode 100644
index 0000000..69d7b0d
--- /dev/null
+++ b/models/mobilenetv2.py
@@ -0,0 +1,422 @@
+from functools import partial
+from typing import Any, Callable, List, Optional
+
+import torch
+from torch import nn, Tensor
+
+
+from torchvision.transforms._presets import ImageClassification
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import Weights, WeightsEnum
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
+import warnings
+from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
+import collections
+from itertools import repeat
+M = TypeVar("M", bound=nn.Module)
+
+BUILTIN_MODELS = {}
+def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
+ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
+ key = name if name is not None else fn.__name__
+ if key in BUILTIN_MODELS:
+ raise ValueError(f"An entry is already registered under the name '{key}'.")
+ BUILTIN_MODELS[key] = fn
+ return fn
+
+ return wrapper
+
+def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
+ """
+ Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
+ Otherwise, we will make a tuple of length n, all with value of x.
+ reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
+
+ Args:
+ x (Any): input value
+ n (int): length of the resulting tuple
+ """
+ if isinstance(x, collections.abc.Iterable):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+class ConvNormActivation(torch.nn.Sequential):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]] = 3,
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Optional[Union[int, Tuple[int, ...], str]] = None,
+ groups: int = 1,
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ inplace: Optional[bool] = True,
+ bias: Optional[bool] = None,
+ conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
+ ) -> None:
+
+ if padding is None:
+ if isinstance(kernel_size, int) and isinstance(dilation, int):
+ padding = (kernel_size - 1) // 2 * dilation
+ else:
+ _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
+ kernel_size = _make_ntuple(kernel_size, _conv_dim)
+ dilation = _make_ntuple(dilation, _conv_dim)
+ padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
+ if bias is None:
+ bias = norm_layer is None
+
+ layers = [
+ conv_layer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ ]
+
+ if norm_layer is not None:
+ layers.append(norm_layer(out_channels))
+
+ if activation_layer is not None:
+ params = {} if inplace is None else {"inplace": inplace}
+ layers.append(activation_layer(**params))
+ super().__init__(*layers)
+ _log_api_usage_once(self)
+ self.out_channels = out_channels
+
+ if self.__class__ == ConvNormActivation:
+ warnings.warn(
+ "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
+ )
+
+
+class Conv2dNormActivation(ConvNormActivation):
+ """
+ Configurable block used for Convolution2d-Normalization-Activation blocks.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
+ stride (int, optional): Stride of the convolution. Default: 1
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
+ dilation (int): Spacing between kernel elements. Default: 1
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 3,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Optional[Union[int, Tuple[int, int], str]] = None,
+ groups: int = 1,
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ inplace: Optional[bool] = True,
+ bias: Optional[bool] = None,
+ ) -> None:
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ norm_layer,
+ activation_layer,
+ dilation,
+ inplace,
+ bias,
+ torch.nn.Conv2d,
+ )
+
+__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
+
+
+# necessary for backwards compatibility
+class InvertedResidual(nn.Module):
+ def __init__(
+ self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super().__init__()
+ self.stride = stride
+ if stride not in [1, 2]:
+ raise ValueError(f"stride should be 1 or 2 instead of {stride}")
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers: List[nn.Module] = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(
+ Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+ )
+ layers.extend(
+ [
+ # dw
+ Conv2dNormActivation(
+ hidden_dim,
+ hidden_dim,
+ stride=stride,
+ groups=hidden_dim,
+ norm_layer=norm_layer,
+ activation_layer=nn.ReLU6,
+ ),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ norm_layer(oup),
+ ]
+ )
+ self.conv = nn.Sequential(*layers)
+ self.out_channels = oup
+ self._is_cn = stride > 1
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(
+ self,
+ num_classes: int = 1000,
+ width_mult: float = 1.0,
+ inverted_residual_setting: Optional[List[List[int]]] = None,
+ round_nearest: int = 8,
+ block: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ dropout: float = 0.2,
+ ) -> None:
+ """
+ MobileNet V2 main class
+
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ norm_layer: Module specifying the normalization layer to use
+ dropout (float): The droupout probability
+
+ """
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if block is None:
+ block = InvertedResidual
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ input_channel = 32
+ last_channel = 1280
+
+ if inverted_residual_setting is None:
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 1], # **
+ [6, 96, 3, 1],
+ [6, 160, 3, 1], # **
+ [6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError(
+ f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
+ )
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features: List[nn.Module] = [
+ Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+ ]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
+ input_channel = output_channel
+ # building last several layers
+ features.append(
+ Conv2dNormActivation(
+ input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
+ )
+ )
+ # make it nn.Sequential
+ self.features = nn.Sequential(*features)
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(p=dropout),
+ nn.Linear(self.last_channel, num_classes),
+ )
+
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ x = self.features(x)
+ # Cannot use "squeeze" as batch-size can be 1
+ # x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
+ # x = torch.flatten(x, 1)
+ # x = self.classifier(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+_COMMON_META = {
+ "num_params": 3504872,
+ "min_size": (1, 1),
+ "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class MobileNet_V2_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.878,
+ "acc@5": 90.286,
+ }
+ },
+ "_ops": 0.301,
+ "_file_size": 13.555,
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 72.154,
+ "acc@5": 90.822,
+ }
+ },
+ "_ops": 0.301,
+ "_file_size": 13.598,
+ "_docs": """
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
+ `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+# @register_model()
+# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
+def mobilenet_v2(
+ *, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
+) -> MobileNetV2:
+ """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
+ Bottlenecks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.MobileNet_V2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.MobileNet_V2_Weights
+ :members:
+ """
+ weights = MobileNet_V2_Weights.verify(weights)
+
+ if weights is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+ model = MobileNetV2(**kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress))
+
+ return model
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class MobileNetv2Wrapper(nn.Module):
+ def __init__(self):
+ super(MobileNetv2Wrapper, self).__init__()
+ weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
+
+ self.model = MobileNetV2()
+
+ if weights is not None:
+ self.model.load_state_dict(weights.get_state_dict(progress=True))
+ self.out = conv1x1(
+ 1280, 128)
+
+ def forward(self, x):
+ # print(x.shape)
+ x = self.model(x)
+ # print(x.shape)
+ if self.out:
+ x = self.out(x)
+ # print(x.shape)
+ return x
+
diff --git a/models/mobilenetv2_copy2.py b/models/mobilenetv2_copy2.py
new file mode 100644
index 0000000..c17829f
--- /dev/null
+++ b/models/mobilenetv2_copy2.py
@@ -0,0 +1,436 @@
+from functools import partial
+from typing import Any, Callable, List, Optional
+
+import torch
+from torch import nn, Tensor
+
+
+from torchvision.transforms._presets import ImageClassification
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import Weights, WeightsEnum
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
+import warnings
+from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
+import collections
+from itertools import repeat
+M = TypeVar("M", bound=nn.Module)
+
+BUILTIN_MODELS = {}
+def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
+ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
+ key = name if name is not None else fn.__name__
+ if key in BUILTIN_MODELS:
+ raise ValueError(f"An entry is already registered under the name '{key}'.")
+ BUILTIN_MODELS[key] = fn
+ return fn
+
+ return wrapper
+
+def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
+ """
+ Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
+ Otherwise, we will make a tuple of length n, all with value of x.
+ reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
+
+ Args:
+ x (Any): input value
+ n (int): length of the resulting tuple
+ """
+ if isinstance(x, collections.abc.Iterable):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+class ConvNormActivation(torch.nn.Sequential):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]] = 3,
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Optional[Union[int, Tuple[int, ...], str]] = None,
+ groups: int = 1,
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ inplace: Optional[bool] = True,
+ bias: Optional[bool] = None,
+ conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
+ ) -> None:
+
+ if padding is None:
+ if isinstance(kernel_size, int) and isinstance(dilation, int):
+ padding = (kernel_size - 1) // 2 * dilation
+ else:
+ _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
+ kernel_size = _make_ntuple(kernel_size, _conv_dim)
+ dilation = _make_ntuple(dilation, _conv_dim)
+ padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
+ if bias is None:
+ bias = norm_layer is None
+
+ layers = [
+ conv_layer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ ]
+
+ if norm_layer is not None:
+ layers.append(norm_layer(out_channels))
+
+ if activation_layer is not None:
+ params = {} if inplace is None else {"inplace": inplace}
+ layers.append(activation_layer(**params))
+ super().__init__(*layers)
+ _log_api_usage_once(self)
+ self.out_channels = out_channels
+
+ if self.__class__ == ConvNormActivation:
+ warnings.warn(
+ "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
+ )
+
+
+class Conv2dNormActivation(ConvNormActivation):
+ """
+ Configurable block used for Convolution2d-Normalization-Activation blocks.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
+ stride (int, optional): Stride of the convolution. Default: 1
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
+ dilation (int): Spacing between kernel elements. Default: 1
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 3,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Optional[Union[int, Tuple[int, int], str]] = None,
+ groups: int = 1,
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ inplace: Optional[bool] = True,
+ bias: Optional[bool] = None,
+ ) -> None:
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ norm_layer,
+ activation_layer,
+ dilation,
+ inplace,
+ bias,
+ torch.nn.Conv2d,
+ )
+
+__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
+
+
+# necessary for backwards compatibility
+class InvertedResidual(nn.Module):
+ def __init__(
+ self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super().__init__()
+ self.stride = stride
+ if stride not in [1, 2]:
+ raise ValueError(f"stride should be 1 or 2 instead of {stride}")
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers: List[nn.Module] = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(
+ Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+ )
+ layers.extend(
+ [
+ # dw
+ Conv2dNormActivation(
+ hidden_dim,
+ hidden_dim,
+ stride=stride,
+ groups=hidden_dim,
+ norm_layer=norm_layer,
+ activation_layer=nn.ReLU6,
+ ),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ norm_layer(oup),
+ ]
+ )
+ self.conv = nn.Sequential(*layers)
+ self.out_channels = oup
+ self._is_cn = stride > 1
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(
+ self,
+ num_classes: int = 1000,
+ width_mult: float = 1.0,
+ inverted_residual_setting: Optional[List[List[int]]] = None,
+ round_nearest: int = 8,
+ block: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ dropout: float = 0.2,
+ ) -> None:
+ """
+ MobileNet V2 main class
+
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ norm_layer: Module specifying the normalization layer to use
+ dropout (float): The droupout probability
+
+ """
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if block is None:
+ block = InvertedResidual
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ input_channel = 32
+ last_channel = 1280
+
+ if inverted_residual_setting is None:
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 1],
+ [6, 32, 3, 1],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError(
+ f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
+ )
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features: List[nn.Module] = [
+ Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+ ]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
+ input_channel = output_channel
+ # building last several layers
+ features.append(
+ Conv2dNormActivation(
+ input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
+ )
+ )
+ # make it nn.Sequential
+ self.features = nn.Sequential(*features)
+ # self.layer1 = nn.Sequential(*features[:])
+ # self.layer2 = features[57:120]
+ # self.layer3 = features[120:]
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(p=dropout),
+ nn.Linear(self.last_channel, num_classes),
+ )
+
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ out_layers = []
+ for layer in self.features.named_modules():
+ for i, layer1 in enumerate(layer[1]):
+ # print(layer1)
+ x = layer1(x)
+ # print("第{}层,输出大小{}".format(i, x.shape))
+ if i in [0, 10, 18]:
+ out_layers.append(x)
+ break
+ # x = self.features(x)
+ # Cannot use "squeeze" as batch-size can be 1
+ # x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
+ # x = torch.flatten(x, 1)
+ # x = self.classifier(x)
+ return out_layers
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+_COMMON_META = {
+ "num_params": 3504872,
+ "min_size": (1, 1),
+ "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class MobileNet_V2_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.878,
+ "acc@5": 90.286,
+ }
+ },
+ "_ops": 0.301,
+ "_file_size": 13.555,
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 72.154,
+ "acc@5": 90.822,
+ }
+ },
+ "_ops": 0.301,
+ "_file_size": 13.598,
+ "_docs": """
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
+ `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+# @register_model()
+# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
+def mobilenet_v2(
+ *, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
+) -> MobileNetV2:
+ """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
+ Bottlenecks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.MobileNet_V2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.MobileNet_V2_Weights
+ :members:
+ """
+ weights = MobileNet_V2_Weights.verify(weights)
+
+ if weights is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+ model = MobileNetV2(**kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress))
+
+ return model
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class MobileNetv2Wrapper(nn.Module):
+ def __init__(self):
+ super(MobileNetv2Wrapper, self).__init__()
+ weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
+
+ self.model = MobileNetV2()
+
+ if weights is not None:
+ self.model.load_state_dict(weights.get_state_dict(progress=True))
+ self.out3 = conv1x1(1280, 128)
+
+ def forward(self, x):
+ # print(x.shape)
+ out_layers = self.model(x)
+ # print(x.shape)
+
+ # out_layers[0] = self.out1(out_layers[0])
+ # out_layers[1] = self.out2(out_layers[1])
+ out_layers[2] = self.out3(out_layers[2])
+ # print(x.shape)
+ return out_layers
+
+
diff --git a/models/registry.py b/models/registry.py
new file mode 100644
index 0000000..7b65eff
--- /dev/null
+++ b/models/registry.py
@@ -0,0 +1,16 @@
+from utils import Registry, build_from_cfg
+
+NET = Registry('net')
+
+def build(cfg, registry, default_args=None):
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return nn.Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+def build_net(cfg):
+ return build(cfg.net, NET, default_args=dict(cfg=cfg))
diff --git a/models/resa.py b/models/resa.py
new file mode 100644
index 0000000..93a0f30
--- /dev/null
+++ b/models/resa.py
@@ -0,0 +1,142 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+
+from models.registry import NET
+# from .resnet_copy import ResNetWrapper
+# from .resnet import ResNetWrapper
+from .decoder_copy2 import BUSD, PlainDecoder
+# from .decoder import BUSD, PlainDecoder
+# from .mobilenetv2 import MobileNetv2Wrapper
+from .mobilenetv2_copy2 import MobileNetv2Wrapper
+
+
+class RESA(nn.Module):
+ def __init__(self, cfg):
+ super(RESA, self).__init__()
+ self.iter = cfg.resa.iter
+ chan = cfg.resa.input_channel
+ fea_stride = cfg.backbone.fea_stride
+ self.height = cfg.img_height // fea_stride
+ self.width = cfg.img_width // fea_stride
+ self.alpha = cfg.resa.alpha
+ conv_stride = cfg.resa.conv_stride
+
+ for i in range(self.iter):
+ conv_vert1 = nn.Conv2d(
+ chan, chan, (1, conv_stride),
+ padding=(0, conv_stride//2), groups=1, bias=False)
+ conv_vert2 = nn.Conv2d(
+ chan, chan, (1, conv_stride),
+ padding=(0, conv_stride//2), groups=1, bias=False)
+
+ setattr(self, 'conv_d'+str(i), conv_vert1)
+ setattr(self, 'conv_u'+str(i), conv_vert2)
+
+ conv_hori1 = nn.Conv2d(
+ chan, chan, (conv_stride, 1),
+ padding=(conv_stride//2, 0), groups=1, bias=False)
+ conv_hori2 = nn.Conv2d(
+ chan, chan, (conv_stride, 1),
+ padding=(conv_stride//2, 0), groups=1, bias=False)
+
+ setattr(self, 'conv_r'+str(i), conv_hori1)
+ setattr(self, 'conv_l'+str(i), conv_hori2)
+
+ idx_d = (torch.arange(self.height) + self.height //
+ 2**(self.iter - i)) % self.height
+ setattr(self, 'idx_d'+str(i), idx_d)
+
+ idx_u = (torch.arange(self.height) - self.height //
+ 2**(self.iter - i)) % self.height
+ setattr(self, 'idx_u'+str(i), idx_u)
+
+ idx_r = (torch.arange(self.width) + self.width //
+ 2**(self.iter - i)) % self.width
+ setattr(self, 'idx_r'+str(i), idx_r)
+
+ idx_l = (torch.arange(self.width) - self.width //
+ 2**(self.iter - i)) % self.width
+ setattr(self, 'idx_l'+str(i), idx_l)
+
+ def forward(self, x):
+ x = x.clone()
+
+ for direction in ['d', 'u']:
+ for i in range(self.iter):
+ conv = getattr(self, 'conv_' + direction + str(i))
+ idx = getattr(self, 'idx_' + direction + str(i))
+ x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
+
+ for direction in ['r', 'l']:
+ for i in range(self.iter):
+ conv = getattr(self, 'conv_' + direction + str(i))
+ idx = getattr(self, 'idx_' + direction + str(i))
+ x.add_(self.alpha * F.relu(conv(x[..., idx])))
+
+ return x
+
+
+
+class ExistHead(nn.Module):
+ def __init__(self, cfg=None):
+ super(ExistHead, self).__init__()
+ self.cfg = cfg
+
+ self.dropout = nn.Dropout2d(0.1) # ???
+ self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
+
+ stride = cfg.backbone.fea_stride * 2
+ self.fc9 = nn.Linear(
+ int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
+ self.fc10 = nn.Linear(128, cfg.num_classes-1)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.conv8(x)
+
+ x = F.softmax(x, dim=1)
+ x = F.avg_pool2d(x, 2, stride=2, padding=0)
+ x = x.view(-1, x.numel() // x.shape[0])
+ x = self.fc9(x)
+ x = F.relu(x)
+ x = self.fc10(x)
+ x = torch.sigmoid(x)
+
+ return x
+
+
+@NET.register_module
+class RESANet(nn.Module):
+ def __init__(self, cfg):
+ super(RESANet, self).__init__()
+ self.cfg = cfg
+ # self.backbone = ResNetWrapper(resnet='resnet34',pretrained=True,
+ # replace_stride_with_dilation=[False, False, False],
+ # out_conv=False)
+ self.backbone = MobileNetv2Wrapper()
+ self.resa = RESA(cfg)
+ self.decoder = eval(cfg.decoder)(cfg)
+ self.heads = ExistHead(cfg)
+
+ def forward(self, batch):
+ # x1, fea, _, _ = self.backbone(batch)
+ # fea = self.resa(fea)
+ # # print(fea.shape)
+ # seg = self.decoder([x1,fea])
+ # # print(seg.shape)
+ # exist = self.heads(fea)
+
+ fea1,fea2,fea = self.backbone(batch)
+ # print('fea1',fea1.shape)
+ # print('fea2',fea2.shape)
+ # print('fea',fea.shape)
+ fea = self.resa(fea)
+ # print(fea.shape)
+ seg = self.decoder([fea1,fea2,fea])
+ # print(seg.shape)
+ exist = self.heads(fea)
+
+ output = {'seg': seg, 'exist': exist}
+
+ return output
diff --git a/models/resnet.py b/models/resnet.py
new file mode 100644
index 0000000..3f10cfc
--- /dev/null
+++ b/models/resnet.py
@@ -0,0 +1,377 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.hub import load_state_dict_from_url
+
+
+# This code is borrow from torchvision.
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError(
+ 'BasicBlock only supports groups=1 and base_width=64')
+ # if dilation > 1:
+ # raise NotImplementedError(
+ # "Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, dilation=dilation)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetWrapper(nn.Module):
+
+ def __init__(self, cfg):
+ super(ResNetWrapper, self).__init__()
+ self.cfg = cfg
+ self.in_channels = [64, 128, 256, 512]
+ if 'in_channels' in cfg.backbone:
+ self.in_channels = cfg.backbone.in_channels
+ self.model = eval(cfg.backbone.resnet)(
+ pretrained=cfg.backbone.pretrained,
+ replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
+ self.out = None
+ if cfg.backbone.out_conv:
+ out_channel = 512
+ for chan in reversed(self.in_channels):
+ if chan < 0: continue
+ out_channel = chan
+ break
+ self.out = conv1x1(
+ out_channel * self.model.expansion, 128)
+
+ def forward(self, x):
+ x = self.model(x)
+ if self.out:
+ x = self.out(x)
+ return x
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, in_channels=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.in_channels = in_channels
+ self.layer1 = self._make_layer(block, in_channels[0], layers[0])
+ self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ if in_channels[3] > 0:
+ self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.expansion = block.expansion
+
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ if self.in_channels[3] > 0:
+ x = self.layer4(x)
+
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ # x = self.fc(x)
+
+ return x
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/models/resnet_copy.py b/models/resnet_copy.py
new file mode 100644
index 0000000..4290e59
--- /dev/null
+++ b/models/resnet_copy.py
@@ -0,0 +1,432 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.hub import load_state_dict_from_url
+
+
+
+model_urls = {
+ 'resnet18':
+ 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34':
+ 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50':
+ 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101':
+ 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152':
+ 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d':
+ 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d':
+ 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2':
+ 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2':
+ 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError(
+ 'BasicBlock only supports groups=1 and base_width=64')
+ # if dilation > 1:
+ # raise NotImplementedError(
+ # "Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, dilation=dilation)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+
+class ResNetWrapper(nn.Module):
+ def __init__(self,
+ resnet='resnet18',
+ pretrained=True,
+ replace_stride_with_dilation=[False, False, False],
+ out_conv=False,
+ fea_stride=8,
+ out_channel=128,
+ in_channels=[64, 128, 256, 512],
+ cfg=None):
+ super(ResNetWrapper, self).__init__()
+ self.cfg = cfg
+ self.in_channels = in_channels
+
+ self.model = eval(resnet)(
+ pretrained=pretrained,
+ replace_stride_with_dilation=replace_stride_with_dilation,
+ in_channels=self.in_channels)
+ self.out = None
+ if out_conv:
+ out_channel = 512
+ for chan in reversed(self.in_channels):
+ if chan < 0: continue
+ out_channel = chan
+ break
+ self.out = conv1x1(out_channel * self.model.expansion,
+ cfg.featuremap_out_channel)
+
+ def forward(self, x):
+ x = self.model(x)
+ if self.out:
+ x[-1] = self.out(x[-1])
+ return x
+
+
+class ResNet(nn.Module):
+ def __init__(self,
+ block,
+ layers,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ norm_layer=None,
+ in_channels=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(
+ replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3,
+ self.inplanes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.in_channels = in_channels
+ self.layer1 = self._make_layer(block, in_channels[0], layers[0])
+ self.layer2 = self._make_layer(block,
+ in_channels[1],
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ in_channels[2],
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ if in_channels[3] > 0:
+ self.layer4 = self._make_layer(
+ block,
+ in_channels[3],
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.expansion = block.expansion
+
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight,
+ mode='fan_out',
+ nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out_layers = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ # out_layers = []
+ for name in ['layer1', 'layer2', 'layer3', 'layer4']:
+ if not hasattr(self, name):
+ continue
+ layer = getattr(self, name)
+ x = layer(x)
+ out_layers.append(x)
+
+ return out_layers
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ print('pretrained model: ', model_urls[arch])
+ # state_dict = torch.load(model_urls[arch])['net']
+ state_dict = load_state_dict_from_url(model_urls[arch])
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
+ progress, **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
+ progress, **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
+ progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
+ progress, **kwargs)
diff --git a/requirement.txt b/requirement.txt
new file mode 100644
index 0000000..92c0165
--- /dev/null
+++ b/requirement.txt
@@ -0,0 +1,8 @@
+pandas
+addict
+sklearn
+opencv-python
+pytorch_warmup
+scikit-image
+tqdm
+termcolor
diff --git a/runner/__init__.py b/runner/__init__.py
new file mode 100644
index 0000000..1644223
--- /dev/null
+++ b/runner/__init__.py
@@ -0,0 +1,4 @@
+from .evaluator import *
+from .resa_trainer import *
+
+from .registry import build_evaluator
diff --git a/runner/evaluator/__init__.py b/runner/evaluator/__init__.py
new file mode 100644
index 0000000..308528c
--- /dev/null
+++ b/runner/evaluator/__init__.py
@@ -0,0 +1,2 @@
+from .tusimple.tusimple import Tusimple
+from .culane.culane import CULane
diff --git a/runner/evaluator/culane/culane.py b/runner/evaluator/culane/culane.py
new file mode 100644
index 0000000..534a87d
--- /dev/null
+++ b/runner/evaluator/culane/culane.py
@@ -0,0 +1,158 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+from runner.logger import get_logger
+
+from runner.registry import EVALUATOR
+import json
+import os
+import subprocess
+from shutil import rmtree
+import cv2
+import numpy as np
+
+def check():
+ import subprocess
+ import sys
+ FNULL = open(os.devnull, 'w')
+ result = subprocess.call(
+ './runner/evaluator/culane/lane_evaluation/evaluate', stdout=FNULL, stderr=FNULL)
+ if result > 1:
+ print('There is something wrong with evaluate tool, please compile it.')
+ sys.exit()
+
+def read_helper(path):
+ lines = open(path, 'r').readlines()[1:]
+ lines = ' '.join(lines)
+ values = lines.split(' ')[1::2]
+ keys = lines.split(' ')[0::2]
+ keys = [key[:-1] for key in keys]
+ res = {k : v for k,v in zip(keys,values)}
+ return res
+
+def call_culane_eval(data_dir, output_path='./output'):
+ if data_dir[-1] != '/':
+ data_dir = data_dir + '/'
+ detect_dir=os.path.join(output_path, 'lines')+'/'
+
+ w_lane=30
+ iou=0.5; # Set iou to 0.3 or 0.5
+ im_w=1640
+ im_h=590
+ frame=1
+ list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt')
+ list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt')
+ list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt')
+ list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt')
+ list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt')
+ list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt')
+ list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt')
+ list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt')
+ list8 = os.path.join(data_dir,'list/test_split/test8_night.txt')
+ if not os.path.exists(os.path.join(output_path,'txt')):
+ os.mkdir(os.path.join(output_path,'txt'))
+ out0 = os.path.join(output_path,'txt','out0_normal.txt')
+ out1 = os.path.join(output_path,'txt','out1_crowd.txt')
+ out2 = os.path.join(output_path,'txt','out2_hlight.txt')
+ out3 = os.path.join(output_path,'txt','out3_shadow.txt')
+ out4 = os.path.join(output_path,'txt','out4_noline.txt')
+ out5 = os.path.join(output_path,'txt','out5_arrow.txt')
+ out6 = os.path.join(output_path,'txt','out6_curve.txt')
+ out7 = os.path.join(output_path,'txt','out7_cross.txt')
+ out8 = os.path.join(output_path,'txt','out8_night.txt')
+
+ eval_cmd = './runner/evaluator/culane/lane_evaluation/evaluate'
+
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
+ os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
+ res_all = {}
+ res_all['normal'] = read_helper(out0)
+ res_all['crowd']= read_helper(out1)
+ res_all['night']= read_helper(out8)
+ res_all['noline'] = read_helper(out4)
+ res_all['shadow'] = read_helper(out3)
+ res_all['arrow']= read_helper(out5)
+ res_all['hlight'] = read_helper(out2)
+ res_all['curve']= read_helper(out6)
+ res_all['cross']= read_helper(out7)
+ return res_all
+
+@EVALUATOR.register_module
+class CULane(nn.Module):
+ def __init__(self, cfg):
+ super(CULane, self).__init__()
+ # Firstly, check the evaluation tool
+ check()
+ self.cfg = cfg
+ self.blur = torch.nn.Conv2d(
+ 5, 5, 9, padding=4, bias=False, groups=5).cuda()
+ torch.nn.init.constant_(self.blur.weight, 1 / 81)
+ self.logger = get_logger('resa')
+ self.out_dir = os.path.join(self.cfg.work_dir, 'lines')
+ if cfg.view:
+ self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
+
+ def evaluate(self, dataset, output, batch):
+ seg, exists = output['seg'], output['exist']
+ predictmaps = F.softmax(seg, dim=1).cpu().numpy()
+ exists = exists.cpu().numpy()
+ batch_size = seg.size(0)
+ img_name = batch['meta']['img_name']
+ img_path = batch['meta']['full_img_path']
+ for i in range(batch_size):
+ coords = dataset.probmap2lane(predictmaps[i], exists[i])
+ outname = self.out_dir + img_name[i][:-4] + '.lines.txt'
+ outdir = os.path.dirname(outname)
+ if not os.path.exists(outdir):
+ os.makedirs(outdir)
+ f = open(outname, 'w')
+ for coord in coords:
+ for x, y in coord:
+ if x < 0 and y < 0:
+ continue
+ f.write('%d %d ' % (x, y))
+ f.write('\n')
+ f.close()
+
+ if self.cfg.view:
+ img = cv2.imread(img_path[i]).astype(np.float32)
+ dataset.view(img, coords, self.view_dir+img_name[i])
+
+
+ def summarize(self):
+ self.logger.info('summarize result...')
+ eval_list_path = os.path.join(
+ self.cfg.dataset_path, "list", self.cfg.dataset.val.data_list)
+ #prob2lines(self.prob_dir, self.out_dir, eval_list_path, self.cfg)
+ res = call_culane_eval(self.cfg.dataset_path, output_path=self.cfg.work_dir)
+ TP,FP,FN = 0,0,0
+ out_str = 'Copypaste: '
+ for k, v in res.items():
+ val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
+ val_tp, val_fp, val_fn = int(v['tp']), int(v['fp']), int(v['fn'])
+ val_p, val_r, val_f1 = float(v['precision']), float(v['recall']), float(v['Fmeasure'])
+ TP += val_tp
+ FP += val_fp
+ FN += val_fn
+ self.logger.info(k + ': ' + str(v))
+ out_str += k
+ for metric, value in v.items():
+ out_str += ' ' + str(value).rstrip('\n')
+ out_str += ' '
+ P = TP * 1.0 / (TP + FP + 1e-9)
+ R = TP * 1.0 / (TP + FN + 1e-9)
+ F = 2*P*R/(P + R + 1e-9)
+ overall_result_str = ('Overall Precision: %f Recall: %f F1: %f' % (P, R, F))
+ self.logger.info(overall_result_str)
+ out_str = out_str + overall_result_str
+ self.logger.info(out_str)
+
+ # delete the tmp output
+ rmtree(self.out_dir)
diff --git a/runner/evaluator/culane/lane_evaluation/.gitignore b/runner/evaluator/culane/lane_evaluation/.gitignore
new file mode 100644
index 0000000..b501d98
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/.gitignore
@@ -0,0 +1,2 @@
+build/
+evaluate
diff --git a/runner/evaluator/culane/lane_evaluation/Makefile b/runner/evaluator/culane/lane_evaluation/Makefile
new file mode 100644
index 0000000..d4457b9
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/Makefile
@@ -0,0 +1,50 @@
+PROJECT_NAME:= evaluate
+
+# config ----------------------------------
+OPENCV_VERSION := 3
+
+INCLUDE_DIRS := include
+LIBRARY_DIRS := lib /usr/local/lib
+
+COMMON_FLAGS := -DCPU_ONLY
+CXXFLAGS := -std=c++11 -fopenmp
+LDFLAGS := -fopenmp -Wl,-rpath,./lib
+BUILD_DIR := build
+
+
+# make rules -------------------------------
+CXX ?= g++
+BUILD_DIR ?= ./build
+
+LIBRARIES += opencv_core opencv_highgui opencv_imgproc
+ifeq ($(OPENCV_VERSION), 3)
+ LIBRARIES += opencv_imgcodecs
+endif
+
+CXXFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+LDFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(LIBRARY_DIRS),-L$(includedir)) $(foreach library,$(LIBRARIES),-l$(library))
+SRC_DIRS += $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print)
+CXX_SRCS += $(shell find src/ -name "*.cpp")
+CXX_TARGETS:=$(patsubst %.cpp, $(BUILD_DIR)/%.o, $(CXX_SRCS))
+ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS)))
+
+.PHONY: all
+all: $(PROJECT_NAME)
+
+.PHONY: $(ALL_BUILD_DIRS)
+$(ALL_BUILD_DIRS):
+ @mkdir -p $@
+
+$(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS)
+ @echo "CXX" $<
+ @$(CXX) $(CXXFLAGS) -c -o $@ $<
+
+$(PROJECT_NAME): $(CXX_TARGETS)
+ @echo "CXX/LD" $@
+ @$(CXX) -o $@ $^ $(LDFLAGS)
+
+.PHONY: clean
+clean:
+ @rm -rf $(CXX_TARGETS)
+ @rm -rf $(PROJECT_NAME)
+ @rm -rf $(BUILD_DIR)
diff --git a/runner/evaluator/culane/lane_evaluation/include/counter.hpp b/runner/evaluator/culane/lane_evaluation/include/counter.hpp
new file mode 100644
index 0000000..430e1d4
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/include/counter.hpp
@@ -0,0 +1,47 @@
+#ifndef COUNTER_HPP
+#define COUNTER_HPP
+
+#include "lane_compare.hpp"
+#include "hungarianGraph.hpp"
+#include
+#include
+#include
+#include
+#include
+
+using namespace std;
+using namespace cv;
+
+// before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp
+class Counter
+{
+ public:
+ Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){
+ im_width = _im_width;
+ im_height = _im_height;
+ sim_threshold = _iou_threshold;
+ lane_compare = new LaneCompare(_im_width, _im_height, _lane_width, LaneCompare::IOU);
+ };
+ double get_precision(void);
+ double get_recall(void);
+ long getTP(void);
+ long getFP(void);
+ long getFN(void);
+ void setTP(long);
+ void setFP(long);
+ void setFN(long);
+ // direct add tp, fp, tn and fn
+ // first match with hungarian
+ tuple, long, long, long, long> count_im_pair(const vector > &anno_lanes, const vector > &detect_lanes);
+ void makeMatch(const vector > &similarity, vector &match1, vector &match2);
+
+ private:
+ double sim_threshold;
+ int im_width;
+ int im_height;
+ long tp;
+ long fp;
+ long fn;
+ LaneCompare *lane_compare;
+};
+#endif
diff --git a/runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp b/runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp
new file mode 100644
index 0000000..40c3ead
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp
@@ -0,0 +1,71 @@
+#ifndef HUNGARIAN_GRAPH_HPP
+#define HUNGARIAN_GRAPH_HPP
+#include
+using namespace std;
+
+struct pipartiteGraph {
+ vector > mat;
+ vector leftUsed, rightUsed;
+ vector leftWeight, rightWeight;
+ vectorrightMatch, leftMatch;
+ int leftNum, rightNum;
+ bool matchDfs(int u) {
+ leftUsed[u] = true;
+ for (int v = 0; v < rightNum; v++) {
+ if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) {
+ rightUsed[v] = true;
+ if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) {
+ rightMatch[v] = u;
+ leftMatch[u] = v;
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+ void resize(int leftNum, int rightNum) {
+ this->leftNum = leftNum;
+ this->rightNum = rightNum;
+ leftMatch.resize(leftNum);
+ rightMatch.resize(rightNum);
+ leftUsed.resize(leftNum);
+ rightUsed.resize(rightNum);
+ leftWeight.resize(leftNum);
+ rightWeight.resize(rightNum);
+ mat.resize(leftNum);
+ for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum);
+ }
+ void match() {
+ for (int i = 0; i < leftNum; i++) leftMatch[i] = -1;
+ for (int i = 0; i < rightNum; i++) rightMatch[i] = -1;
+ for (int i = 0; i < rightNum; i++) rightWeight[i] = 0;
+ for (int i = 0; i < leftNum; i++) {
+ leftWeight[i] = -1e5;
+ for (int j = 0; j < rightNum; j++) {
+ if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j];
+ }
+ }
+
+ for (int u = 0; u < leftNum; u++) {
+ while (1) {
+ for (int i = 0; i < leftNum; i++) leftUsed[i] = false;
+ for (int i = 0; i < rightNum; i++) rightUsed[i] = false;
+ if (matchDfs(u)) break;
+ double d = 1e10;
+ for (int i = 0; i < leftNum; i++) {
+ if (leftUsed[i] ) {
+ for (int j = 0; j < rightNum; j++) {
+ if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]);
+ }
+ }
+ }
+ if (d == 1e10) return ;
+ for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d;
+ for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d;
+ }
+ }
+ }
+};
+
+
+#endif // HUNGARIAN_GRAPH_HPP
diff --git a/runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp b/runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp
new file mode 100644
index 0000000..02ddfce
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp
@@ -0,0 +1,51 @@
+#ifndef LANE_COMPARE_HPP
+#define LANE_COMPARE_HPP
+
+#include "spline.hpp"
+#include
+#include
+#include
+#include
+
+#if CV_VERSION_EPOCH == 2
+#define OPENCV2
+#elif CV_VERSION_MAJOR == 3
+#define OPENCV3
+#else
+#error Not support this OpenCV version
+#endif
+
+#ifdef OPENCV3
+#include
+#elif defined(OPENCV2)
+#include
+#endif
+
+using namespace std;
+using namespace cv;
+
+class LaneCompare{
+ public:
+ enum CompareMode{
+ IOU,
+ Caltech
+ };
+
+ LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){
+ im_width = _im_width;
+ im_height = _im_height;
+ compare_mode = _compare_mode;
+ lane_width = _lane_width;
+ }
+
+ double get_lane_similarity(const vector &lane1, const vector &lane2);
+ void resize_lane(vector &curr_lane, int curr_width, int curr_height);
+ private:
+ CompareMode compare_mode;
+ int im_width;
+ int im_height;
+ int lane_width;
+ Spline splineSolver;
+};
+
+#endif
diff --git a/runner/evaluator/culane/lane_evaluation/include/spline.hpp b/runner/evaluator/culane/lane_evaluation/include/spline.hpp
new file mode 100644
index 0000000..0ae73ef
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/include/spline.hpp
@@ -0,0 +1,28 @@
+#ifndef SPLINE_HPP
+#define SPLINE_HPP
+#include
+#include
+#include
+#include
+
+using namespace cv;
+using namespace std;
+
+struct Func {
+ double a_x;
+ double b_x;
+ double c_x;
+ double d_x;
+ double a_y;
+ double b_y;
+ double c_y;
+ double d_y;
+ double h;
+};
+class Spline {
+public:
+ vector splineInterpTimes(const vector &tmp_line, int times);
+ vector splineInterpStep(vector tmp_line, double step);
+ vector cal_fun(const vector &point_v);
+};
+#endif
diff --git a/runner/evaluator/culane/lane_evaluation/src/counter.cpp b/runner/evaluator/culane/lane_evaluation/src/counter.cpp
new file mode 100644
index 0000000..f4fa6a7
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/src/counter.cpp
@@ -0,0 +1,134 @@
+/*************************************************************************
+ > File Name: counter.cpp
+ > Author: Xingang Pan, Jun Li
+ > Mail: px117@ie.cuhk.edu.hk
+ > Created Time: Thu Jul 14 20:23:08 2016
+ ************************************************************************/
+
+#include "counter.hpp"
+
+double Counter::get_precision(void)
+{
+ cerr<<"tp: "<, long, long, long, long> Counter::count_im_pair(const vector > &anno_lanes, const vector > &detect_lanes)
+{
+ vector anno_match(anno_lanes.size(), -1);
+ vector detect_match;
+ if(anno_lanes.empty())
+ {
+ return make_tuple(anno_match, 0, detect_lanes.size(), 0, 0);
+ }
+
+ if(detect_lanes.empty())
+ {
+ return make_tuple(anno_match, 0, 0, 0, anno_lanes.size());
+ }
+ // hungarian match first
+
+ // first calc similarity matrix
+ vector > similarity(anno_lanes.size(), vector(detect_lanes.size(), 0));
+ for(int i=0; i &curr_anno_lane = anno_lanes[i];
+ for(int j=0; j &curr_detect_lane = detect_lanes[j];
+ similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane);
+ }
+ }
+
+
+
+ makeMatch(similarity, anno_match, detect_match);
+
+
+ int curr_tp = 0;
+ // count and add
+ for(int i=0; i=0 && similarity[i][anno_match[i]] > sim_threshold)
+ {
+ curr_tp++;
+ }
+ else
+ {
+ anno_match[i] = -1;
+ }
+ }
+ int curr_fn = anno_lanes.size() - curr_tp;
+ int curr_fp = detect_lanes.size() - curr_tp;
+ return make_tuple(anno_match, curr_tp, curr_fp, 0, curr_fn);
+}
+
+
+void Counter::makeMatch(const vector > &similarity, vector &match1, vector &match2) {
+ int m = similarity.size();
+ int n = similarity[0].size();
+ pipartiteGraph gra;
+ bool have_exchange = false;
+ if (m > n) {
+ have_exchange = true;
+ swap(m, n);
+ }
+ gra.resize(m, n);
+ for (int i = 0; i < gra.leftNum; i++) {
+ for (int j = 0; j < gra.rightNum; j++) {
+ if(have_exchange)
+ gra.mat[i][j] = similarity[j][i];
+ else
+ gra.mat[i][j] = similarity[i][j];
+ }
+ }
+ gra.match();
+ match1 = gra.leftMatch;
+ match2 = gra.rightMatch;
+ if (have_exchange) swap(match1, match2);
+}
diff --git a/runner/evaluator/culane/lane_evaluation/src/evaluate.cpp b/runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
new file mode 100644
index 0000000..ae95bb4
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
@@ -0,0 +1,302 @@
+/*************************************************************************
+ > File Name: evaluate.cpp
+ > Author: Xingang Pan, Jun Li
+ > Mail: px117@ie.cuhk.edu.hk
+ > Created Time: 2016年07月14日 星期四 18时28分45秒
+ ************************************************************************/
+
+#include "counter.hpp"
+#include "spline.hpp"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace std;
+using namespace cv;
+
+void help(void) {
+ cout << "./evaluate [OPTIONS]" << endl;
+ cout << "-h : print usage help" << endl;
+ cout << "-a : directory for annotation files (default: "
+ "/data/driving/eval_data/anno_label/)" << endl;
+ cout << "-d : directory for detection files (default: "
+ "/data/driving/eval_data/predict_label/)" << endl;
+ cout << "-i : directory for image files (default: "
+ "/data/driving/eval_data/img/)" << endl;
+ cout << "-l : list of images used for evaluation (default: "
+ "/data/driving/eval_data/img/all.txt)" << endl;
+ cout << "-w : width of the lanes (default: 10)" << endl;
+ cout << "-t : threshold of iou (default: 0.4)" << endl;
+ cout << "-c : cols (max image width) (default: 1920)"
+ << endl;
+ cout << "-r : rows (max image height) (default: 1080)"
+ << endl;
+ cout << "-s : show visualization" << endl;
+ cout << "-f : start frame in the test set (default: 1)"
+ << endl;
+}
+
+void read_lane_file(const string &file_name, vector> &lanes);
+void visualize(string &full_im_name, vector> &anno_lanes,
+ vector> &detect_lanes, vector anno_match,
+ int width_lane, string save_path = "");
+
+int main(int argc, char **argv) {
+ // process params
+ string anno_dir = "/data/driving/eval_data/anno_label/";
+ string detect_dir = "/data/driving/eval_data/predict_label/";
+ string im_dir = "/data/driving/eval_data/img/";
+ string list_im_file = "/data/driving/eval_data/img/all.txt";
+ string output_file = "./output.txt";
+ int width_lane = 10;
+ double iou_threshold = 0.4;
+ int im_width = 1920;
+ int im_height = 1080;
+ int oc;
+ bool show = false;
+ int frame = 1;
+ string save_path = "";
+ while ((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:p:")) != -1) {
+ switch (oc) {
+ case 'h':
+ help();
+ return 0;
+ case 'a':
+ anno_dir = optarg;
+ break;
+ case 'd':
+ detect_dir = optarg;
+ break;
+ case 'i':
+ im_dir = optarg;
+ break;
+ case 'l':
+ list_im_file = optarg;
+ break;
+ case 'w':
+ width_lane = atoi(optarg);
+ break;
+ case 't':
+ iou_threshold = atof(optarg);
+ break;
+ case 'c':
+ im_width = atoi(optarg);
+ break;
+ case 'r':
+ im_height = atoi(optarg);
+ break;
+ case 's':
+ show = true;
+ break;
+ case 'p':
+ save_path = optarg;
+ break;
+ case 'f':
+ frame = atoi(optarg);
+ break;
+ case 'o':
+ output_file = optarg;
+ break;
+ }
+ }
+
+ cout << "------------Configuration---------" << endl;
+ cout << "anno_dir: " << anno_dir << endl;
+ cout << "detect_dir: " << detect_dir << endl;
+ cout << "im_dir: " << im_dir << endl;
+ cout << "list_im_file: " << list_im_file << endl;
+ cout << "width_lane: " << width_lane << endl;
+ cout << "iou_threshold: " << iou_threshold << endl;
+ cout << "im_width: " << im_width << endl;
+ cout << "im_height: " << im_height << endl;
+ cout << "-----------------------------------" << endl;
+ cout << "Evaluating the results..." << endl;
+ // this is the max_width and max_height
+
+ if (width_lane < 1) {
+ cerr << "width_lane must be positive" << endl;
+ help();
+ return 1;
+ }
+
+ ifstream ifs_im_list(list_im_file, ios::in);
+ if (ifs_im_list.fail()) {
+ cerr << "Error: file " << list_im_file << " not exist!" << endl;
+ return 1;
+ }
+
+ Counter counter(im_width, im_height, iou_threshold, width_lane);
+
+ vector anno_match;
+ string sub_im_name;
+ // pre-load filelist
+ vector filelists;
+ while (getline(ifs_im_list, sub_im_name)) {
+ filelists.push_back(sub_im_name);
+ }
+ ifs_im_list.close();
+
+ vector, long, long, long, long>> tuple_lists;
+ tuple_lists.resize(filelists.size());
+
+#pragma omp parallel for
+ for (size_t i = 0; i < filelists.size(); i++) {
+ auto sub_im_name = filelists[i];
+ string full_im_name = im_dir + sub_im_name;
+ string sub_txt_name =
+ sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt";
+ string anno_file_name = anno_dir + sub_txt_name;
+ string detect_file_name = detect_dir + sub_txt_name;
+ vector> anno_lanes;
+ vector> detect_lanes;
+ read_lane_file(anno_file_name, anno_lanes);
+ read_lane_file(detect_file_name, detect_lanes);
+ // cerr<(tuple_lists[i]);
+ visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane);
+ waitKey(0);
+ }
+ if (save_path != "") {
+ auto anno_match = get<0>(tuple_lists[i]);
+ visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane,
+ save_path);
+ }
+ }
+
+ long tp = 0, fp = 0, tn = 0, fn = 0;
+ for (auto result : tuple_lists) {
+ tp += get<1>(result);
+ fp += get<2>(result);
+ // tn = get<3>(result);
+ fn += get<4>(result);
+ }
+ counter.setTP(tp);
+ counter.setFP(fp);
+ counter.setFN(fn);
+
+ double precision = counter.get_precision();
+ double recall = counter.get_recall();
+ double F = 2 * precision * recall / (precision + recall);
+ cerr << "finished process file" << endl;
+ cout << "precision: " << precision << endl;
+ cout << "recall: " << recall << endl;
+ cout << "Fmeasure: " << F << endl;
+ cout << "----------------------------------" << endl;
+
+ ofstream ofs_out_file;
+ ofs_out_file.open(output_file, ios::out);
+ ofs_out_file << "file: " << output_file << endl;
+ ofs_out_file << "tp: " << counter.getTP() << " fp: " << counter.getFP()
+ << " fn: " << counter.getFN() << endl;
+ ofs_out_file << "precision: " << precision << endl;
+ ofs_out_file << "recall: " << recall << endl;
+ ofs_out_file << "Fmeasure: " << F << endl << endl;
+ ofs_out_file.close();
+ return 0;
+}
+
+void read_lane_file(const string &file_name, vector> &lanes) {
+ lanes.clear();
+ ifstream ifs_lane(file_name, ios::in);
+ if (ifs_lane.fail()) {
+ return;
+ }
+
+ string str_line;
+ while (getline(ifs_lane, str_line)) {
+ vector curr_lane;
+ stringstream ss;
+ ss << str_line;
+ double x, y;
+ while (ss >> x >> y) {
+ curr_lane.push_back(Point2f(x, y));
+ }
+ lanes.push_back(curr_lane);
+ }
+
+ ifs_lane.close();
+}
+
+void visualize(string &full_im_name, vector> &anno_lanes,
+ vector> &detect_lanes, vector anno_match,
+ int width_lane, string save_path) {
+ Mat img = imread(full_im_name, 1);
+ Mat img2 = imread(full_im_name, 1);
+ vector curr_lane;
+ vector p_interp;
+ Spline splineSolver;
+ Scalar color_B = Scalar(255, 0, 0);
+ Scalar color_G = Scalar(0, 255, 0);
+ Scalar color_R = Scalar(0, 0, 255);
+ Scalar color_P = Scalar(255, 0, 255);
+ Scalar color;
+ for (int i = 0; i < anno_lanes.size(); i++) {
+ curr_lane = anno_lanes[i];
+ if (curr_lane.size() == 2) {
+ p_interp = curr_lane;
+ } else {
+ p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
+ }
+ if (anno_match[i] >= 0) {
+ color = color_G;
+ } else {
+ color = color_G;
+ }
+ for (int n = 0; n < p_interp.size() - 1; n++) {
+ line(img, p_interp[n], p_interp[n + 1], color, width_lane);
+ line(img2, p_interp[n], p_interp[n + 1], color, 2);
+ }
+ }
+ bool detected;
+ for (int i = 0; i < detect_lanes.size(); i++) {
+ detected = false;
+ curr_lane = detect_lanes[i];
+ if (curr_lane.size() == 2) {
+ p_interp = curr_lane;
+ } else {
+ p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
+ }
+ for (int n = 0; n < anno_lanes.size(); n++) {
+ if (anno_match[n] == i) {
+ detected = true;
+ break;
+ }
+ }
+ if (detected == true) {
+ color = color_B;
+ } else {
+ color = color_R;
+ }
+ for (int n = 0; n < p_interp.size() - 1; n++) {
+ line(img, p_interp[n], p_interp[n + 1], color, width_lane);
+ line(img2, p_interp[n], p_interp[n + 1], color, 2);
+ }
+ }
+ if (save_path != "") {
+ size_t pos = 0;
+ string s = full_im_name;
+ std::string token;
+ std::string delimiter = "/";
+ vector names;
+ while ((pos = s.find(delimiter)) != std::string::npos) {
+ token = s.substr(0, pos);
+ names.emplace_back(token);
+ s.erase(0, pos + delimiter.length());
+ }
+ names.emplace_back(s);
+ string file_name = names[3] + '_' + names[4] + '_' + names[5];
+ // cout << file_name << endl;
+ imwrite(save_path + '/' + file_name, img);
+ } else {
+ namedWindow("visualize", 1);
+ imshow("visualize", img);
+ namedWindow("visualize2", 1);
+ imshow("visualize2", img2);
+ }
+}
diff --git a/runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp b/runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp
new file mode 100644
index 0000000..83d08b9
--- /dev/null
+++ b/runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp
@@ -0,0 +1,73 @@
+/*************************************************************************
+ > File Name: lane_compare.cpp
+ > Author: Xingang Pan, Jun Li
+ > Mail: px117@ie.cuhk.edu.hk
+ > Created Time: Fri Jul 15 10:26:32 2016
+ ************************************************************************/
+
+#include "lane_compare.hpp"
+
+double LaneCompare::get_lane_similarity(const vector &lane1, const vector &lane2)
+{
+ if(lane1.size()<2 || lane2.size()<2)
+ {
+ cerr<<"lane size must be greater or equal to 2"< p_interp1;
+ vector p_interp2;
+ if(lane1.size() == 2)
+ {
+ p_interp1 = lane1;
+ }
+ else
+ {
+ p_interp1 = splineSolver.splineInterpTimes(lane1, 50);
+ }
+
+ if(lane2.size() == 2)
+ {
+ p_interp2 = lane2;
+ }
+ else
+ {
+ p_interp2 = splineSolver.splineInterpTimes(lane2, 50);
+ }
+
+ Scalar color_white = Scalar(1);
+ for(int n=0; n &curr_lane, int curr_width, int curr_height)
+{
+ if(curr_width == im_width && curr_height == im_height)
+ {
+ return;
+ }
+ double x_scale = im_width/(double)curr_width;
+ double y_scale = im_height/(double)curr_height;
+ for(int n=0; n
+#include
+#include "spline.hpp"
+using namespace std;
+using namespace cv;
+
+vector Spline::splineInterpTimes(const vector& tmp_line, int times) {
+ vector res;
+
+ if(tmp_line.size() == 2) {
+ double x1 = tmp_line[0].x;
+ double y1 = tmp_line[0].y;
+ double x2 = tmp_line[1].x;
+ double y2 = tmp_line[1].y;
+
+ for (int k = 0; k <= times; k++) {
+ double xi = x1 + double((x2 - x1) * k) / times;
+ double yi = y1 + double((y2 - y1) * k) / times;
+ res.push_back(Point2f(xi, yi));
+ }
+ }
+
+ else if(tmp_line.size() > 2)
+ {
+ vector tmp_func;
+ tmp_func = this->cal_fun(tmp_line);
+ if (tmp_func.empty()) {
+ cout << "in splineInterpTimes: cal_fun failed" << endl;
+ return res;
+ }
+ for(int j = 0; j < tmp_func.size(); j++)
+ {
+ double delta = tmp_func[j].h / times;
+ for(int k = 0; k < times; k++)
+ {
+ double t1 = delta*k;
+ double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
+ double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
+ res.push_back(Point2f(x1, y1));
+ }
+ }
+ res.push_back(tmp_line[tmp_line.size() - 1]);
+ }
+ else {
+ cerr << "in splineInterpTimes: not enough points" << endl;
+ }
+ return res;
+}
+vector Spline::splineInterpStep(vector tmp_line, double step) {
+ vector res;
+ /*
+ if (tmp_line.size() == 2) {
+ double x1 = tmp_line[0].x;
+ double y1 = tmp_line[0].y;
+ double x2 = tmp_line[1].x;
+ double y2 = tmp_line[1].y;
+
+ for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) {
+ double xi;
+ if (yi == y1) xi = x1;
+ else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1;
+ res.push_back(Point2f(xi, yi));
+ }
+ }*/
+ if (tmp_line.size() == 2) {
+ double x1 = tmp_line[0].x;
+ double y1 = tmp_line[0].y;
+ double x2 = tmp_line[1].x;
+ double y2 = tmp_line[1].y;
+ tmp_line[1].x = (x1 + x2) / 2;
+ tmp_line[1].y = (y1 + y2) / 2;
+ tmp_line.push_back(Point2f(x2, y2));
+ }
+ if (tmp_line.size() > 2) {
+ vector tmp_func;
+ tmp_func = this->cal_fun(tmp_line);
+ double ystart = tmp_line[0].y;
+ double yend = tmp_line[tmp_line.size() - 1].y;
+ bool down;
+ if (ystart < yend) down = 1;
+ else down = 0;
+ if (tmp_func.empty()) {
+ cerr << "in splineInterpStep: cal_fun failed" << endl;
+ }
+
+ for(int j = 0; j < tmp_func.size(); j++)
+ {
+ for(double t1 = 0; t1 < tmp_func[j].h; t1 += step)
+ {
+ double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
+ double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
+ res.push_back(Point2f(x1, y1));
+ }
+ }
+ res.push_back(tmp_line[tmp_line.size() - 1]);
+ }
+ else {
+ cerr << "in splineInterpStep: not enough points" << endl;
+ }
+ return res;
+}
+
+vector Spline::cal_fun(const vector &point_v)
+{
+ vector func_v;
+ int n = point_v.size();
+ if(n<=2) {
+ cout << "in cal_fun: point number less than 3" << endl;
+ return func_v;
+ }
+
+ func_v.resize(point_v.size()-1);
+
+ vector Mx(n);
+ vector My(n);
+ vector A(n-2);
+ vector B(n-2);
+ vector C(n-2);
+ vector Dx(n-2);
+ vector Dy(n-2);
+ vector h(n-1);
+ //vector func_v(n-1);
+
+ for(int i = 0; i < n-1; i++)
+ {
+ h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2));
+ }
+
+ for(int i = 0; i < n-2; i++)
+ {
+ A[i] = h[i];
+ B[i] = 2*(h[i]+h[i+1]);
+ C[i] = h[i+1];
+
+ Dx[i] = 6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] );
+ Dy[i] = 6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] );
+ }
+
+ //TDMA
+ C[0] = C[0] / B[0];
+ Dx[0] = Dx[0] / B[0];
+ Dy[0] = Dy[0] / B[0];
+ for(int i = 1; i < n-2; i++)
+ {
+ double tmp = B[i] - A[i]*C[i-1];
+ C[i] = C[i] / tmp;
+ Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp;
+ Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp;
+ }
+ Mx[n-2] = Dx[n-3];
+ My[n-2] = Dy[n-3];
+ for(int i = n-4; i >= 0; i--)
+ {
+ Mx[i+1] = Dx[i] - C[i]*Mx[i+2];
+ My[i+1] = Dy[i] - C[i]*My[i+2];
+ }
+
+ Mx[0] = 0;
+ Mx[n-1] = 0;
+ My[0] = 0;
+ My[n-1] = 0;
+
+ for(int i = 0; i < n-1; i++)
+ {
+ func_v[i].a_x = point_v[i].x;
+ func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6;
+ func_v[i].c_x = Mx[i]/2;
+ func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]);
+
+ func_v[i].a_y = point_v[i].y;
+ func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6;
+ func_v[i].c_y = My[i]/2;
+ func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]);
+
+ func_v[i].h = h[i];
+ }
+ return func_v;
+}
diff --git a/runner/evaluator/culane/prob2lines.py b/runner/evaluator/culane/prob2lines.py
new file mode 100644
index 0000000..f23caed
--- /dev/null
+++ b/runner/evaluator/culane/prob2lines.py
@@ -0,0 +1,51 @@
+import os
+import argparse
+import numpy as np
+import pandas as pd
+from PIL import Image
+import tqdm
+
+
+def getLane(probmap, pts, cfg = None):
+ thr = 0.3
+ coordinate = np.zeros(pts)
+ cut_height = 0
+ if cfg.cut_height:
+ cut_height = cfg.cut_height
+ for i in range(pts):
+ line = probmap[round(cfg.img_height-i*20/(590-cut_height)*cfg.img_height)-1]
+ if np.max(line)/255 > thr:
+ coordinate[i] = np.argmax(line)+1
+ if np.sum(coordinate > 0) < 2:
+ coordinate = np.zeros(pts)
+ return coordinate
+
+
+def prob2lines(prob_dir, out_dir, list_file, cfg = None):
+ lists = pd.read_csv(list_file, sep=' ', header=None,
+ names=('img', 'probmap', 'label1', 'label2', 'label3', 'label4'))
+ pts = 18
+
+ for k, im in enumerate(lists['img'], 1):
+ existPath = prob_dir + im[:-4] + '.exist.txt'
+ outname = out_dir + im[:-4] + '.lines.txt'
+ prefix = '/'.join(outname.split('/')[:-1])
+ if not os.path.exists(prefix):
+ os.makedirs(prefix)
+ f = open(outname, 'w')
+
+ labels = list(pd.read_csv(existPath, sep=' ', header=None).iloc[0])
+ coordinates = np.zeros((4, pts))
+ for i in range(4):
+ if labels[i] == 1:
+ probfile = prob_dir + im[:-4] + '_{0}_avg.png'.format(i+1)
+ probmap = np.array(Image.open(probfile))
+ coordinates[i] = getLane(probmap, pts, cfg)
+
+ if np.sum(coordinates[i] > 0) > 1:
+ for idx, value in enumerate(coordinates[i]):
+ if value > 0:
+ f.write('%d %d ' % (
+ round(value*1640/cfg.img_width)-1, round(590-idx*20)-1))
+ f.write('\n')
+ f.close()
diff --git a/runner/evaluator/tusimple/getLane.py b/runner/evaluator/tusimple/getLane.py
new file mode 100644
index 0000000..8026127
--- /dev/null
+++ b/runner/evaluator/tusimple/getLane.py
@@ -0,0 +1,115 @@
+import cv2
+import numpy as np
+
+def isShort(lane):
+ start = [i for i, x in enumerate(lane) if x > 0]
+ if not start:
+ return 1
+ else:
+ return 0
+
+def fixGap(coordinate):
+ if any(x > 0 for x in coordinate):
+ start = [i for i, x in enumerate(coordinate) if x > 0][0]
+ end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
+ lane = coordinate[start:end+1]
+ if any(x < 0 for x in lane):
+ gap_start = [i for i, x in enumerate(
+ lane[:-1]) if x > 0 and lane[i+1] < 0]
+ gap_end = [i+1 for i,
+ x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
+ gap_id = [i for i, x in enumerate(lane) if x < 0]
+ if len(gap_start) == 0 or len(gap_end) == 0:
+ return coordinate
+ for id in gap_id:
+ for i in range(len(gap_start)):
+ if i >= len(gap_end):
+ return coordinate
+ if id > gap_start[i] and id < gap_end[i]:
+ gap_width = float(gap_end[i] - gap_start[i])
+ lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
+ gap_end[i] - id) / gap_width * lane[gap_start[i]])
+ if not all(x > 0 for x in lane):
+ print("Gaps still exist!")
+ coordinate[start:end+1] = lane
+ return coordinate
+
+def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None):
+ """
+ Arguments:
+ ----------
+ prob_map: prob map for single lane, np array size (h, w)
+ resize_shape: reshape size target, (H, W)
+
+ Return:
+ ----------
+ coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
+ """
+ if resize_shape is None:
+ resize_shape = prob_map.shape
+ h, w = prob_map.shape
+ H, W = resize_shape
+ H -= cfg.cut_height
+
+ coords = np.zeros(pts)
+ coords[:] = -1.0
+ for i in range(pts):
+ y = int((H - 10 - i * y_px_gap) * h / H)
+ if y < 0:
+ break
+ line = prob_map[y, :]
+ id = np.argmax(line)
+ if line[id] > thresh:
+ coords[i] = int(id / w * W)
+ if (coords > 0).sum() < 2:
+ coords = np.zeros(pts)
+ fixGap(coords)
+ return coords
+
+
+def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None):
+ """
+ Arguments:
+ ----------
+ seg_pred: np.array size (5, h, w)
+ resize_shape: reshape size target, (H, W)
+ exist: list of existence, e.g. [0, 1, 1, 0]
+ smooth: whether to smooth the probability or not
+ y_px_gap: y pixel gap for sampling
+ pts: how many points for one lane
+ thresh: probability threshold
+
+ Return:
+ ----------
+ coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
+ """
+ if resize_shape is None:
+ resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
+ _, h, w = seg_pred.shape
+ H, W = resize_shape
+ coordinates = []
+
+ if pts is None:
+ pts = round(H / 2 / y_px_gap)
+
+ seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0)))
+ for i in range(cfg.num_classes - 1):
+ prob_map = seg_pred[..., i + 1]
+ if smooth:
+ prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
+ coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg)
+ if isShort(coords):
+ continue
+ coordinates.append(
+ [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
+ range(pts)])
+
+
+ if len(coordinates) == 0:
+ coords = np.zeros(pts)
+ coordinates.append(
+ [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
+ range(pts)])
+
+
+ return coordinates
diff --git a/runner/evaluator/tusimple/lane.py b/runner/evaluator/tusimple/lane.py
new file mode 100644
index 0000000..44abf70
--- /dev/null
+++ b/runner/evaluator/tusimple/lane.py
@@ -0,0 +1,108 @@
+import numpy as np
+from sklearn.linear_model import LinearRegression
+import json as json
+
+
+class LaneEval(object):
+ lr = LinearRegression()
+ pixel_thresh = 20
+ pt_thresh = 0.85
+
+ @staticmethod
+ def get_angle(xs, y_samples):
+ xs, ys = xs[xs >= 0], y_samples[xs >= 0]
+ if len(xs) > 1:
+ LaneEval.lr.fit(ys[:, None], xs)
+ k = LaneEval.lr.coef_[0]
+ theta = np.arctan(k)
+ else:
+ theta = 0
+ return theta
+
+ @staticmethod
+ def line_accuracy(pred, gt, thresh):
+ pred = np.array([p if p >= 0 else -100 for p in pred])
+ gt = np.array([g if g >= 0 else -100 for g in gt])
+ return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)
+
+ @staticmethod
+ def bench(pred, gt, y_samples, running_time):
+ if any(len(p) != len(y_samples) for p in pred):
+ raise Exception('Format of lanes error.')
+ if running_time > 200 or len(gt) + 2 < len(pred):
+ return 0., 0., 1.
+ angles = [LaneEval.get_angle(
+ np.array(x_gts), np.array(y_samples)) for x_gts in gt]
+ threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
+ line_accs = []
+ fp, fn = 0., 0.
+ matched = 0.
+ for x_gts, thresh in zip(gt, threshs):
+ accs = [LaneEval.line_accuracy(
+ np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
+ max_acc = np.max(accs) if len(accs) > 0 else 0.
+ if max_acc < LaneEval.pt_thresh:
+ fn += 1
+ else:
+ matched += 1
+ line_accs.append(max_acc)
+ fp = len(pred) - matched
+ if len(gt) > 4 and fn > 0:
+ fn -= 1
+ s = sum(line_accs)
+ if len(gt) > 4:
+ s -= min(line_accs)
+ return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)
+
+ @staticmethod
+ def bench_one_submit(pred_file, gt_file):
+ try:
+ json_pred = [json.loads(line)
+ for line in open(pred_file).readlines()]
+ except BaseException as e:
+ raise Exception('Fail to load json file of the prediction.')
+ json_gt = [json.loads(line) for line in open(gt_file).readlines()]
+ if len(json_gt) != len(json_pred):
+ raise Exception(
+ 'We do not get the predictions of all the test tasks')
+ gts = {l['raw_file']: l for l in json_gt}
+ accuracy, fp, fn = 0., 0., 0.
+ for pred in json_pred:
+ if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
+ raise Exception(
+ 'raw_file or lanes or run_time not in some predictions.')
+ raw_file = pred['raw_file']
+ pred_lanes = pred['lanes']
+ run_time = pred['run_time']
+ if raw_file not in gts:
+ raise Exception(
+ 'Some raw_file from your predictions do not exist in the test tasks.')
+ gt = gts[raw_file]
+ gt_lanes = gt['lanes']
+ y_samples = gt['h_samples']
+ try:
+ a, p, n = LaneEval.bench(
+ pred_lanes, gt_lanes, y_samples, run_time)
+ except BaseException as e:
+ raise Exception('Format of lanes error.')
+ accuracy += a
+ fp += p
+ fn += n
+ num = len(gts)
+ # the first return parameter is the default ranking parameter
+ return json.dumps([
+ {'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'},
+ {'name': 'FP', 'value': fp / num, 'order': 'asc'},
+ {'name': 'FN', 'value': fn / num, 'order': 'asc'}
+ ]), accuracy / num
+
+
+if __name__ == '__main__':
+ import sys
+ try:
+ if len(sys.argv) != 3:
+ raise Exception('Invalid input arguments')
+ print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
+ except Exception as e:
+ print(e.message)
+ sys.exit(e.message)
diff --git a/runner/evaluator/tusimple/tusimple.py b/runner/evaluator/tusimple/tusimple.py
new file mode 100644
index 0000000..12d8a99
--- /dev/null
+++ b/runner/evaluator/tusimple/tusimple.py
@@ -0,0 +1,111 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+from runner.logger import get_logger
+
+from runner.registry import EVALUATOR
+import json
+import os
+import cv2
+
+from .lane import LaneEval
+
+def split_path(path):
+ """split path tree into list"""
+ folders = []
+ while True:
+ path, folder = os.path.split(path)
+ if folder != "":
+ folders.insert(0, folder)
+ else:
+ if path != "":
+ folders.insert(0, path)
+ break
+ return folders
+
+
+@EVALUATOR.register_module
+class Tusimple(nn.Module):
+ def __init__(self, cfg):
+ super(Tusimple, self).__init__()
+ self.cfg = cfg
+ exp_dir = os.path.join(self.cfg.work_dir, "output")
+ if not os.path.exists(exp_dir):
+ os.mkdir(exp_dir)
+ self.out_path = os.path.join(exp_dir, "coord_output")
+ if not os.path.exists(self.out_path):
+ os.mkdir(self.out_path)
+ self.dump_to_json = []
+ self.thresh = cfg.evaluator.thresh
+ self.logger = get_logger('resa')
+ if cfg.view:
+ self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
+
+ def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
+ img_name = batch['meta']['img_name']
+ img_path = batch['meta']['full_img_path']
+ for b in range(len(seg_pred)):
+ seg = seg_pred[b]
+ exist = [1 if exist_pred[b, i] >
+ 0.5 else 0 for i in range(self.cfg.num_classes-1)]
+ lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
+ for i in range(len(lane_coords)):
+ lane_coords[i] = sorted(
+ lane_coords[i], key=lambda pair: pair[1])
+
+ path_tree = split_path(img_name[b])
+ save_dir, save_name = path_tree[-3:-1], path_tree[-1]
+ save_dir = os.path.join(self.out_path, *save_dir)
+ save_name = save_name[:-3] + "lines.txt"
+ save_name = os.path.join(save_dir, save_name)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+
+ with open(save_name, "w") as f:
+ for l in lane_coords:
+ for (x, y) in l:
+ print("{} {}".format(x, y), end=" ", file=f)
+ print(file=f)
+
+ json_dict = {}
+ json_dict['lanes'] = []
+ json_dict['h_sample'] = []
+ json_dict['raw_file'] = os.path.join(*path_tree[-4:])
+ json_dict['run_time'] = 0
+ for l in lane_coords:
+ if len(l) == 0:
+ continue
+ json_dict['lanes'].append([])
+ for (x, y) in l:
+ json_dict['lanes'][-1].append(int(x))
+ for (x, y) in lane_coords[0]:
+ json_dict['h_sample'].append(y)
+ self.dump_to_json.append(json.dumps(json_dict))
+ if self.cfg.view:
+ img = cv2.imread(img_path[b])
+ new_img_name = img_name[b].replace('/', '_')
+ save_dir = os.path.join(self.view_dir, new_img_name)
+ dataset.view(img, lane_coords, save_dir)
+
+
+ def evaluate(self, dataset, output, batch):
+ seg_pred, exist_pred = output['seg'], output['exist']
+ seg_pred = F.softmax(seg_pred, dim=1)
+ seg_pred = seg_pred.detach().cpu().numpy()
+ exist_pred = exist_pred.detach().cpu().numpy()
+ self.evaluate_pred(dataset, seg_pred, exist_pred, batch)
+
+ def summarize(self):
+ best_acc = 0
+ output_file = os.path.join(self.out_path, 'predict_test.json')
+ with open(output_file, "w+") as f:
+ for line in self.dump_to_json:
+ print(line, end="\n", file=f)
+
+ eval_result, acc = LaneEval.bench_one_submit(output_file,
+ self.cfg.test_json_file)
+
+ self.logger.info(eval_result)
+ self.dump_to_json = []
+ best_acc = max(acc, best_acc)
+ return best_acc
diff --git a/runner/logger.py b/runner/logger.py
new file mode 100644
index 0000000..189d353
--- /dev/null
+++ b/runner/logger.py
@@ -0,0 +1,50 @@
+import logging
+
+logger_initialized = {}
+
+def get_logger(name, log_file=None, log_level=logging.INFO):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if log_file is not None:
+ file_handler = logging.FileHandler(log_file, 'w')
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ logger.setLevel(log_level)
+
+ logger_initialized[name] = True
+
+ return logger
diff --git a/runner/net_utils.py b/runner/net_utils.py
new file mode 100644
index 0000000..0038fe6
--- /dev/null
+++ b/runner/net_utils.py
@@ -0,0 +1,43 @@
+import torch
+import os
+from torch import nn
+import numpy as np
+import torch.nn.functional
+from termcolor import colored
+from .logger import get_logger
+
+def save_model(net, optim, scheduler, recorder, is_best=False):
+ model_dir = os.path.join(recorder.work_dir, 'ckpt')
+ os.system('mkdir -p {}'.format(model_dir))
+ epoch = recorder.epoch
+ ckpt_name = 'best' if is_best else epoch
+ torch.save({
+ 'net': net.state_dict(),
+ 'optim': optim.state_dict(),
+ 'scheduler': scheduler.state_dict(),
+ 'recorder': recorder.state_dict(),
+ 'epoch': epoch
+ }, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
+
+
+def load_network_specified(net, model_dir, logger=None):
+ pretrained_net = torch.load(model_dir)['net']
+ net_state = net.state_dict()
+ state = {}
+ for k, v in pretrained_net.items():
+ if k not in net_state.keys() or v.size() != net_state[k].size():
+ if logger:
+ logger.info('skip weights: ' + k)
+ continue
+ state[k] = v
+ net.load_state_dict(state, strict=False)
+
+
+def load_network(net, model_dir, finetune_from=None, logger=None):
+ if finetune_from:
+ if logger:
+ logger.info('Finetune model from: ' + finetune_from)
+ load_network_specified(net, finetune_from, logger)
+ return
+ pretrained_model = torch.load(model_dir)
+ net.load_state_dict(pretrained_model['net'], strict=True)
diff --git a/runner/optimizer.py b/runner/optimizer.py
new file mode 100644
index 0000000..9f2f836
--- /dev/null
+++ b/runner/optimizer.py
@@ -0,0 +1,26 @@
+import torch
+
+
+_optimizer_factory = {
+ 'adam': torch.optim.Adam,
+ 'sgd': torch.optim.SGD
+}
+
+
+def build_optimizer(cfg, net):
+ params = []
+ lr = cfg.optimizer.lr
+ weight_decay = cfg.optimizer.weight_decay
+
+ for key, value in net.named_parameters():
+ if not value.requires_grad:
+ continue
+ params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
+
+ if 'adam' in cfg.optimizer.type:
+ optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
+ else:
+ optimizer = _optimizer_factory[cfg.optimizer.type](
+ params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)
+
+ return optimizer
diff --git a/runner/recorder.py b/runner/recorder.py
new file mode 100644
index 0000000..2ae345b
--- /dev/null
+++ b/runner/recorder.py
@@ -0,0 +1,100 @@
+from collections import deque, defaultdict
+import torch
+import os
+import datetime
+from .logger import get_logger
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20):
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+
+ def update(self, value):
+ self.deque.append(value)
+ self.count += 1
+ self.total += value
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque))
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+
+class Recorder(object):
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.work_dir = self.get_work_dir()
+ cfg.work_dir = self.work_dir
+ self.log_path = os.path.join(self.work_dir, 'log.txt')
+
+ self.logger = get_logger('resa', self.log_path)
+ self.logger.info('Config: \n' + cfg.text)
+
+ # scalars
+ self.epoch = 0
+ self.step = 0
+ self.loss_stats = defaultdict(SmoothedValue)
+ self.batch_time = SmoothedValue()
+ self.data_time = SmoothedValue()
+ self.max_iter = self.cfg.total_iter
+ self.lr = 0.
+
+ def get_work_dir(self):
+ now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
+ hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)
+ work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ return work_dir
+
+ def update_loss_stats(self, loss_dict):
+ for k, v in loss_dict.items():
+ self.loss_stats[k].update(v.detach().cpu())
+
+ def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
+ self.logger.info(self)
+ # self.write(str(self))
+
+ def write(self, content):
+ with open(self.log_path, 'a+') as f:
+ f.write(content)
+ f.write('\n')
+
+ def state_dict(self):
+ scalar_dict = {}
+ scalar_dict['step'] = self.step
+ return scalar_dict
+
+ def load_state_dict(self, scalar_dict):
+ self.step = scalar_dict['step']
+
+ def __str__(self):
+ loss_state = []
+ for k, v in self.loss_stats.items():
+ loss_state.append('{}: {:.4f}'.format(k, v.avg))
+ loss_state = ' '.join(loss_state)
+
+ recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
+ eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)
+
+
+def build_recorder(cfg):
+ return Recorder(cfg)
+
diff --git a/runner/registry.py b/runner/registry.py
new file mode 100644
index 0000000..c1c119b
--- /dev/null
+++ b/runner/registry.py
@@ -0,0 +1,19 @@
+from utils import Registry, build_from_cfg
+
+TRAINER = Registry('trainer')
+EVALUATOR = Registry('evaluator')
+
+def build(cfg, registry, default_args=None):
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return nn.Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+def build_trainer(cfg):
+ return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))
+
+def build_evaluator(cfg):
+ return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))
diff --git a/runner/resa_trainer.py b/runner/resa_trainer.py
new file mode 100644
index 0000000..7cdad78
--- /dev/null
+++ b/runner/resa_trainer.py
@@ -0,0 +1,58 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+
+from runner.registry import TRAINER
+
+def dice_loss(input, target):
+ input = input.contiguous().view(input.size()[0], -1)
+ target = target.contiguous().view(target.size()[0], -1).float()
+
+ a = torch.sum(input * target, 1)
+ b = torch.sum(input * input, 1) + 0.001
+ c = torch.sum(target * target, 1) + 0.001
+ d = (2 * a) / (b + c)
+ return (1-d).mean()
+
+@TRAINER.register_module
+class RESA(nn.Module):
+ def __init__(self, cfg):
+ super(RESA, self).__init__()
+ self.cfg = cfg
+ self.loss_type = cfg.loss_type
+ if self.loss_type == 'cross_entropy':
+ weights = torch.ones(cfg.num_classes)
+ weights[0] = cfg.bg_weight
+ weights = weights.cuda()
+ self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
+ weight=weights).cuda()
+
+ self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
+
+ def forward(self, net, batch):
+ output = net(batch['img'])
+
+ loss_stats = {}
+ loss = 0.
+
+ if self.loss_type == 'dice_loss':
+ target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
+ seg_loss = dice_loss(F.softmax(
+ output['seg'], dim=1)[:, 1:], target[:, 1:])
+ else:
+ seg_loss = self.criterion(F.log_softmax(
+ output['seg'], dim=1), batch['label'].long())
+
+ loss += seg_loss * self.cfg.seg_loss_weight
+
+ loss_stats.update({'seg_loss': seg_loss})
+
+ if 'exist' in output:
+ exist_loss = 0.1 * \
+ self.criterion_exist(output['exist'], batch['exist'].float())
+ loss += exist_loss
+ loss_stats.update({'exist_loss': exist_loss})
+
+ ret = {'loss': loss, 'loss_stats': loss_stats}
+
+ return ret
diff --git a/runner/runner.py b/runner/runner.py
new file mode 100644
index 0000000..3e8f9c4
--- /dev/null
+++ b/runner/runner.py
@@ -0,0 +1,112 @@
+import time
+import torch
+import numpy as np
+from tqdm import tqdm
+import pytorch_warmup as warmup
+
+from models.registry import build_net
+from .registry import build_trainer, build_evaluator
+from .optimizer import build_optimizer
+from .scheduler import build_scheduler
+from datasets import build_dataloader
+from .recorder import build_recorder
+from .net_utils import save_model, load_network
+
+
+class Runner(object):
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.recorder = build_recorder(self.cfg)
+ self.net = build_net(self.cfg)
+ self.net = torch.nn.parallel.DataParallel(
+ self.net, device_ids = range(self.cfg.gpus)).cuda()
+ self.recorder.logger.info('Network: \n' + str(self.net))
+ self.resume()
+ self.optimizer = build_optimizer(self.cfg, self.net)
+ self.scheduler = build_scheduler(self.cfg, self.optimizer)
+ self.evaluator = build_evaluator(self.cfg)
+ self.warmup_scheduler = warmup.LinearWarmup(
+ self.optimizer, warmup_period=5000)
+ self.metric = 0.
+
+ def resume(self):
+ if not self.cfg.load_from and not self.cfg.finetune_from:
+ return
+ load_network(self.net, self.cfg.load_from,
+ finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)
+
+ def to_cuda(self, batch):
+ for k in batch:
+ if k == 'meta':
+ continue
+ batch[k] = batch[k].cuda()
+ return batch
+
+ def train_epoch(self, epoch, train_loader):
+ self.net.train()
+ end = time.time()
+ max_iter = len(train_loader)
+ for i, data in enumerate(train_loader):
+ if self.recorder.step >= self.cfg.total_iter:
+ break
+ date_time = time.time() - end
+ self.recorder.step += 1
+ data = self.to_cuda(data)
+ output = self.trainer.forward(self.net, data)
+ self.optimizer.zero_grad()
+ loss = output['loss']
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ self.warmup_scheduler.dampen()
+ batch_time = time.time() - end
+ end = time.time()
+ self.recorder.update_loss_stats(output['loss_stats'])
+ self.recorder.batch_time.update(batch_time)
+ self.recorder.data_time.update(date_time)
+
+ if i % self.cfg.log_interval == 0 or i == max_iter - 1:
+ lr = self.optimizer.param_groups[0]['lr']
+ self.recorder.lr = lr
+ self.recorder.record('train')
+
+ def train(self):
+ self.recorder.logger.info('start training...')
+ self.trainer = build_trainer(self.cfg)
+ train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
+ val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)
+
+ for epoch in range(self.cfg.epochs):
+ print('Epoch: [{}/{}]'.format(self.recorder.step, self.cfg.total_iter))
+ print('Epoch: [{}/{}]'.format(epoch, self.cfg.epochs))
+ self.recorder.epoch = epoch
+ self.train_epoch(epoch, train_loader)
+ if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
+ self.save_ckpt()
+ if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
+ self.validate(val_loader)
+ if self.recorder.step >= self.cfg.total_iter:
+ break
+
+ def validate(self, val_loader):
+ self.net.eval()
+ count = 10
+ for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
+ start_time = time.time()
+ data = self.to_cuda(data)
+ with torch.no_grad():
+ output = self.net(data['img'])
+ self.evaluator.evaluate(val_loader.dataset, output, data)
+ # print("第{}张图片检测花了{}秒".format(i,time.time()-start_time))
+
+ metric = self.evaluator.summarize()
+ if not metric:
+ return
+ if metric > self.metric:
+ self.metric = metric
+ self.save_ckpt(is_best=True)
+ self.recorder.logger.info('Best metric: ' + str(self.metric))
+
+ def save_ckpt(self, is_best=False):
+ save_model(self.net, self.optimizer, self.scheduler,
+ self.recorder, is_best)
diff --git a/runner/scheduler.py b/runner/scheduler.py
new file mode 100644
index 0000000..6843dc2
--- /dev/null
+++ b/runner/scheduler.py
@@ -0,0 +1,20 @@
+import torch
+import math
+
+
+_scheduler_factory = {
+ 'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
+}
+
+
+def build_scheduler(cfg, optimizer):
+
+ assert cfg.scheduler.type in _scheduler_factory
+
+ cfg_cp = cfg.scheduler.copy()
+ cfg_cp.pop('type')
+
+ scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)
+
+
+ return scheduler
diff --git a/tools/generate_seg_tusimple.py b/tools/generate_seg_tusimple.py
new file mode 100644
index 0000000..cf8273d
--- /dev/null
+++ b/tools/generate_seg_tusimple.py
@@ -0,0 +1,105 @@
+import json
+import numpy as np
+import cv2
+import os
+import argparse
+
+TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json']
+VAL_SET = ['label_data_0531.json']
+TRAIN_VAL_SET = TRAIN_SET + VAL_SET
+TEST_SET = ['test_label.json']
+
+def gen_label_for_json(args, image_set):
+ H, W = 720, 1280
+ SEG_WIDTH = 30
+ save_dir = args.savedir
+
+ os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True)
+ list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w")
+
+ json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set))
+ with open(json_path) as f:
+ for line in f:
+ label = json.loads(line)
+ # ---------- clean and sort lanes -------------
+ lanes = []
+ _lanes = []
+ slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope
+ for i in range(len(label['lanes'])):
+ l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0]
+ if (len(l)>1):
+ _lanes.append(l)
+ slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180)
+ _lanes = [_lanes[i] for i in np.argsort(slope)]
+ slope = [slope[i] for i in np.argsort(slope)]
+
+ idx = [None for i in range(6)]
+ for i in range(len(slope)):
+ if slope[i] <= 90:
+ idx[2] = i
+ idx[1] = i-1 if i > 0 else None
+ idx[0] = i-2 if i > 1 else None
+ else:
+ idx[3] = i
+ idx[4] = i+1 if i+1 < len(slope) else None
+ idx[5] = i+2 if i+2 < len(slope) else None
+ break
+ for i in range(6):
+ lanes.append([] if idx[i] is None else _lanes[idx[i]])
+
+ # ---------------------------------------------
+
+ img_path = label['raw_file']
+ seg_img = np.zeros((H, W, 3))
+ list_str = [] # str to be written to list.txt
+ for i in range(len(lanes)):
+ coords = lanes[i]
+ if len(coords) < 4:
+ list_str.append('0')
+ continue
+ for j in range(len(coords)-1):
+ cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2)
+ list_str.append('1')
+
+ seg_path = img_path.split("/")
+ seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3]
+ os.makedirs(seg_path, exist_ok=True)
+ seg_path = os.path.join(seg_path, img_name[:-3]+"png")
+ cv2.imwrite(seg_path, seg_img)
+
+ seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"])
+ if seg_path[0] != '/':
+ seg_path = '/' + seg_path
+ if img_path[0] != '/':
+ img_path = '/' + img_path
+ list_str.insert(0, seg_path)
+ list_str.insert(0, img_path)
+ list_str = " ".join(list_str) + "\n"
+ list_f.write(list_str)
+
+
+def generate_json_file(save_dir, json_file, image_set):
+ with open(os.path.join(save_dir, json_file), "w") as outfile:
+ for json_name in (image_set):
+ with open(os.path.join(args.root, json_name)) as infile:
+ for line in infile:
+ outfile.write(line)
+
+def generate_label(args):
+ save_dir = os.path.join(args.root, args.savedir)
+ os.makedirs(save_dir, exist_ok=True)
+ generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET)
+ generate_json_file(save_dir, "test.json", TEST_SET)
+
+ print("generating train_val set...")
+ gen_label_for_json(args, 'train_val')
+ print("generating test set...")
+ gen_label_for_json(args, 'test')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--root', required=True, help='The root of the Tusimple dataset')
+ parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset')
+ args = parser.parse_args()
+
+ generate_label(args)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..eb99ab0
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,2 @@
+from .config import Config
+from .registry import Registry, build_from_cfg
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000..42a0ff2
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,417 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import ast
+import os.path as osp
+import shutil
+import sys
+import tempfile
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+
+class ConfigDict(Dict):
+
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=''):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument('--' + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument('--' + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument('--' + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument('--' + prefix + k, action='store_true')
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + '.')
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+ else:
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename) as f:
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}')
+
+ @staticmethod
+ def _file2dict(filename):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ if filename.endswith('.py'):
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix='.py')
+ temp_config_name = osp.basename(temp_config_file.name)
+ shutil.copyfile(filename,
+ osp.join(temp_config_dir, temp_config_name))
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ # close temp file
+ temp_config_file.close()
+ elif filename.endswith(('.yml', '.yaml', '.json')):
+ import mmcv
+ cfg_dict = mmcv.load(filename)
+ else:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r') as f:
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError('Duplicate key is not allowed among bases')
+ base_cfg_dict.update(c)
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b):
+ # merge dict `a` into dict `b` (non-inplace). values in `a` will
+ # overwrite `b`.
+ # copy first to avoid inplace modification
+ b = b.copy()
+ for k, v in a.items():
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
+ if not isinstance(b[k], dict):
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = Config._merge_a_into_b(v, b[k])
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename):
+ cfg_dict, cfg_text = Config._file2dict(filename)
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)
+ """
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument('config', help='config file path')
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument('config', help='config file path')
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(Config, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(Config, self).__setattr__('_text', text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+ if self.filename.endswith('.py'):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w') as f:
+ f.write(self.pretty_text)
+ else:
+ import mmcv
+ if file is None:
+ file_format = self.filename.split('.')[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options):
+ """Merge list into cfg_dict
+ Merge the dict parsed by MultipleKVAction into this cfg.
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+ Args:
+ options (dict): dict of configs to merge from.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ super(Config, self).__setattr__(
+ '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options should
+ be passed as comma separated values, i.e KEY=V1,V2,V3
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ return val
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ val = [self._parse_int_float_bool(v) for v in val.split(',')]
+ if len(val) == 1:
+ val = val[0]
+ options[key] = val
+ setattr(namespace, self.dest, options)
diff --git a/utils/registry.py b/utils/registry.py
new file mode 100644
index 0000000..e850f5c
--- /dev/null
+++ b/utils/registry.py
@@ -0,0 +1,81 @@
+import inspect
+
+import six
+
+# borrow from mmdetection
+
+def is_str(x):
+ """Whether the input is an string instance."""
+ return isinstance(x, six.string_types)
+
+class Registry(object):
+
+ def __init__(self, name):
+ self._name = name
+ self._module_dict = dict()
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + '(name={}, items={})'.format(
+ self._name, list(self._module_dict.keys()))
+ return format_str
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ def get(self, key):
+ return self._module_dict.get(key, None)
+
+ def _register_module(self, module_class):
+ """Register a module.
+
+ Args:
+ module (:obj:`nn.Module`): Module to be registered.
+ """
+ if not inspect.isclass(module_class):
+ raise TypeError('module must be a class, but got {}'.format(
+ type(module_class)))
+ module_name = module_class.__name__
+ if module_name in self._module_dict:
+ raise KeyError('{} is already registered in {}'.format(
+ module_name, self.name))
+ self._module_dict[module_name] = module_class
+
+ def register_module(self, cls):
+ self._register_module(cls)
+ return cls
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ obj: The constructed object.
+ """
+ assert isinstance(cfg, dict) and 'type' in cfg
+ assert isinstance(default_args, dict) or default_args is None
+ args = {}
+ obj_type = cfg.type
+ if is_str(obj_type):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError('{} is not in the {} registry'.format(
+ obj_type, registry.name))
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError('type must be a str or valid type, but got {}'.format(
+ type(obj_type)))
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_cls(**args)
diff --git a/utils/transforms.py b/utils/transforms.py
new file mode 100644
index 0000000..0960b68
--- /dev/null
+++ b/utils/transforms.py
@@ -0,0 +1,357 @@
+import random
+import cv2
+import numpy as np
+import numbers
+import collections
+
+# copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py
+
+__all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad',
+ 'GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize']
+
+
+class SampleResize(object):
+ def __init__(self, size):
+ assert (isinstance(size, collections.Iterable) and len(size) == 2)
+ self.size = size
+
+ def __call__(self, sample):
+ out = list()
+ out.append(cv2.resize(sample[0], self.size,
+ interpolation=cv2.INTER_CUBIC))
+ if len(sample) > 1:
+ out.append(cv2.resize(sample[1], self.size,
+ interpolation=cv2.INTER_NEAREST))
+ return out
+
+
+class GroupRandomCrop(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img_group):
+ h, w = img_group[0].shape[0:2]
+ th, tw = self.size
+
+ out_images = list()
+ h1 = random.randint(0, max(0, h - th))
+ w1 = random.randint(0, max(0, w - tw))
+ h2 = min(h1 + th, h)
+ w2 = min(w1 + tw, w)
+
+ for img in img_group:
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(img[h1:h2, w1:w2, ...])
+ return out_images
+
+
+class GroupRandomCropRatio(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img_group):
+ h, w = img_group[0].shape[0:2]
+ tw, th = self.size
+
+ out_images = list()
+ h1 = random.randint(0, max(0, h - th))
+ w1 = random.randint(0, max(0, w - tw))
+ h2 = min(h1 + th, h)
+ w2 = min(w1 + tw, w)
+
+ for img in img_group:
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(img[h1:h2, w1:w2, ...])
+ return out_images
+
+
+class GroupCenterCrop(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img_group):
+ h, w = img_group[0].shape[0:2]
+ th, tw = self.size
+
+ out_images = list()
+ h1 = max(0, int((h - th) / 2))
+ w1 = max(0, int((w - tw) / 2))
+ h2 = min(h1 + th, h)
+ w2 = min(w1 + tw, w)
+
+ for img in img_group:
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(img[h1:h2, w1:w2, ...])
+ return out_images
+
+
+class GroupRandomPad(object):
+ def __init__(self, size, padding):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+
+ def __call__(self, img_group):
+ assert (len(self.padding) == len(img_group))
+ h, w = img_group[0].shape[0:2]
+ th, tw = self.size
+
+ out_images = list()
+ h1 = random.randint(0, max(0, th - h))
+ w1 = random.randint(0, max(0, tw - w))
+ h2 = max(th - h - h1, 0)
+ w2 = max(tw - w - w1, 0)
+
+ for img, padding in zip(img_group, self.padding):
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(cv2.copyMakeBorder(
+ img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupCenterPad(object):
+ def __init__(self, size, padding):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+
+ def __call__(self, img_group):
+ assert (len(self.padding) == len(img_group))
+ h, w = img_group[0].shape[0:2]
+ th, tw = self.size
+
+ out_images = list()
+ h1 = max(0, int((th - h) / 2))
+ w1 = max(0, int((tw - w) / 2))
+ h2 = max(th - h - h1, 0)
+ w2 = max(tw - w - w1, 0)
+
+ for img, padding in zip(img_group, self.padding):
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(cv2.copyMakeBorder(
+ img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupConcerPad(object):
+ def __init__(self, size, padding):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+
+ def __call__(self, img_group):
+ assert (len(self.padding) == len(img_group))
+ h, w = img_group[0].shape[0:2]
+ th, tw = self.size
+
+ out_images = list()
+ h1 = 0
+ w1 = 0
+ h2 = max(th - h - h1, 0)
+ w2 = max(tw - w - w1, 0)
+
+ for img, padding in zip(img_group, self.padding):
+ assert (img.shape[0] == h and img.shape[1] == w)
+ out_images.append(cv2.copyMakeBorder(
+ img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupRandomScaleNew(object):
+ def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, img_group):
+ assert (len(self.interpolation) == len(img_group))
+ scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590
+ out_images = list()
+ for img, interpolation in zip(img_group, self.interpolation):
+ out_images.append(cv2.resize(img, None, fx=scale_w,
+ fy=scale_h, interpolation=interpolation))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupRandomScale(object):
+ def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, img_group):
+ assert (len(self.interpolation) == len(img_group))
+ scale = random.uniform(self.size[0], self.size[1])
+ out_images = list()
+ for img, interpolation in zip(img_group, self.interpolation):
+ out_images.append(cv2.resize(img, None, fx=scale,
+ fy=scale, interpolation=interpolation))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupRandomMultiScale(object):
+ def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, img_group):
+ assert (len(self.interpolation) == len(img_group))
+ scales = [0.5, 1.0, 1.5] # random.uniform(self.size[0], self.size[1])
+ out_images = list()
+ for scale in scales:
+ for img, interpolation in zip(img_group, self.interpolation):
+ out_images.append(cv2.resize(
+ img, None, fx=scale, fy=scale, interpolation=interpolation))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+
+
+class GroupRandomScaleRatio(object):
+ def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
+ self.size = size
+ self.interpolation = interpolation
+ self.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784,
+ 1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295]
+
+ def __call__(self, img_group):
+ assert (len(self.interpolation) == len(img_group))
+ w_scale = random.randint(self.size[0], self.size[1])
+ h_scale = random.randint(self.size[2], self.size[3])
+ h, w, _ = img_group[0].shape
+ out_images = list()
+ out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h,
+ interpolation=self.interpolation[0])) # fx=w_scale*1.0/w, fy=h_scale*1.0/h
+ ### process label map ###
+ origin_label = cv2.resize(
+ img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1])
+ origin_label = origin_label.astype(int)
+ label = origin_label[:, :, 0] * 5 + \
+ origin_label[:, :, 1] * 3 + origin_label[:, :, 2]
+ new_label = np.ones(label.shape) * 100
+ new_label = new_label.astype(int)
+ for cnt in range(37):
+ new_label = (
+ label == self.origin_id[cnt]) * (cnt - 100) + new_label
+ new_label = (label == self.origin_id[37]) * (36 - 100) + new_label
+ assert(100 not in np.unique(new_label))
+ out_images.append(new_label)
+ return out_images
+
+
+class GroupRandomRotation(object):
+ def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None):
+ self.degree = degree
+ self.interpolation = interpolation
+ self.padding = padding
+ if self.padding is None:
+ self.padding = [0, 0]
+
+ def __call__(self, img_group):
+ assert (len(self.interpolation) == len(img_group))
+ v = random.random()
+ if v < 0.5:
+ degree = random.uniform(self.degree[0], self.degree[1])
+ h, w = img_group[0].shape[0:2]
+ center = (w / 2, h / 2)
+ map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0)
+ out_images = list()
+ for img, interpolation, padding in zip(img_group, self.interpolation, self.padding):
+ out_images.append(cv2.warpAffine(
+ img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding))
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+ else:
+ return img_group
+
+
+class GroupRandomBlur(object):
+ def __init__(self, applied):
+ self.applied = applied
+
+ def __call__(self, img_group):
+ assert (len(self.applied) == len(img_group))
+ v = random.random()
+ if v < 0.5:
+ out_images = []
+ for img, a in zip(img_group, self.applied):
+ if a:
+ img = cv2.GaussianBlur(
+ img, (5, 5), random.uniform(1e-6, 0.6))
+ out_images.append(img)
+ if len(img.shape) > len(out_images[-1].shape):
+ out_images[-1] = out_images[-1][...,
+ np.newaxis] # single channel image
+ return out_images
+ else:
+ return img_group
+
+
+class GroupRandomHorizontalFlip(object):
+ """Randomly horizontally flips the given numpy Image with a probability of 0.5
+ """
+
+ def __init__(self, is_flow=False):
+ self.is_flow = is_flow
+
+ def __call__(self, img_group, is_flow=False):
+ v = random.random()
+ if v < 0.5:
+ out_images = [np.fliplr(img) for img in img_group]
+ if self.is_flow:
+ for i in range(0, len(out_images), 2):
+ # invert flow pixel values when flipping
+ out_images[i] = -out_images[i]
+ return out_images
+ else:
+ return img_group
+
+
+class GroupNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, img_group):
+ out_images = list()
+ for img, m, s in zip(img_group, self.mean, self.std):
+ if len(m) == 1:
+ img = img - np.array(m) # single channel image
+ img = img / np.array(s)
+ else:
+ img = img - np.array(m)[np.newaxis, np.newaxis, ...]
+ img = img / np.array(s)[np.newaxis, np.newaxis, ...]
+ out_images.append(img)
+
+ return out_images