first commit
commit
c8504c8e34
|
@ -0,0 +1,22 @@
|
||||||
|
work_dirs/
|
||||||
|
predicts/
|
||||||
|
output/
|
||||||
|
data/
|
||||||
|
data
|
||||||
|
|
||||||
|
__pycache__/
|
||||||
|
*/*.un~
|
||||||
|
.*.swp
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
*.egg-info/
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
output.txt
|
||||||
|
.vscode/*
|
||||||
|
.DS_Store
|
||||||
|
tmp.*
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.un~
|
|
@ -0,0 +1,74 @@
|
||||||
|
|
||||||
|
# Install
|
||||||
|
|
||||||
|
1. Clone the RESA repository
|
||||||
|
```
|
||||||
|
git clone https://github.com/zjulearning/resa.git
|
||||||
|
```
|
||||||
|
We call this directory as `$RESA_ROOT`
|
||||||
|
|
||||||
|
2. Create a conda virtual environment and activate it (conda is optional)
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
conda create -n resa python=3.8 -y
|
||||||
|
conda activate resa
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install dependencies
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
|
||||||
|
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
|
||||||
|
|
||||||
|
# Or you can install via pip
|
||||||
|
pip install torch torchvision
|
||||||
|
|
||||||
|
# Install python packages
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Data preparation
|
||||||
|
|
||||||
|
Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
cd $RESA_ROOT
|
||||||
|
ln -s $CULANEROOT data/CULane
|
||||||
|
ln -s $TUSIMPLEROOT data/tusimple
|
||||||
|
```
|
||||||
|
|
||||||
|
For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
python scripts/convert_tusimple.py --root $TUSIMPLEROOT
|
||||||
|
# this will generate segmentations and two list files: train_gt.txt and test.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
For CULane, you should have structure like this:
|
||||||
|
```
|
||||||
|
$RESA_ROOT/data/CULane/driver_xx_xxframe # data folders x6
|
||||||
|
$RESA_ROOT/data/CULane/laneseg_label_w16 # lane segmentation labels
|
||||||
|
$RESA_ROOT/data/CULane/list # data lists
|
||||||
|
```
|
||||||
|
|
||||||
|
For Tusimple, you should have structure like this:
|
||||||
|
```
|
||||||
|
$RESA_ROOT/data/tusimple/clips # data folders
|
||||||
|
$RESA_ROOT/data/tusimple/lable_data_xxxx.json # label json file x4
|
||||||
|
$RESA_ROOT/data/tusimple/test_tasks_0627.json # test tasks json file
|
||||||
|
$RESA_ROOT/data/tusimple/test_label.json # test label json file
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Install CULane evaluation tools.
|
||||||
|
|
||||||
|
This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
|
||||||
|
|
||||||
|
|
||||||
|
Then compile the evaluation tool of CULane.
|
||||||
|
```Shell
|
||||||
|
cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
|
||||||
|
make
|
||||||
|
cd -
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
|
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2021 Tu Zheng
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
|
@ -0,0 +1,148 @@
|
||||||
|
# RESA
|
||||||
|
PyTorch implementation of the paper "[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719)".
|
||||||
|
|
||||||
|
Our paper has been accepted by AAAI2021.
|
||||||
|
|
||||||
|
**News**: We also release RESA on [LaneDet](https://github.com/Turoad/lanedet). It's also recommended for you to try LaneDet.
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
![intro](intro.png "intro")
|
||||||
|
- RESA shifts sliced
|
||||||
|
feature map recurrently in vertical and horizontal directions
|
||||||
|
and enables each pixel to gather global information.
|
||||||
|
- RESA achieves SOTA results on CULane and Tusimple Dataset.
|
||||||
|
|
||||||
|
## Get started
|
||||||
|
1. Clone the RESA repository
|
||||||
|
```
|
||||||
|
git clone https://github.com/zjulearning/resa.git
|
||||||
|
```
|
||||||
|
We call this directory as `$RESA_ROOT`
|
||||||
|
|
||||||
|
2. Create a conda virtual environment and activate it (conda is optional)
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
conda create -n resa python=3.8 -y
|
||||||
|
conda activate resa
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install dependencies
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
|
||||||
|
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
|
||||||
|
|
||||||
|
# Or you can install via pip
|
||||||
|
pip install torch torchvision
|
||||||
|
|
||||||
|
# Install python packages
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Data preparation
|
||||||
|
|
||||||
|
Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
cd $RESA_ROOT
|
||||||
|
mkdir -p data
|
||||||
|
ln -s $CULANEROOT data/CULane
|
||||||
|
ln -s $TUSIMPLEROOT data/tusimple
|
||||||
|
```
|
||||||
|
|
||||||
|
For CULane, you should have structure like this:
|
||||||
|
```
|
||||||
|
$CULANEROOT/driver_xx_xxframe # data folders x6
|
||||||
|
$CULANEROOT/laneseg_label_w16 # lane segmentation labels
|
||||||
|
$CULANEROOT/list # data lists
|
||||||
|
```
|
||||||
|
|
||||||
|
For Tusimple, you should have structure like this:
|
||||||
|
```
|
||||||
|
$TUSIMPLEROOT/clips # data folders
|
||||||
|
$TUSIMPLEROOT/lable_data_xxxx.json # label json file x4
|
||||||
|
$TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
|
||||||
|
$TUSIMPLEROOT/test_label.json # test label json file
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
|
||||||
|
# this will generate seg_label directory
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Install CULane evaluation tools.
|
||||||
|
|
||||||
|
This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
|
||||||
|
|
||||||
|
|
||||||
|
Then compile the evaluation tool of CULane.
|
||||||
|
```Shell
|
||||||
|
cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
|
||||||
|
make
|
||||||
|
cd -
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
|
||||||
|
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
For training, run
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
python main.py [configs/path_to_your_config] --gpus [gpu_ids]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
For example, run
|
||||||
|
```Shell
|
||||||
|
python main.py configs/culane.py --gpus 0 1 2 3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
For testing, run
|
||||||
|
```Shell
|
||||||
|
python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num]
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, run
|
||||||
|
```Shell
|
||||||
|
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3
|
||||||
|
|
||||||
|
python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
We provide two trained ResNet models on CULane and Tusimple, downloading our best performed model (Tusimple: [GoogleDrive](https://drive.google.com/file/d/1M1xi82y0RoWUwYYG9LmZHXWSD2D60o0D/view?usp=sharing)/[BaiduDrive(code:s5ii)](https://pan.baidu.com/s/1CgJFrt9OHe-RUNooPpHRGA),
|
||||||
|
CULane: [GoogleDrive](https://drive.google.com/file/d/1pcqq9lpJ4ixJgFVFndlPe42VgVsjgn0Q/view?usp=sharing)/[BaiduDrive(code:rlwj)](https://pan.baidu.com/s/1ODKAZxpKrZIPXyaNnxcV3g)
|
||||||
|
)
|
||||||
|
|
||||||
|
## Visualization
|
||||||
|
Just add `--view`.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```Shell
|
||||||
|
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 --view
|
||||||
|
```
|
||||||
|
You will get the result in the directory: `work_dirs/[DATASET]/xxx/vis`.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
If you use our method, please consider citing:
|
||||||
|
```BibTeX
|
||||||
|
@inproceedings{zheng2021resa,
|
||||||
|
title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection},
|
||||||
|
author={Zheng, Tu and Fang, Hao and Zhang, Yi and Tang, Wenjian and Yang, Zheng and Liu, Haifeng and Cai, Deng},
|
||||||
|
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||||
|
volume={35},
|
||||||
|
number={4},
|
||||||
|
pages={3547--3554},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<!-- ## Thanks
|
||||||
|
|
||||||
|
The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark). -->
|
|
@ -0,0 +1,88 @@
|
||||||
|
net = dict(
|
||||||
|
type='RESANet',
|
||||||
|
)
|
||||||
|
|
||||||
|
backbone = dict(
|
||||||
|
type='ResNetWrapper',
|
||||||
|
resnet='resnet50',
|
||||||
|
pretrained=True,
|
||||||
|
replace_stride_with_dilation=[False, True, True],
|
||||||
|
out_conv=True,
|
||||||
|
fea_stride=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
resa = dict(
|
||||||
|
type='RESA',
|
||||||
|
alpha=2.0,
|
||||||
|
iter=4,
|
||||||
|
input_channel=128,
|
||||||
|
conv_stride=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
#decoder = 'PlainDecoder'
|
||||||
|
decoder = 'BUSD'
|
||||||
|
|
||||||
|
trainer = dict(
|
||||||
|
type='RESA'
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = dict(
|
||||||
|
type='CULane',
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = dict(
|
||||||
|
type='sgd',
|
||||||
|
lr=0.025,
|
||||||
|
weight_decay=1e-4,
|
||||||
|
momentum=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs = 12
|
||||||
|
batch_size = 8
|
||||||
|
total_iter = (88880 // batch_size) * epochs
|
||||||
|
import math
|
||||||
|
scheduler = dict(
|
||||||
|
type = 'LambdaLR',
|
||||||
|
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_type = 'dice_loss'
|
||||||
|
seg_loss_weight = 2.
|
||||||
|
eval_ep = 6
|
||||||
|
save_ep = epochs
|
||||||
|
|
||||||
|
bg_weight = 0.4
|
||||||
|
|
||||||
|
img_norm = dict(
|
||||||
|
mean=[103.939, 116.779, 123.68],
|
||||||
|
std=[1., 1., 1.]
|
||||||
|
)
|
||||||
|
|
||||||
|
img_height = 288
|
||||||
|
img_width = 800
|
||||||
|
cut_height = 240
|
||||||
|
|
||||||
|
dataset_path = './data/CULane'
|
||||||
|
dataset = dict(
|
||||||
|
train=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='train_gt.txt',
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test.txt',
|
||||||
|
),
|
||||||
|
test=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test.txt',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
workers = 12
|
||||||
|
num_classes = 4 + 1
|
||||||
|
ignore_label = 255
|
||||||
|
log_interval = 500
|
|
@ -0,0 +1,97 @@
|
||||||
|
net = dict(
|
||||||
|
type='RESANet',
|
||||||
|
)
|
||||||
|
|
||||||
|
# backbone = dict(
|
||||||
|
# type='ResNetWrapper',
|
||||||
|
# resnet='resnet50',
|
||||||
|
# pretrained=True,
|
||||||
|
# replace_stride_with_dilation=[False, True, True],
|
||||||
|
# out_conv=True,
|
||||||
|
# fea_stride=8,
|
||||||
|
# )
|
||||||
|
|
||||||
|
backbone = dict(
|
||||||
|
type='ResNetWrapper',
|
||||||
|
resnet='resnet34',
|
||||||
|
pretrained=True,
|
||||||
|
replace_stride_with_dilation=[False, False, False],
|
||||||
|
out_conv=False,
|
||||||
|
fea_stride=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
resa = dict(
|
||||||
|
type='RESA',
|
||||||
|
alpha=2.0,
|
||||||
|
iter=4,
|
||||||
|
input_channel=128,
|
||||||
|
conv_stride=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
#decoder = 'PlainDecoder'
|
||||||
|
decoder = 'BUSD'
|
||||||
|
|
||||||
|
trainer = dict(
|
||||||
|
type='RESA'
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = dict(
|
||||||
|
type='CULane',
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = dict(
|
||||||
|
type='sgd',
|
||||||
|
lr=0.025,
|
||||||
|
weight_decay=1e-4,
|
||||||
|
momentum=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs = 20
|
||||||
|
batch_size = 8
|
||||||
|
total_iter = (88880 // batch_size) * epochs
|
||||||
|
import math
|
||||||
|
scheduler = dict(
|
||||||
|
type = 'LambdaLR',
|
||||||
|
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_type = 'dice_loss'
|
||||||
|
seg_loss_weight = 2.
|
||||||
|
eval_ep = 1
|
||||||
|
save_ep = epochs
|
||||||
|
|
||||||
|
bg_weight = 0.4
|
||||||
|
|
||||||
|
img_norm = dict(
|
||||||
|
mean=[103.939, 116.779, 123.68],
|
||||||
|
std=[1., 1., 1.]
|
||||||
|
)
|
||||||
|
|
||||||
|
img_height = 288
|
||||||
|
img_width = 800
|
||||||
|
cut_height = 240
|
||||||
|
|
||||||
|
dataset_path = './data/CULane'
|
||||||
|
dataset = dict(
|
||||||
|
train=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='train_gt.txt',
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test.txt',
|
||||||
|
),
|
||||||
|
test=dict(
|
||||||
|
type='CULane',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test.txt',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
workers = 12
|
||||||
|
num_classes = 4 + 1
|
||||||
|
ignore_label = 255
|
||||||
|
log_interval = 500
|
|
@ -0,0 +1,93 @@
|
||||||
|
net = dict(
|
||||||
|
type='RESANet',
|
||||||
|
)
|
||||||
|
|
||||||
|
backbone = dict(
|
||||||
|
type='ResNetWrapper',
|
||||||
|
resnet='resnet34',
|
||||||
|
pretrained=True,
|
||||||
|
replace_stride_with_dilation=[False, True, True],
|
||||||
|
out_conv=True,
|
||||||
|
fea_stride=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
resa = dict(
|
||||||
|
type='RESA',
|
||||||
|
alpha=2.0,
|
||||||
|
iter=5,
|
||||||
|
input_channel=128,
|
||||||
|
conv_stride=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = 'BUSD'
|
||||||
|
|
||||||
|
trainer = dict(
|
||||||
|
type='RESA'
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = dict(
|
||||||
|
type='Tusimple',
|
||||||
|
thresh = 0.60
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = dict(
|
||||||
|
type='sgd',
|
||||||
|
lr=0.020,
|
||||||
|
weight_decay=1e-4,
|
||||||
|
momentum=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
total_iter = 181400
|
||||||
|
import math
|
||||||
|
scheduler = dict(
|
||||||
|
type = 'LambdaLR',
|
||||||
|
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||||
|
)
|
||||||
|
|
||||||
|
bg_weight = 0.4
|
||||||
|
|
||||||
|
img_norm = dict(
|
||||||
|
mean=[103.939, 116.779, 123.68],
|
||||||
|
std=[1., 1., 1.]
|
||||||
|
)
|
||||||
|
|
||||||
|
img_height = 368
|
||||||
|
img_width = 640
|
||||||
|
cut_height = 160
|
||||||
|
seg_label = "seg_label"
|
||||||
|
|
||||||
|
dataset_path = './data/tusimple'
|
||||||
|
test_json_file = './data/tusimple/test_label.json'
|
||||||
|
|
||||||
|
dataset = dict(
|
||||||
|
train=dict(
|
||||||
|
type='TuSimple',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='train_val_gt.txt',
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
type='TuSimple',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test_gt.txt'
|
||||||
|
),
|
||||||
|
test=dict(
|
||||||
|
type='TuSimple',
|
||||||
|
img_path=dataset_path,
|
||||||
|
data_list='test_gt.txt'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
loss_type = 'cross_entropy'
|
||||||
|
seg_loss_weight = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
workers = 12
|
||||||
|
num_classes = 6 + 1
|
||||||
|
ignore_label = 255
|
||||||
|
epochs = 300
|
||||||
|
log_interval = 100
|
||||||
|
eval_ep = 1
|
||||||
|
save_ep = epochs
|
||||||
|
log_note = ''
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .registry import build_dataset, build_dataloader
|
||||||
|
|
||||||
|
from .tusimple import TuSimple
|
||||||
|
from .culane import CULane
|
|
@ -0,0 +1,86 @@
|
||||||
|
import os.path as osp
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torchvision
|
||||||
|
import utils.transforms as tf
|
||||||
|
from .registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module
|
||||||
|
class BaseDataset(Dataset):
|
||||||
|
def __init__(self, img_path, data_list, list_path='list', cfg=None):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.img_path = img_path
|
||||||
|
self.list_path = osp.join(img_path, list_path)
|
||||||
|
self.data_list = data_list
|
||||||
|
self.is_training = ('train' in data_list)
|
||||||
|
|
||||||
|
self.img_name_list = []
|
||||||
|
self.full_img_path_list = []
|
||||||
|
self.label_list = []
|
||||||
|
self.exist_list = []
|
||||||
|
|
||||||
|
self.transform = self.transform_train() if self.is_training else self.transform_val()
|
||||||
|
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def transform_train(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def transform_val(self):
|
||||||
|
val_transform = torchvision.transforms.Compose([
|
||||||
|
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||||
|
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||||
|
self.cfg.img_norm['std'], (1, ))),
|
||||||
|
])
|
||||||
|
return val_transform
|
||||||
|
|
||||||
|
def view(self, img, coords, file_path=None):
|
||||||
|
for coord in coords:
|
||||||
|
for x, y in coord:
|
||||||
|
if x <= 0 or y <= 0:
|
||||||
|
continue
|
||||||
|
x, y = int(x), int(y)
|
||||||
|
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
|
||||||
|
|
||||||
|
if file_path is not None:
|
||||||
|
if not os.path.exists(osp.dirname(file_path)):
|
||||||
|
os.makedirs(osp.dirname(file_path))
|
||||||
|
cv2.imwrite(file_path, img)
|
||||||
|
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.full_img_path_list)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
|
||||||
|
img = img[self.cfg.cut_height:, :, :]
|
||||||
|
|
||||||
|
if self.is_training:
|
||||||
|
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
|
||||||
|
if len(label.shape) > 2:
|
||||||
|
label = label[:, :, 0]
|
||||||
|
label = label.squeeze()
|
||||||
|
label = label[self.cfg.cut_height:, :]
|
||||||
|
exist = self.exist_list[idx]
|
||||||
|
if self.transform:
|
||||||
|
img, label = self.transform((img, label))
|
||||||
|
label = torch.from_numpy(label).contiguous().long()
|
||||||
|
else:
|
||||||
|
img, = self.transform((img,))
|
||||||
|
|
||||||
|
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
|
||||||
|
meta = {'full_img_path': self.full_img_path_list[idx],
|
||||||
|
'img_name': self.img_name_list[idx]}
|
||||||
|
|
||||||
|
data = {'img': img, 'meta': meta}
|
||||||
|
if self.is_training:
|
||||||
|
data.update({'label': label, 'exist': exist})
|
||||||
|
return data
|
|
@ -0,0 +1,72 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
import torchvision
|
||||||
|
import utils.transforms as tf
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .registry import DATASETS
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module
|
||||||
|
class CULane(BaseDataset):
|
||||||
|
def __init__(self, img_path, data_list, cfg=None):
|
||||||
|
super().__init__(img_path, data_list, cfg=cfg)
|
||||||
|
self.ori_imgh = 590
|
||||||
|
self.ori_imgw = 1640
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
with open(osp.join(self.list_path, self.data_list)) as f:
|
||||||
|
for line in f:
|
||||||
|
line_split = line.strip().split(" ")
|
||||||
|
self.img_name_list.append(line_split[0])
|
||||||
|
self.full_img_path_list.append(self.img_path + line_split[0])
|
||||||
|
if not self.is_training:
|
||||||
|
continue
|
||||||
|
self.label_list.append(self.img_path + line_split[1])
|
||||||
|
self.exist_list.append(
|
||||||
|
np.array([int(line_split[2]), int(line_split[3]),
|
||||||
|
int(line_split[4]), int(line_split[5])]))
|
||||||
|
|
||||||
|
def transform_train(self):
|
||||||
|
train_transform = torchvision.transforms.Compose([
|
||||||
|
tf.GroupRandomRotation(degree=(-2, 2)),
|
||||||
|
tf.GroupRandomHorizontalFlip(),
|
||||||
|
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||||
|
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||||
|
self.cfg.img_norm['std'], (1, ))),
|
||||||
|
])
|
||||||
|
return train_transform
|
||||||
|
|
||||||
|
def probmap2lane(self, probmaps, exists, pts=18):
|
||||||
|
coords = []
|
||||||
|
probmaps = probmaps[1:, ...]
|
||||||
|
exists = exists > 0.5
|
||||||
|
for probmap, exist in zip(probmaps, exists):
|
||||||
|
if exist == 0:
|
||||||
|
continue
|
||||||
|
probmap = cv2.blur(probmap, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||||
|
thr = 0.3
|
||||||
|
coordinate = np.zeros(pts)
|
||||||
|
cut_height = self.cfg.cut_height
|
||||||
|
for i in range(pts):
|
||||||
|
line = probmap[round(
|
||||||
|
self.cfg.img_height-i*20/(self.ori_imgh-cut_height)*self.cfg.img_height)-1]
|
||||||
|
|
||||||
|
if np.max(line) > thr:
|
||||||
|
coordinate[i] = np.argmax(line)+1
|
||||||
|
if np.sum(coordinate > 0) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_coord = np.zeros((pts, 2))
|
||||||
|
img_coord[:, :] = -1
|
||||||
|
for idx, value in enumerate(coordinate):
|
||||||
|
if value > 0:
|
||||||
|
img_coord[idx][0] = round(value*self.ori_imgw/self.cfg.img_width-1)
|
||||||
|
img_coord[idx][1] = round(self.ori_imgh-idx*20-1)
|
||||||
|
|
||||||
|
img_coord = img_coord.astype(int)
|
||||||
|
coords.append(img_coord)
|
||||||
|
|
||||||
|
return coords
|
|
@ -0,0 +1,36 @@
|
||||||
|
from utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
DATASETS = Registry('datasets')
|
||||||
|
|
||||||
|
def build(cfg, registry, default_args=None):
|
||||||
|
if isinstance(cfg, list):
|
||||||
|
modules = [
|
||||||
|
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||||
|
]
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
else:
|
||||||
|
return build_from_cfg(cfg, registry, default_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(split_cfg, cfg):
|
||||||
|
args = split_cfg.copy()
|
||||||
|
args.pop('type')
|
||||||
|
args = args.to_dict()
|
||||||
|
args['cfg'] = cfg
|
||||||
|
return build(split_cfg, DATASETS, default_args=args)
|
||||||
|
|
||||||
|
def build_dataloader(split_cfg, cfg, is_train=True):
|
||||||
|
if is_train:
|
||||||
|
shuffle = True
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
|
|
||||||
|
dataset = build_dataset(split_cfg, cfg)
|
||||||
|
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, batch_size = cfg.batch_size, shuffle = shuffle,
|
||||||
|
num_workers = cfg.workers, pin_memory = False, drop_last = False)
|
||||||
|
|
||||||
|
return data_loader
|
|
@ -0,0 +1,150 @@
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torchvision
|
||||||
|
import utils.transforms as tf
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module
|
||||||
|
class TuSimple(BaseDataset):
|
||||||
|
def __init__(self, img_path, data_list, cfg=None):
|
||||||
|
super().__init__(img_path, data_list, 'seg_label/list', cfg)
|
||||||
|
|
||||||
|
def transform_train(self):
|
||||||
|
input_mean = self.cfg.img_norm['mean']
|
||||||
|
train_transform = torchvision.transforms.Compose([
|
||||||
|
tf.GroupRandomRotation(),
|
||||||
|
tf.GroupRandomHorizontalFlip(),
|
||||||
|
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||||
|
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||||
|
self.cfg.img_norm['std'], (1, ))),
|
||||||
|
])
|
||||||
|
return train_transform
|
||||||
|
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
with open(osp.join(self.list_path, self.data_list)) as f:
|
||||||
|
for line in f:
|
||||||
|
line_split = line.strip().split(" ")
|
||||||
|
self.img_name_list.append(line_split[0])
|
||||||
|
self.full_img_path_list.append(self.img_path + line_split[0])
|
||||||
|
if not self.is_training:
|
||||||
|
continue
|
||||||
|
self.label_list.append(self.img_path + line_split[1])
|
||||||
|
self.exist_list.append(
|
||||||
|
np.array([int(line_split[2]), int(line_split[3]),
|
||||||
|
int(line_split[4]), int(line_split[5]),
|
||||||
|
int(line_split[6]), int(line_split[7])
|
||||||
|
]))
|
||||||
|
|
||||||
|
def fix_gap(self, coordinate):
|
||||||
|
if any(x > 0 for x in coordinate):
|
||||||
|
start = [i for i, x in enumerate(coordinate) if x > 0][0]
|
||||||
|
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
|
||||||
|
lane = coordinate[start:end+1]
|
||||||
|
if any(x < 0 for x in lane):
|
||||||
|
gap_start = [i for i, x in enumerate(
|
||||||
|
lane[:-1]) if x > 0 and lane[i+1] < 0]
|
||||||
|
gap_end = [i+1 for i,
|
||||||
|
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
|
||||||
|
gap_id = [i for i, x in enumerate(lane) if x < 0]
|
||||||
|
if len(gap_start) == 0 or len(gap_end) == 0:
|
||||||
|
return coordinate
|
||||||
|
for id in gap_id:
|
||||||
|
for i in range(len(gap_start)):
|
||||||
|
if i >= len(gap_end):
|
||||||
|
return coordinate
|
||||||
|
if id > gap_start[i] and id < gap_end[i]:
|
||||||
|
gap_width = float(gap_end[i] - gap_start[i])
|
||||||
|
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
|
||||||
|
gap_end[i] - id) / gap_width * lane[gap_start[i]])
|
||||||
|
if not all(x > 0 for x in lane):
|
||||||
|
print("Gaps still exist!")
|
||||||
|
coordinate[start:end+1] = lane
|
||||||
|
return coordinate
|
||||||
|
|
||||||
|
def is_short(self, lane):
|
||||||
|
start = [i for i, x in enumerate(lane) if x > 0]
|
||||||
|
if not start:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
----------
|
||||||
|
prob_map: prob map for single lane, np array size (h, w)
|
||||||
|
resize_shape: reshape size target, (H, W)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
----------
|
||||||
|
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
|
||||||
|
"""
|
||||||
|
if resize_shape is None:
|
||||||
|
resize_shape = prob_map.shape
|
||||||
|
h, w = prob_map.shape
|
||||||
|
H, W = resize_shape
|
||||||
|
H -= self.cfg.cut_height
|
||||||
|
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
coords[:] = -1.0
|
||||||
|
for i in range(pts):
|
||||||
|
y = int((H - 10 - i * y_px_gap) * h / H)
|
||||||
|
if y < 0:
|
||||||
|
break
|
||||||
|
line = prob_map[y, :]
|
||||||
|
id = np.argmax(line)
|
||||||
|
if line[id] > thresh:
|
||||||
|
coords[i] = int(id / w * W)
|
||||||
|
if (coords > 0).sum() < 2:
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
self.fix_gap(coords)
|
||||||
|
#print(coords.shape)
|
||||||
|
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
----------
|
||||||
|
seg_pred: np.array size (5, h, w)
|
||||||
|
resize_shape: reshape size target, (H, W)
|
||||||
|
exist: list of existence, e.g. [0, 1, 1, 0]
|
||||||
|
smooth: whether to smooth the probability or not
|
||||||
|
y_px_gap: y pixel gap for sampling
|
||||||
|
pts: how many points for one lane
|
||||||
|
thresh: probability threshold
|
||||||
|
|
||||||
|
Return:
|
||||||
|
----------
|
||||||
|
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
|
||||||
|
"""
|
||||||
|
if resize_shape is None:
|
||||||
|
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
|
||||||
|
_, h, w = seg_pred.shape
|
||||||
|
H, W = resize_shape
|
||||||
|
coordinates = []
|
||||||
|
|
||||||
|
for i in range(self.cfg.num_classes - 1):
|
||||||
|
prob_map = seg_pred[i + 1]
|
||||||
|
if smooth:
|
||||||
|
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||||
|
coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
|
||||||
|
if self.is_short(coords):
|
||||||
|
continue
|
||||||
|
coordinates.append(
|
||||||
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||||
|
range(pts)])
|
||||||
|
|
||||||
|
|
||||||
|
if len(coordinates) == 0:
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
coordinates.append(
|
||||||
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||||
|
range(pts)])
|
||||||
|
#print(coordinates)
|
||||||
|
|
||||||
|
return coordinates
|
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import models
|
||||||
|
import argparse
|
||||||
|
from utils.config import Config
|
||||||
|
from runner.runner import Runner
|
||||||
|
from datasets import build_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
|
||||||
|
|
||||||
|
cfg = Config.fromfile(args.config)
|
||||||
|
cfg.gpus = len(args.gpus)
|
||||||
|
|
||||||
|
cfg.load_from = args.load_from
|
||||||
|
cfg.finetune_from = args.finetune_from
|
||||||
|
cfg.view = args.view
|
||||||
|
|
||||||
|
cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
cudnn.fastest = True
|
||||||
|
|
||||||
|
runner = Runner(cfg)
|
||||||
|
|
||||||
|
if args.validate:
|
||||||
|
val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
|
||||||
|
runner.validate(val_loader)
|
||||||
|
else:
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Train a detector')
|
||||||
|
parser.add_argument('config', help='train config file path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--work_dirs', type=str, default='work_dirs',
|
||||||
|
help='work dirs')
|
||||||
|
parser.add_argument(
|
||||||
|
'--load_from', default=None,
|
||||||
|
help='the checkpoint file to resume from')
|
||||||
|
parser.add_argument(
|
||||||
|
'--finetune_from', default=None,
|
||||||
|
help='whether to finetune from the checkpoint')
|
||||||
|
parser.add_argument(
|
||||||
|
'--validate',
|
||||||
|
action='store_true',
|
||||||
|
help='whether to evaluate the checkpoint during training')
|
||||||
|
parser.add_argument(
|
||||||
|
'--view',
|
||||||
|
action='store_true',
|
||||||
|
help='whether to show visualization result')
|
||||||
|
parser.add_argument('--gpus', nargs='+', type=int, default='0')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
default=None, help='random seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1 @@
|
||||||
|
from .resa import *
|
|
@ -0,0 +1,129 @@
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class PlainDecoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(PlainDecoder, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(0.1)
|
||||||
|
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv8(x)
|
||||||
|
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class non_bottleneck_1d(nn.Module):
|
||||||
|
def __init__(self, chann, dropprob, dilated):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv3x1_1 = nn.Conv2d(
|
||||||
|
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv1x3_1 = nn.Conv2d(
|
||||||
|
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||||
|
dilation=(dilated, 1))
|
||||||
|
|
||||||
|
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||||
|
dilation=(1, dilated))
|
||||||
|
|
||||||
|
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(dropprob)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv3x1_1(input)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_1(output)
|
||||||
|
output = self.bn1(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
|
||||||
|
output = self.conv3x1_2(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_2(output)
|
||||||
|
output = self.bn2(output)
|
||||||
|
|
||||||
|
if (self.dropout.p != 0):
|
||||||
|
output = self.dropout(output)
|
||||||
|
|
||||||
|
# +input = identity (residual connection)
|
||||||
|
return F.relu(output + input)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsamplerBlock(nn.Module):
|
||||||
|
def __init__(self, ninput, noutput, up_width, up_height):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.ConvTranspose2d(
|
||||||
|
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
self.follows = nn.ModuleList()
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
self.up_width = up_width
|
||||||
|
self.up_height = up_height
|
||||||
|
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||||
|
self.interpolate_bn = nn.BatchNorm2d(
|
||||||
|
noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv(input)
|
||||||
|
output = self.bn(output)
|
||||||
|
out = F.relu(output)
|
||||||
|
for follow in self.follows:
|
||||||
|
out = follow(out)
|
||||||
|
|
||||||
|
interpolate_output = self.interpolate_conv(input)
|
||||||
|
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||||
|
interpolate_output = F.relu(interpolate_output)
|
||||||
|
|
||||||
|
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return out + interpolate
|
||||||
|
|
||||||
|
class BUSD(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
img_height = cfg.img_height
|
||||||
|
img_width = cfg.img_width
|
||||||
|
num_classes = cfg.num_classes
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||||
|
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||||
|
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
|
||||||
|
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||||
|
|
||||||
|
self.output_conv = conv1x1(16, num_classes)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = input
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
output = layer(output)
|
||||||
|
|
||||||
|
output = self.output_conv(output)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,135 @@
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class PlainDecoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(PlainDecoder, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(0.1)
|
||||||
|
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv8(x)
|
||||||
|
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class non_bottleneck_1d(nn.Module):
|
||||||
|
def __init__(self, chann, dropprob, dilated):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv3x1_1 = nn.Conv2d(
|
||||||
|
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv1x3_1 = nn.Conv2d(
|
||||||
|
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||||
|
dilation=(dilated, 1))
|
||||||
|
|
||||||
|
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||||
|
dilation=(1, dilated))
|
||||||
|
|
||||||
|
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(dropprob)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv3x1_1(input)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_1(output)
|
||||||
|
output = self.bn1(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
|
||||||
|
output = self.conv3x1_2(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_2(output)
|
||||||
|
output = self.bn2(output)
|
||||||
|
|
||||||
|
if (self.dropout.p != 0):
|
||||||
|
output = self.dropout(output)
|
||||||
|
|
||||||
|
# +input = identity (residual connection)
|
||||||
|
return F.relu(output + input)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsamplerBlock(nn.Module):
|
||||||
|
def __init__(self, ninput, noutput, up_width, up_height):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.ConvTranspose2d(
|
||||||
|
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
self.follows = nn.ModuleList()
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
self.up_width = up_width
|
||||||
|
self.up_height = up_height
|
||||||
|
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||||
|
self.interpolate_bn = nn.BatchNorm2d(
|
||||||
|
noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv(input)
|
||||||
|
output = self.bn(output)
|
||||||
|
out = F.relu(output)
|
||||||
|
for follow in self.follows:
|
||||||
|
out = follow(out)
|
||||||
|
|
||||||
|
interpolate_output = self.interpolate_conv(input)
|
||||||
|
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||||
|
interpolate_output = F.relu(interpolate_output)
|
||||||
|
|
||||||
|
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return out + interpolate
|
||||||
|
|
||||||
|
class BUSD(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
img_height = cfg.img_height
|
||||||
|
img_width = cfg.img_width
|
||||||
|
num_classes = cfg.num_classes
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||||
|
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||||
|
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||||
|
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||||
|
|
||||||
|
self.output_conv = conv1x1(32, num_classes)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
x = input[0]
|
||||||
|
output = input[1]
|
||||||
|
|
||||||
|
for i,layer in enumerate(self.layers):
|
||||||
|
output = layer(output)
|
||||||
|
if i == 0:
|
||||||
|
output = torch.cat((x, output), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
output = self.output_conv(output)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,143 @@
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class PlainDecoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(PlainDecoder, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(0.1)
|
||||||
|
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv8(x)
|
||||||
|
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class non_bottleneck_1d(nn.Module):
|
||||||
|
def __init__(self, chann, dropprob, dilated):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv3x1_1 = nn.Conv2d(
|
||||||
|
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv1x3_1 = nn.Conv2d(
|
||||||
|
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||||
|
dilation=(dilated, 1))
|
||||||
|
|
||||||
|
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||||
|
dilation=(1, dilated))
|
||||||
|
|
||||||
|
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(dropprob)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv3x1_1(input)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_1(output)
|
||||||
|
output = self.bn1(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
|
||||||
|
output = self.conv3x1_2(output)
|
||||||
|
output = F.relu(output)
|
||||||
|
output = self.conv1x3_2(output)
|
||||||
|
output = self.bn2(output)
|
||||||
|
|
||||||
|
if (self.dropout.p != 0):
|
||||||
|
output = self.dropout(output)
|
||||||
|
|
||||||
|
# +input = identity (residual connection)
|
||||||
|
return F.relu(output + input)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsamplerBlock(nn.Module):
|
||||||
|
def __init__(self, ninput, noutput, up_width, up_height):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.ConvTranspose2d(
|
||||||
|
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
self.follows = nn.ModuleList()
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
self.up_width = up_width
|
||||||
|
self.up_height = up_height
|
||||||
|
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||||
|
self.interpolate_bn = nn.BatchNorm2d(
|
||||||
|
noutput, eps=1e-3, track_running_stats=True)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv(input)
|
||||||
|
output = self.bn(output)
|
||||||
|
out = F.relu(output)
|
||||||
|
for follow in self.follows:
|
||||||
|
out = follow(out)
|
||||||
|
|
||||||
|
interpolate_output = self.interpolate_conv(input)
|
||||||
|
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||||
|
interpolate_output = F.relu(interpolate_output)
|
||||||
|
|
||||||
|
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||||
|
mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
return out + interpolate
|
||||||
|
|
||||||
|
class BUSD(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
img_height = cfg.img_height
|
||||||
|
img_width = cfg.img_width
|
||||||
|
num_classes = cfg.num_classes
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||||
|
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||||
|
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||||
|
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
|
||||||
|
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||||
|
self.out1 = conv1x1(128, 64)
|
||||||
|
self.out2 = conv1x1(64, 32)
|
||||||
|
self.output_conv = conv1x1(16, num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
out1 = input[0]
|
||||||
|
out2 = input[1]
|
||||||
|
output = input[2]
|
||||||
|
|
||||||
|
for i,layer in enumerate(self.layers):
|
||||||
|
if i == 0:
|
||||||
|
output = layer(output)
|
||||||
|
output = torch.cat((out2, output), dim=1)
|
||||||
|
output = self.out1(output)
|
||||||
|
elif i == 1:
|
||||||
|
output = layer(output)
|
||||||
|
output = torch.cat((out1, output), dim=1)
|
||||||
|
output = self.out2(output)
|
||||||
|
else:
|
||||||
|
output = layer(output)
|
||||||
|
|
||||||
|
output = self.output_conv(output)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,422 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
from torchvision.transforms._presets import ImageClassification
|
||||||
|
from torchvision.utils import _log_api_usage_once
|
||||||
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
||||||
|
from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
||||||
|
import warnings
|
||||||
|
from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
|
||||||
|
import collections
|
||||||
|
from itertools import repeat
|
||||||
|
M = TypeVar("M", bound=nn.Module)
|
||||||
|
|
||||||
|
BUILTIN_MODELS = {}
|
||||||
|
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
|
||||||
|
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
|
||||||
|
key = name if name is not None else fn.__name__
|
||||||
|
if key in BUILTIN_MODELS:
|
||||||
|
raise ValueError(f"An entry is already registered under the name '{key}'.")
|
||||||
|
BUILTIN_MODELS[key] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||||
|
"""
|
||||||
|
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||||
|
Otherwise, we will make a tuple of length n, all with value of x.
|
||||||
|
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): input value
|
||||||
|
n (int): length of the resulting tuple
|
||||||
|
"""
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
class ConvNormActivation(torch.nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||||
|
stride: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||||
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||||
|
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
inplace: Optional[bool] = True,
|
||||||
|
bias: Optional[bool] = None,
|
||||||
|
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int) and isinstance(dilation, int):
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
else:
|
||||||
|
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
|
||||||
|
kernel_size = _make_ntuple(kernel_size, _conv_dim)
|
||||||
|
dilation = _make_ntuple(dilation, _conv_dim)
|
||||||
|
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
|
||||||
|
if bias is None:
|
||||||
|
bias = norm_layer is None
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
conv_layer(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if norm_layer is not None:
|
||||||
|
layers.append(norm_layer(out_channels))
|
||||||
|
|
||||||
|
if activation_layer is not None:
|
||||||
|
params = {} if inplace is None else {"inplace": inplace}
|
||||||
|
layers.append(activation_layer(**params))
|
||||||
|
super().__init__(*layers)
|
||||||
|
_log_api_usage_once(self)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if self.__class__ == ConvNormActivation:
|
||||||
|
warnings.warn(
|
||||||
|
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dNormActivation(ConvNormActivation):
|
||||||
|
"""
|
||||||
|
Configurable block used for Convolution2d-Normalization-Activation blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of channels in the input image
|
||||||
|
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
||||||
|
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
||||||
|
stride (int, optional): Stride of the convolution. Default: 1
|
||||||
|
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
||||||
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||||
|
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
|
||||||
|
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
|
||||||
|
dilation (int): Spacing between kernel elements. Default: 1
|
||||||
|
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
||||||
|
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]] = 3,
|
||||||
|
stride: Union[int, Tuple[int, int]] = 1,
|
||||||
|
padding: Optional[Union[int, Tuple[int, int], str]] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||||
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||||
|
dilation: Union[int, Tuple[int, int]] = 1,
|
||||||
|
inplace: Optional[bool] = True,
|
||||||
|
bias: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups,
|
||||||
|
norm_layer,
|
||||||
|
activation_layer,
|
||||||
|
dilation,
|
||||||
|
inplace,
|
||||||
|
bias,
|
||||||
|
torch.nn.Conv2d,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
|
||||||
|
|
||||||
|
|
||||||
|
# necessary for backwards compatibility
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
if stride not in [1, 2]:
|
||||||
|
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
hidden_dim = int(round(inp * expand_ratio))
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
|
||||||
|
layers: List[nn.Module] = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
# pw
|
||||||
|
layers.append(
|
||||||
|
Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||||
|
)
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
# dw
|
||||||
|
Conv2dNormActivation(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
stride=stride,
|
||||||
|
groups=hidden_dim,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU6,
|
||||||
|
),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
self.out_channels = oup
|
||||||
|
self._is_cn = stride > 1
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: int = 1000,
|
||||||
|
width_mult: float = 1.0,
|
||||||
|
inverted_residual_setting: Optional[List[List[int]]] = None,
|
||||||
|
round_nearest: int = 8,
|
||||||
|
block: Optional[Callable[..., nn.Module]] = None,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||||
|
dropout: float = 0.2,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
MobileNet V2 main class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of classes
|
||||||
|
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||||
|
inverted_residual_setting: Network structure
|
||||||
|
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||||
|
Set to 1 to turn off rounding
|
||||||
|
block: Module specifying inverted residual building block for mobilenet
|
||||||
|
norm_layer: Module specifying the normalization layer to use
|
||||||
|
dropout (float): The droupout probability
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
_log_api_usage_once(self)
|
||||||
|
|
||||||
|
if block is None:
|
||||||
|
block = InvertedResidual
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
|
||||||
|
if inverted_residual_setting is None:
|
||||||
|
inverted_residual_setting = [
|
||||||
|
# t, c, n, s
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 1], # **
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 1], # **
|
||||||
|
[6, 320, 1, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
# only check the first element, assuming user knows t,c,n,s are required
|
||||||
|
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||||
|
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||||
|
features: List[nn.Module] = [
|
||||||
|
Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||||
|
]
|
||||||
|
# building inverted residual blocks
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||||
|
for i in range(n):
|
||||||
|
stride = s if i == 0 else 1
|
||||||
|
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||||
|
input_channel = output_channel
|
||||||
|
# building last several layers
|
||||||
|
features.append(
|
||||||
|
Conv2dNormActivation(
|
||||||
|
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# make it nn.Sequential
|
||||||
|
self.features = nn.Sequential(*features)
|
||||||
|
|
||||||
|
# building classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.Linear(self.last_channel, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||||
|
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||||
|
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||||
|
x = self.features(x)
|
||||||
|
# Cannot use "squeeze" as batch-size can be 1
|
||||||
|
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||||
|
# x = torch.flatten(x, 1)
|
||||||
|
# x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
_COMMON_META = {
|
||||||
|
"num_params": 3504872,
|
||||||
|
"min_size": (1, 1),
|
||||||
|
"categories": _IMAGENET_CATEGORIES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNet_V2_Weights(WeightsEnum):
|
||||||
|
IMAGENET1K_V1 = Weights(
|
||||||
|
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||||
|
transforms=partial(ImageClassification, crop_size=224),
|
||||||
|
meta={
|
||||||
|
**_COMMON_META,
|
||||||
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
|
||||||
|
"_metrics": {
|
||||||
|
"ImageNet-1K": {
|
||||||
|
"acc@1": 71.878,
|
||||||
|
"acc@5": 90.286,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_ops": 0.301,
|
||||||
|
"_file_size": 13.555,
|
||||||
|
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
IMAGENET1K_V2 = Weights(
|
||||||
|
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
|
||||||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||||||
|
meta={
|
||||||
|
**_COMMON_META,
|
||||||
|
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
|
||||||
|
"_metrics": {
|
||||||
|
"ImageNet-1K": {
|
||||||
|
"acc@1": 72.154,
|
||||||
|
"acc@5": 90.822,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_ops": 0.301,
|
||||||
|
"_file_size": 13.598,
|
||||||
|
"_docs": """
|
||||||
|
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
||||||
|
`new training recipe
|
||||||
|
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
DEFAULT = IMAGENET1K_V2
|
||||||
|
|
||||||
|
|
||||||
|
# @register_model()
|
||||||
|
# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
|
||||||
|
def mobilenet_v2(
|
||||||
|
*, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
|
||||||
|
) -> MobileNetV2:
|
||||||
|
"""MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
|
||||||
|
Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
|
||||||
|
pretrained weights to use. See
|
||||||
|
:class:`~torchvision.models.MobileNet_V2_Weights` below for
|
||||||
|
more details, and possible values. By default, no pre-trained
|
||||||
|
weights are used.
|
||||||
|
progress (bool, optional): If True, displays a progress bar of the
|
||||||
|
download to stderr. Default is True.
|
||||||
|
**kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
|
||||||
|
base class. Please refer to the `source code
|
||||||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
|
||||||
|
for more details about this class.
|
||||||
|
|
||||||
|
.. autoclass:: torchvision.models.MobileNet_V2_Weights
|
||||||
|
:members:
|
||||||
|
"""
|
||||||
|
weights = MobileNet_V2_Weights.verify(weights)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||||||
|
|
||||||
|
model = MobileNetV2(**kwargs)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
model.load_state_dict(weights.get_state_dict(progress=progress))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
class MobileNetv2Wrapper(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MobileNetv2Wrapper, self).__init__()
|
||||||
|
weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
|
||||||
|
|
||||||
|
self.model = MobileNetV2()
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
self.model.load_state_dict(weights.get_state_dict(progress=True))
|
||||||
|
self.out = conv1x1(
|
||||||
|
1280, 128)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.model(x)
|
||||||
|
# print(x.shape)
|
||||||
|
if self.out:
|
||||||
|
x = self.out(x)
|
||||||
|
# print(x.shape)
|
||||||
|
return x
|
||||||
|
|
|
@ -0,0 +1,436 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
from torchvision.transforms._presets import ImageClassification
|
||||||
|
from torchvision.utils import _log_api_usage_once
|
||||||
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
||||||
|
from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
||||||
|
import warnings
|
||||||
|
from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
|
||||||
|
import collections
|
||||||
|
from itertools import repeat
|
||||||
|
M = TypeVar("M", bound=nn.Module)
|
||||||
|
|
||||||
|
BUILTIN_MODELS = {}
|
||||||
|
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
|
||||||
|
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
|
||||||
|
key = name if name is not None else fn.__name__
|
||||||
|
if key in BUILTIN_MODELS:
|
||||||
|
raise ValueError(f"An entry is already registered under the name '{key}'.")
|
||||||
|
BUILTIN_MODELS[key] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||||
|
"""
|
||||||
|
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||||
|
Otherwise, we will make a tuple of length n, all with value of x.
|
||||||
|
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): input value
|
||||||
|
n (int): length of the resulting tuple
|
||||||
|
"""
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
class ConvNormActivation(torch.nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||||
|
stride: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||||
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||||
|
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
inplace: Optional[bool] = True,
|
||||||
|
bias: Optional[bool] = None,
|
||||||
|
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int) and isinstance(dilation, int):
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
else:
|
||||||
|
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
|
||||||
|
kernel_size = _make_ntuple(kernel_size, _conv_dim)
|
||||||
|
dilation = _make_ntuple(dilation, _conv_dim)
|
||||||
|
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
|
||||||
|
if bias is None:
|
||||||
|
bias = norm_layer is None
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
conv_layer(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if norm_layer is not None:
|
||||||
|
layers.append(norm_layer(out_channels))
|
||||||
|
|
||||||
|
if activation_layer is not None:
|
||||||
|
params = {} if inplace is None else {"inplace": inplace}
|
||||||
|
layers.append(activation_layer(**params))
|
||||||
|
super().__init__(*layers)
|
||||||
|
_log_api_usage_once(self)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if self.__class__ == ConvNormActivation:
|
||||||
|
warnings.warn(
|
||||||
|
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dNormActivation(ConvNormActivation):
|
||||||
|
"""
|
||||||
|
Configurable block used for Convolution2d-Normalization-Activation blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of channels in the input image
|
||||||
|
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
||||||
|
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
||||||
|
stride (int, optional): Stride of the convolution. Default: 1
|
||||||
|
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
||||||
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||||
|
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
|
||||||
|
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
|
||||||
|
dilation (int): Spacing between kernel elements. Default: 1
|
||||||
|
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
||||||
|
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]] = 3,
|
||||||
|
stride: Union[int, Tuple[int, int]] = 1,
|
||||||
|
padding: Optional[Union[int, Tuple[int, int], str]] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||||
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||||
|
dilation: Union[int, Tuple[int, int]] = 1,
|
||||||
|
inplace: Optional[bool] = True,
|
||||||
|
bias: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups,
|
||||||
|
norm_layer,
|
||||||
|
activation_layer,
|
||||||
|
dilation,
|
||||||
|
inplace,
|
||||||
|
bias,
|
||||||
|
torch.nn.Conv2d,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
|
||||||
|
|
||||||
|
|
||||||
|
# necessary for backwards compatibility
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
if stride not in [1, 2]:
|
||||||
|
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
hidden_dim = int(round(inp * expand_ratio))
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
|
||||||
|
layers: List[nn.Module] = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
# pw
|
||||||
|
layers.append(
|
||||||
|
Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||||
|
)
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
# dw
|
||||||
|
Conv2dNormActivation(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
stride=stride,
|
||||||
|
groups=hidden_dim,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU6,
|
||||||
|
),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
self.out_channels = oup
|
||||||
|
self._is_cn = stride > 1
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: int = 1000,
|
||||||
|
width_mult: float = 1.0,
|
||||||
|
inverted_residual_setting: Optional[List[List[int]]] = None,
|
||||||
|
round_nearest: int = 8,
|
||||||
|
block: Optional[Callable[..., nn.Module]] = None,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||||
|
dropout: float = 0.2,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
MobileNet V2 main class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of classes
|
||||||
|
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||||
|
inverted_residual_setting: Network structure
|
||||||
|
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||||
|
Set to 1 to turn off rounding
|
||||||
|
block: Module specifying inverted residual building block for mobilenet
|
||||||
|
norm_layer: Module specifying the normalization layer to use
|
||||||
|
dropout (float): The droupout probability
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
_log_api_usage_once(self)
|
||||||
|
|
||||||
|
if block is None:
|
||||||
|
block = InvertedResidual
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
|
||||||
|
if inverted_residual_setting is None:
|
||||||
|
inverted_residual_setting = [
|
||||||
|
# t, c, n, s
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 1],
|
||||||
|
[6, 32, 3, 1],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 320, 1, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
# only check the first element, assuming user knows t,c,n,s are required
|
||||||
|
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||||
|
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||||
|
features: List[nn.Module] = [
|
||||||
|
Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||||
|
]
|
||||||
|
# building inverted residual blocks
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||||
|
for i in range(n):
|
||||||
|
stride = s if i == 0 else 1
|
||||||
|
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||||
|
input_channel = output_channel
|
||||||
|
# building last several layers
|
||||||
|
features.append(
|
||||||
|
Conv2dNormActivation(
|
||||||
|
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# make it nn.Sequential
|
||||||
|
self.features = nn.Sequential(*features)
|
||||||
|
# self.layer1 = nn.Sequential(*features[:])
|
||||||
|
# self.layer2 = features[57:120]
|
||||||
|
# self.layer3 = features[120:]
|
||||||
|
|
||||||
|
# building classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.Linear(self.last_channel, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||||
|
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||||
|
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||||
|
out_layers = []
|
||||||
|
for layer in self.features.named_modules():
|
||||||
|
for i, layer1 in enumerate(layer[1]):
|
||||||
|
# print(layer1)
|
||||||
|
x = layer1(x)
|
||||||
|
# print("第{}层,输出大小{}".format(i, x.shape))
|
||||||
|
if i in [0, 10, 18]:
|
||||||
|
out_layers.append(x)
|
||||||
|
break
|
||||||
|
# x = self.features(x)
|
||||||
|
# Cannot use "squeeze" as batch-size can be 1
|
||||||
|
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||||
|
# x = torch.flatten(x, 1)
|
||||||
|
# x = self.classifier(x)
|
||||||
|
return out_layers
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
_COMMON_META = {
|
||||||
|
"num_params": 3504872,
|
||||||
|
"min_size": (1, 1),
|
||||||
|
"categories": _IMAGENET_CATEGORIES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNet_V2_Weights(WeightsEnum):
|
||||||
|
IMAGENET1K_V1 = Weights(
|
||||||
|
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||||
|
transforms=partial(ImageClassification, crop_size=224),
|
||||||
|
meta={
|
||||||
|
**_COMMON_META,
|
||||||
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
|
||||||
|
"_metrics": {
|
||||||
|
"ImageNet-1K": {
|
||||||
|
"acc@1": 71.878,
|
||||||
|
"acc@5": 90.286,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_ops": 0.301,
|
||||||
|
"_file_size": 13.555,
|
||||||
|
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
IMAGENET1K_V2 = Weights(
|
||||||
|
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
|
||||||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||||||
|
meta={
|
||||||
|
**_COMMON_META,
|
||||||
|
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
|
||||||
|
"_metrics": {
|
||||||
|
"ImageNet-1K": {
|
||||||
|
"acc@1": 72.154,
|
||||||
|
"acc@5": 90.822,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_ops": 0.301,
|
||||||
|
"_file_size": 13.598,
|
||||||
|
"_docs": """
|
||||||
|
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
||||||
|
`new training recipe
|
||||||
|
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
DEFAULT = IMAGENET1K_V2
|
||||||
|
|
||||||
|
|
||||||
|
# @register_model()
|
||||||
|
# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
|
||||||
|
def mobilenet_v2(
|
||||||
|
*, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
|
||||||
|
) -> MobileNetV2:
|
||||||
|
"""MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
|
||||||
|
Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
|
||||||
|
pretrained weights to use. See
|
||||||
|
:class:`~torchvision.models.MobileNet_V2_Weights` below for
|
||||||
|
more details, and possible values. By default, no pre-trained
|
||||||
|
weights are used.
|
||||||
|
progress (bool, optional): If True, displays a progress bar of the
|
||||||
|
download to stderr. Default is True.
|
||||||
|
**kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
|
||||||
|
base class. Please refer to the `source code
|
||||||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
|
||||||
|
for more details about this class.
|
||||||
|
|
||||||
|
.. autoclass:: torchvision.models.MobileNet_V2_Weights
|
||||||
|
:members:
|
||||||
|
"""
|
||||||
|
weights = MobileNet_V2_Weights.verify(weights)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||||||
|
|
||||||
|
model = MobileNetV2(**kwargs)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
model.load_state_dict(weights.get_state_dict(progress=progress))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
class MobileNetv2Wrapper(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MobileNetv2Wrapper, self).__init__()
|
||||||
|
weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
|
||||||
|
|
||||||
|
self.model = MobileNetV2()
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
self.model.load_state_dict(weights.get_state_dict(progress=True))
|
||||||
|
self.out3 = conv1x1(1280, 128)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# print(x.shape)
|
||||||
|
out_layers = self.model(x)
|
||||||
|
# print(x.shape)
|
||||||
|
|
||||||
|
# out_layers[0] = self.out1(out_layers[0])
|
||||||
|
# out_layers[1] = self.out2(out_layers[1])
|
||||||
|
out_layers[2] = self.out3(out_layers[2])
|
||||||
|
# print(x.shape)
|
||||||
|
return out_layers
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
from utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
NET = Registry('net')
|
||||||
|
|
||||||
|
def build(cfg, registry, default_args=None):
|
||||||
|
if isinstance(cfg, list):
|
||||||
|
modules = [
|
||||||
|
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||||
|
]
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
else:
|
||||||
|
return build_from_cfg(cfg, registry, default_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_net(cfg):
|
||||||
|
return build(cfg.net, NET, default_args=dict(cfg=cfg))
|
|
@ -0,0 +1,142 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from models.registry import NET
|
||||||
|
# from .resnet_copy import ResNetWrapper
|
||||||
|
# from .resnet import ResNetWrapper
|
||||||
|
from .decoder_copy2 import BUSD, PlainDecoder
|
||||||
|
# from .decoder import BUSD, PlainDecoder
|
||||||
|
# from .mobilenetv2 import MobileNetv2Wrapper
|
||||||
|
from .mobilenetv2_copy2 import MobileNetv2Wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class RESA(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(RESA, self).__init__()
|
||||||
|
self.iter = cfg.resa.iter
|
||||||
|
chan = cfg.resa.input_channel
|
||||||
|
fea_stride = cfg.backbone.fea_stride
|
||||||
|
self.height = cfg.img_height // fea_stride
|
||||||
|
self.width = cfg.img_width // fea_stride
|
||||||
|
self.alpha = cfg.resa.alpha
|
||||||
|
conv_stride = cfg.resa.conv_stride
|
||||||
|
|
||||||
|
for i in range(self.iter):
|
||||||
|
conv_vert1 = nn.Conv2d(
|
||||||
|
chan, chan, (1, conv_stride),
|
||||||
|
padding=(0, conv_stride//2), groups=1, bias=False)
|
||||||
|
conv_vert2 = nn.Conv2d(
|
||||||
|
chan, chan, (1, conv_stride),
|
||||||
|
padding=(0, conv_stride//2), groups=1, bias=False)
|
||||||
|
|
||||||
|
setattr(self, 'conv_d'+str(i), conv_vert1)
|
||||||
|
setattr(self, 'conv_u'+str(i), conv_vert2)
|
||||||
|
|
||||||
|
conv_hori1 = nn.Conv2d(
|
||||||
|
chan, chan, (conv_stride, 1),
|
||||||
|
padding=(conv_stride//2, 0), groups=1, bias=False)
|
||||||
|
conv_hori2 = nn.Conv2d(
|
||||||
|
chan, chan, (conv_stride, 1),
|
||||||
|
padding=(conv_stride//2, 0), groups=1, bias=False)
|
||||||
|
|
||||||
|
setattr(self, 'conv_r'+str(i), conv_hori1)
|
||||||
|
setattr(self, 'conv_l'+str(i), conv_hori2)
|
||||||
|
|
||||||
|
idx_d = (torch.arange(self.height) + self.height //
|
||||||
|
2**(self.iter - i)) % self.height
|
||||||
|
setattr(self, 'idx_d'+str(i), idx_d)
|
||||||
|
|
||||||
|
idx_u = (torch.arange(self.height) - self.height //
|
||||||
|
2**(self.iter - i)) % self.height
|
||||||
|
setattr(self, 'idx_u'+str(i), idx_u)
|
||||||
|
|
||||||
|
idx_r = (torch.arange(self.width) + self.width //
|
||||||
|
2**(self.iter - i)) % self.width
|
||||||
|
setattr(self, 'idx_r'+str(i), idx_r)
|
||||||
|
|
||||||
|
idx_l = (torch.arange(self.width) - self.width //
|
||||||
|
2**(self.iter - i)) % self.width
|
||||||
|
setattr(self, 'idx_l'+str(i), idx_l)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.clone()
|
||||||
|
|
||||||
|
for direction in ['d', 'u']:
|
||||||
|
for i in range(self.iter):
|
||||||
|
conv = getattr(self, 'conv_' + direction + str(i))
|
||||||
|
idx = getattr(self, 'idx_' + direction + str(i))
|
||||||
|
x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
|
||||||
|
|
||||||
|
for direction in ['r', 'l']:
|
||||||
|
for i in range(self.iter):
|
||||||
|
conv = getattr(self, 'conv_' + direction + str(i))
|
||||||
|
idx = getattr(self, 'idx_' + direction + str(i))
|
||||||
|
x.add_(self.alpha * F.relu(conv(x[..., idx])))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExistHead(nn.Module):
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
super(ExistHead, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout2d(0.1) # ???
|
||||||
|
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||||
|
|
||||||
|
stride = cfg.backbone.fea_stride * 2
|
||||||
|
self.fc9 = nn.Linear(
|
||||||
|
int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
|
||||||
|
self.fc10 = nn.Linear(128, cfg.num_classes-1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv8(x)
|
||||||
|
|
||||||
|
x = F.softmax(x, dim=1)
|
||||||
|
x = F.avg_pool2d(x, 2, stride=2, padding=0)
|
||||||
|
x = x.view(-1, x.numel() // x.shape[0])
|
||||||
|
x = self.fc9(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.fc10(x)
|
||||||
|
x = torch.sigmoid(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@NET.register_module
|
||||||
|
class RESANet(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(RESANet, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
# self.backbone = ResNetWrapper(resnet='resnet34',pretrained=True,
|
||||||
|
# replace_stride_with_dilation=[False, False, False],
|
||||||
|
# out_conv=False)
|
||||||
|
self.backbone = MobileNetv2Wrapper()
|
||||||
|
self.resa = RESA(cfg)
|
||||||
|
self.decoder = eval(cfg.decoder)(cfg)
|
||||||
|
self.heads = ExistHead(cfg)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
# x1, fea, _, _ = self.backbone(batch)
|
||||||
|
# fea = self.resa(fea)
|
||||||
|
# # print(fea.shape)
|
||||||
|
# seg = self.decoder([x1,fea])
|
||||||
|
# # print(seg.shape)
|
||||||
|
# exist = self.heads(fea)
|
||||||
|
|
||||||
|
fea1,fea2,fea = self.backbone(batch)
|
||||||
|
# print('fea1',fea1.shape)
|
||||||
|
# print('fea2',fea2.shape)
|
||||||
|
# print('fea',fea.shape)
|
||||||
|
fea = self.resa(fea)
|
||||||
|
# print(fea.shape)
|
||||||
|
seg = self.decoder([fea1,fea2,fea])
|
||||||
|
# print(seg.shape)
|
||||||
|
exist = self.heads(fea)
|
||||||
|
|
||||||
|
output = {'seg': seg, 'exist': exist}
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,377 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
|
||||||
|
|
||||||
|
# This code is borrow from torchvision.
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError(
|
||||||
|
'BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
# if dilation > 1:
|
||||||
|
# raise NotImplementedError(
|
||||||
|
# "Dilation > 1 not supported in BasicBlock")
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes, dilation=dilation)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetWrapper(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(ResNetWrapper, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.in_channels = [64, 128, 256, 512]
|
||||||
|
if 'in_channels' in cfg.backbone:
|
||||||
|
self.in_channels = cfg.backbone.in_channels
|
||||||
|
self.model = eval(cfg.backbone.resnet)(
|
||||||
|
pretrained=cfg.backbone.pretrained,
|
||||||
|
replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
|
||||||
|
self.out = None
|
||||||
|
if cfg.backbone.out_conv:
|
||||||
|
out_channel = 512
|
||||||
|
for chan in reversed(self.in_channels):
|
||||||
|
if chan < 0: continue
|
||||||
|
out_channel = chan
|
||||||
|
break
|
||||||
|
self.out = conv1x1(
|
||||||
|
out_channel * self.model.expansion, 128)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.model(x)
|
||||||
|
if self.out:
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block, layers, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None, in_channels=None):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
if in_channels[3] > 0:
|
||||||
|
self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.expansion = block.expansion
|
||||||
|
|
||||||
|
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
if self.in_channels[3] > 0:
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
# x = self.avgpool(x)
|
||||||
|
# x = torch.flatten(x, 1)
|
||||||
|
# x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
progress=progress)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
|
@ -0,0 +1,432 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18':
|
||||||
|
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34':
|
||||||
|
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50':
|
||||||
|
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101':
|
||||||
|
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152':
|
||||||
|
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d':
|
||||||
|
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d':
|
||||||
|
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2':
|
||||||
|
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2':
|
||||||
|
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes,
|
||||||
|
out_planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=False,
|
||||||
|
dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes,
|
||||||
|
out_planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
groups=1,
|
||||||
|
base_width=64,
|
||||||
|
dilation=1,
|
||||||
|
norm_layer=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError(
|
||||||
|
'BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
# if dilation > 1:
|
||||||
|
# raise NotImplementedError(
|
||||||
|
# "Dilation > 1 not supported in BasicBlock")
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes, dilation=dilation)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
groups=1,
|
||||||
|
base_width=64,
|
||||||
|
dilation=1,
|
||||||
|
norm_layer=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetWrapper(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
resnet='resnet18',
|
||||||
|
pretrained=True,
|
||||||
|
replace_stride_with_dilation=[False, False, False],
|
||||||
|
out_conv=False,
|
||||||
|
fea_stride=8,
|
||||||
|
out_channel=128,
|
||||||
|
in_channels=[64, 128, 256, 512],
|
||||||
|
cfg=None):
|
||||||
|
super(ResNetWrapper, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.model = eval(resnet)(
|
||||||
|
pretrained=pretrained,
|
||||||
|
replace_stride_with_dilation=replace_stride_with_dilation,
|
||||||
|
in_channels=self.in_channels)
|
||||||
|
self.out = None
|
||||||
|
if out_conv:
|
||||||
|
out_channel = 512
|
||||||
|
for chan in reversed(self.in_channels):
|
||||||
|
if chan < 0: continue
|
||||||
|
out_channel = chan
|
||||||
|
break
|
||||||
|
self.out = conv1x1(out_channel * self.model.expansion,
|
||||||
|
cfg.featuremap_out_channel)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.model(x)
|
||||||
|
if self.out:
|
||||||
|
x[-1] = self.out(x[-1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
block,
|
||||||
|
layers,
|
||||||
|
zero_init_residual=False,
|
||||||
|
groups=1,
|
||||||
|
width_per_group=64,
|
||||||
|
replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None,
|
||||||
|
in_channels=None):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(
|
||||||
|
replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3,
|
||||||
|
self.inplanes,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
|
||||||
|
self.layer2 = self._make_layer(block,
|
||||||
|
in_channels[1],
|
||||||
|
layers[1],
|
||||||
|
stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block,
|
||||||
|
in_channels[2],
|
||||||
|
layers[2],
|
||||||
|
stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
if in_channels[3] > 0:
|
||||||
|
self.layer4 = self._make_layer(
|
||||||
|
block,
|
||||||
|
in_channels[3],
|
||||||
|
layers[3],
|
||||||
|
stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.expansion = block.expansion
|
||||||
|
|
||||||
|
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight,
|
||||||
|
mode='fan_out',
|
||||||
|
nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(
|
||||||
|
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(
|
||||||
|
block(self.inplanes,
|
||||||
|
planes,
|
||||||
|
groups=self.groups,
|
||||||
|
base_width=self.base_width,
|
||||||
|
dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_layers = []
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
# out_layers = []
|
||||||
|
for name in ['layer1', 'layer2', 'layer3', 'layer4']:
|
||||||
|
if not hasattr(self, name):
|
||||||
|
continue
|
||||||
|
layer = getattr(self, name)
|
||||||
|
x = layer(x)
|
||||||
|
out_layers.append(x)
|
||||||
|
|
||||||
|
return out_layers
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
print('pretrained model: ', model_urls[arch])
|
||||||
|
# state_dict = torch.load(model_urls[arch])['net']
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch])
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||||
|
progress, **kwargs)
|
|
@ -0,0 +1,8 @@
|
||||||
|
pandas
|
||||||
|
addict
|
||||||
|
sklearn
|
||||||
|
opencv-python
|
||||||
|
pytorch_warmup
|
||||||
|
scikit-image
|
||||||
|
tqdm
|
||||||
|
termcolor
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .evaluator import *
|
||||||
|
from .resa_trainer import *
|
||||||
|
|
||||||
|
from .registry import build_evaluator
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .tusimple.tusimple import Tusimple
|
||||||
|
from .culane.culane import CULane
|
|
@ -0,0 +1,158 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from runner.logger import get_logger
|
||||||
|
|
||||||
|
from runner.registry import EVALUATOR
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from shutil import rmtree
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def check():
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
FNULL = open(os.devnull, 'w')
|
||||||
|
result = subprocess.call(
|
||||||
|
'./runner/evaluator/culane/lane_evaluation/evaluate', stdout=FNULL, stderr=FNULL)
|
||||||
|
if result > 1:
|
||||||
|
print('There is something wrong with evaluate tool, please compile it.')
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
def read_helper(path):
|
||||||
|
lines = open(path, 'r').readlines()[1:]
|
||||||
|
lines = ' '.join(lines)
|
||||||
|
values = lines.split(' ')[1::2]
|
||||||
|
keys = lines.split(' ')[0::2]
|
||||||
|
keys = [key[:-1] for key in keys]
|
||||||
|
res = {k : v for k,v in zip(keys,values)}
|
||||||
|
return res
|
||||||
|
|
||||||
|
def call_culane_eval(data_dir, output_path='./output'):
|
||||||
|
if data_dir[-1] != '/':
|
||||||
|
data_dir = data_dir + '/'
|
||||||
|
detect_dir=os.path.join(output_path, 'lines')+'/'
|
||||||
|
|
||||||
|
w_lane=30
|
||||||
|
iou=0.5; # Set iou to 0.3 or 0.5
|
||||||
|
im_w=1640
|
||||||
|
im_h=590
|
||||||
|
frame=1
|
||||||
|
list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt')
|
||||||
|
list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt')
|
||||||
|
list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt')
|
||||||
|
list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt')
|
||||||
|
list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt')
|
||||||
|
list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt')
|
||||||
|
list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt')
|
||||||
|
list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt')
|
||||||
|
list8 = os.path.join(data_dir,'list/test_split/test8_night.txt')
|
||||||
|
if not os.path.exists(os.path.join(output_path,'txt')):
|
||||||
|
os.mkdir(os.path.join(output_path,'txt'))
|
||||||
|
out0 = os.path.join(output_path,'txt','out0_normal.txt')
|
||||||
|
out1 = os.path.join(output_path,'txt','out1_crowd.txt')
|
||||||
|
out2 = os.path.join(output_path,'txt','out2_hlight.txt')
|
||||||
|
out3 = os.path.join(output_path,'txt','out3_shadow.txt')
|
||||||
|
out4 = os.path.join(output_path,'txt','out4_noline.txt')
|
||||||
|
out5 = os.path.join(output_path,'txt','out5_arrow.txt')
|
||||||
|
out6 = os.path.join(output_path,'txt','out6_curve.txt')
|
||||||
|
out7 = os.path.join(output_path,'txt','out7_cross.txt')
|
||||||
|
out8 = os.path.join(output_path,'txt','out8_night.txt')
|
||||||
|
|
||||||
|
eval_cmd = './runner/evaluator/culane/lane_evaluation/evaluate'
|
||||||
|
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
|
||||||
|
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
|
||||||
|
res_all = {}
|
||||||
|
res_all['normal'] = read_helper(out0)
|
||||||
|
res_all['crowd']= read_helper(out1)
|
||||||
|
res_all['night']= read_helper(out8)
|
||||||
|
res_all['noline'] = read_helper(out4)
|
||||||
|
res_all['shadow'] = read_helper(out3)
|
||||||
|
res_all['arrow']= read_helper(out5)
|
||||||
|
res_all['hlight'] = read_helper(out2)
|
||||||
|
res_all['curve']= read_helper(out6)
|
||||||
|
res_all['cross']= read_helper(out7)
|
||||||
|
return res_all
|
||||||
|
|
||||||
|
@EVALUATOR.register_module
|
||||||
|
class CULane(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(CULane, self).__init__()
|
||||||
|
# Firstly, check the evaluation tool
|
||||||
|
check()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.blur = torch.nn.Conv2d(
|
||||||
|
5, 5, 9, padding=4, bias=False, groups=5).cuda()
|
||||||
|
torch.nn.init.constant_(self.blur.weight, 1 / 81)
|
||||||
|
self.logger = get_logger('resa')
|
||||||
|
self.out_dir = os.path.join(self.cfg.work_dir, 'lines')
|
||||||
|
if cfg.view:
|
||||||
|
self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
|
||||||
|
|
||||||
|
def evaluate(self, dataset, output, batch):
|
||||||
|
seg, exists = output['seg'], output['exist']
|
||||||
|
predictmaps = F.softmax(seg, dim=1).cpu().numpy()
|
||||||
|
exists = exists.cpu().numpy()
|
||||||
|
batch_size = seg.size(0)
|
||||||
|
img_name = batch['meta']['img_name']
|
||||||
|
img_path = batch['meta']['full_img_path']
|
||||||
|
for i in range(batch_size):
|
||||||
|
coords = dataset.probmap2lane(predictmaps[i], exists[i])
|
||||||
|
outname = self.out_dir + img_name[i][:-4] + '.lines.txt'
|
||||||
|
outdir = os.path.dirname(outname)
|
||||||
|
if not os.path.exists(outdir):
|
||||||
|
os.makedirs(outdir)
|
||||||
|
f = open(outname, 'w')
|
||||||
|
for coord in coords:
|
||||||
|
for x, y in coord:
|
||||||
|
if x < 0 and y < 0:
|
||||||
|
continue
|
||||||
|
f.write('%d %d ' % (x, y))
|
||||||
|
f.write('\n')
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
if self.cfg.view:
|
||||||
|
img = cv2.imread(img_path[i]).astype(np.float32)
|
||||||
|
dataset.view(img, coords, self.view_dir+img_name[i])
|
||||||
|
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
self.logger.info('summarize result...')
|
||||||
|
eval_list_path = os.path.join(
|
||||||
|
self.cfg.dataset_path, "list", self.cfg.dataset.val.data_list)
|
||||||
|
#prob2lines(self.prob_dir, self.out_dir, eval_list_path, self.cfg)
|
||||||
|
res = call_culane_eval(self.cfg.dataset_path, output_path=self.cfg.work_dir)
|
||||||
|
TP,FP,FN = 0,0,0
|
||||||
|
out_str = 'Copypaste: '
|
||||||
|
for k, v in res.items():
|
||||||
|
val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
|
||||||
|
val_tp, val_fp, val_fn = int(v['tp']), int(v['fp']), int(v['fn'])
|
||||||
|
val_p, val_r, val_f1 = float(v['precision']), float(v['recall']), float(v['Fmeasure'])
|
||||||
|
TP += val_tp
|
||||||
|
FP += val_fp
|
||||||
|
FN += val_fn
|
||||||
|
self.logger.info(k + ': ' + str(v))
|
||||||
|
out_str += k
|
||||||
|
for metric, value in v.items():
|
||||||
|
out_str += ' ' + str(value).rstrip('\n')
|
||||||
|
out_str += ' '
|
||||||
|
P = TP * 1.0 / (TP + FP + 1e-9)
|
||||||
|
R = TP * 1.0 / (TP + FN + 1e-9)
|
||||||
|
F = 2*P*R/(P + R + 1e-9)
|
||||||
|
overall_result_str = ('Overall Precision: %f Recall: %f F1: %f' % (P, R, F))
|
||||||
|
self.logger.info(overall_result_str)
|
||||||
|
out_str = out_str + overall_result_str
|
||||||
|
self.logger.info(out_str)
|
||||||
|
|
||||||
|
# delete the tmp output
|
||||||
|
rmtree(self.out_dir)
|
|
@ -0,0 +1,2 @@
|
||||||
|
build/
|
||||||
|
evaluate
|
|
@ -0,0 +1,50 @@
|
||||||
|
PROJECT_NAME:= evaluate
|
||||||
|
|
||||||
|
# config ----------------------------------
|
||||||
|
OPENCV_VERSION := 3
|
||||||
|
|
||||||
|
INCLUDE_DIRS := include
|
||||||
|
LIBRARY_DIRS := lib /usr/local/lib
|
||||||
|
|
||||||
|
COMMON_FLAGS := -DCPU_ONLY
|
||||||
|
CXXFLAGS := -std=c++11 -fopenmp
|
||||||
|
LDFLAGS := -fopenmp -Wl,-rpath,./lib
|
||||||
|
BUILD_DIR := build
|
||||||
|
|
||||||
|
|
||||||
|
# make rules -------------------------------
|
||||||
|
CXX ?= g++
|
||||||
|
BUILD_DIR ?= ./build
|
||||||
|
|
||||||
|
LIBRARIES += opencv_core opencv_highgui opencv_imgproc
|
||||||
|
ifeq ($(OPENCV_VERSION), 3)
|
||||||
|
LIBRARIES += opencv_imgcodecs
|
||||||
|
endif
|
||||||
|
|
||||||
|
CXXFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
|
||||||
|
LDFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(LIBRARY_DIRS),-L$(includedir)) $(foreach library,$(LIBRARIES),-l$(library))
|
||||||
|
SRC_DIRS += $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print)
|
||||||
|
CXX_SRCS += $(shell find src/ -name "*.cpp")
|
||||||
|
CXX_TARGETS:=$(patsubst %.cpp, $(BUILD_DIR)/%.o, $(CXX_SRCS))
|
||||||
|
ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS)))
|
||||||
|
|
||||||
|
.PHONY: all
|
||||||
|
all: $(PROJECT_NAME)
|
||||||
|
|
||||||
|
.PHONY: $(ALL_BUILD_DIRS)
|
||||||
|
$(ALL_BUILD_DIRS):
|
||||||
|
@mkdir -p $@
|
||||||
|
|
||||||
|
$(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS)
|
||||||
|
@echo "CXX" $<
|
||||||
|
@$(CXX) $(CXXFLAGS) -c -o $@ $<
|
||||||
|
|
||||||
|
$(PROJECT_NAME): $(CXX_TARGETS)
|
||||||
|
@echo "CXX/LD" $@
|
||||||
|
@$(CXX) -o $@ $^ $(LDFLAGS)
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean:
|
||||||
|
@rm -rf $(CXX_TARGETS)
|
||||||
|
@rm -rf $(PROJECT_NAME)
|
||||||
|
@rm -rf $(BUILD_DIR)
|
|
@ -0,0 +1,47 @@
|
||||||
|
#ifndef COUNTER_HPP
|
||||||
|
#define COUNTER_HPP
|
||||||
|
|
||||||
|
#include "lane_compare.hpp"
|
||||||
|
#include "hungarianGraph.hpp"
|
||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
#include <opencv2/core/core.hpp>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
// before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp
|
||||||
|
class Counter
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){
|
||||||
|
im_width = _im_width;
|
||||||
|
im_height = _im_height;
|
||||||
|
sim_threshold = _iou_threshold;
|
||||||
|
lane_compare = new LaneCompare(_im_width, _im_height, _lane_width, LaneCompare::IOU);
|
||||||
|
};
|
||||||
|
double get_precision(void);
|
||||||
|
double get_recall(void);
|
||||||
|
long getTP(void);
|
||||||
|
long getFP(void);
|
||||||
|
long getFN(void);
|
||||||
|
void setTP(long);
|
||||||
|
void setFP(long);
|
||||||
|
void setFN(long);
|
||||||
|
// direct add tp, fp, tn and fn
|
||||||
|
// first match with hungarian
|
||||||
|
tuple<vector<int>, long, long, long, long> count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes);
|
||||||
|
void makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2);
|
||||||
|
|
||||||
|
private:
|
||||||
|
double sim_threshold;
|
||||||
|
int im_width;
|
||||||
|
int im_height;
|
||||||
|
long tp;
|
||||||
|
long fp;
|
||||||
|
long fn;
|
||||||
|
LaneCompare *lane_compare;
|
||||||
|
};
|
||||||
|
#endif
|
|
@ -0,0 +1,71 @@
|
||||||
|
#ifndef HUNGARIAN_GRAPH_HPP
|
||||||
|
#define HUNGARIAN_GRAPH_HPP
|
||||||
|
#include <vector>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
struct pipartiteGraph {
|
||||||
|
vector<vector<double> > mat;
|
||||||
|
vector<bool> leftUsed, rightUsed;
|
||||||
|
vector<double> leftWeight, rightWeight;
|
||||||
|
vector<int>rightMatch, leftMatch;
|
||||||
|
int leftNum, rightNum;
|
||||||
|
bool matchDfs(int u) {
|
||||||
|
leftUsed[u] = true;
|
||||||
|
for (int v = 0; v < rightNum; v++) {
|
||||||
|
if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) {
|
||||||
|
rightUsed[v] = true;
|
||||||
|
if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) {
|
||||||
|
rightMatch[v] = u;
|
||||||
|
leftMatch[u] = v;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void resize(int leftNum, int rightNum) {
|
||||||
|
this->leftNum = leftNum;
|
||||||
|
this->rightNum = rightNum;
|
||||||
|
leftMatch.resize(leftNum);
|
||||||
|
rightMatch.resize(rightNum);
|
||||||
|
leftUsed.resize(leftNum);
|
||||||
|
rightUsed.resize(rightNum);
|
||||||
|
leftWeight.resize(leftNum);
|
||||||
|
rightWeight.resize(rightNum);
|
||||||
|
mat.resize(leftNum);
|
||||||
|
for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum);
|
||||||
|
}
|
||||||
|
void match() {
|
||||||
|
for (int i = 0; i < leftNum; i++) leftMatch[i] = -1;
|
||||||
|
for (int i = 0; i < rightNum; i++) rightMatch[i] = -1;
|
||||||
|
for (int i = 0; i < rightNum; i++) rightWeight[i] = 0;
|
||||||
|
for (int i = 0; i < leftNum; i++) {
|
||||||
|
leftWeight[i] = -1e5;
|
||||||
|
for (int j = 0; j < rightNum; j++) {
|
||||||
|
if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int u = 0; u < leftNum; u++) {
|
||||||
|
while (1) {
|
||||||
|
for (int i = 0; i < leftNum; i++) leftUsed[i] = false;
|
||||||
|
for (int i = 0; i < rightNum; i++) rightUsed[i] = false;
|
||||||
|
if (matchDfs(u)) break;
|
||||||
|
double d = 1e10;
|
||||||
|
for (int i = 0; i < leftNum; i++) {
|
||||||
|
if (leftUsed[i] ) {
|
||||||
|
for (int j = 0; j < rightNum; j++) {
|
||||||
|
if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (d == 1e10) return ;
|
||||||
|
for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d;
|
||||||
|
for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif // HUNGARIAN_GRAPH_HPP
|
|
@ -0,0 +1,51 @@
|
||||||
|
#ifndef LANE_COMPARE_HPP
|
||||||
|
#define LANE_COMPARE_HPP
|
||||||
|
|
||||||
|
#include "spline.hpp"
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include <opencv2/core/version.hpp>
|
||||||
|
#include <opencv2/core/core.hpp>
|
||||||
|
|
||||||
|
#if CV_VERSION_EPOCH == 2
|
||||||
|
#define OPENCV2
|
||||||
|
#elif CV_VERSION_MAJOR == 3
|
||||||
|
#define OPENCV3
|
||||||
|
#else
|
||||||
|
#error Not support this OpenCV version
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef OPENCV3
|
||||||
|
#include <opencv2/imgproc.hpp>
|
||||||
|
#elif defined(OPENCV2)
|
||||||
|
#include <opencv2/imgproc/imgproc.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
class LaneCompare{
|
||||||
|
public:
|
||||||
|
enum CompareMode{
|
||||||
|
IOU,
|
||||||
|
Caltech
|
||||||
|
};
|
||||||
|
|
||||||
|
LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){
|
||||||
|
im_width = _im_width;
|
||||||
|
im_height = _im_height;
|
||||||
|
compare_mode = _compare_mode;
|
||||||
|
lane_width = _lane_width;
|
||||||
|
}
|
||||||
|
|
||||||
|
double get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2);
|
||||||
|
void resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height);
|
||||||
|
private:
|
||||||
|
CompareMode compare_mode;
|
||||||
|
int im_width;
|
||||||
|
int im_height;
|
||||||
|
int lane_width;
|
||||||
|
Spline splineSolver;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,28 @@
|
||||||
|
#ifndef SPLINE_HPP
|
||||||
|
#define SPLINE_HPP
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <math.h>
|
||||||
|
#include <opencv2/core/core.hpp>
|
||||||
|
|
||||||
|
using namespace cv;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
struct Func {
|
||||||
|
double a_x;
|
||||||
|
double b_x;
|
||||||
|
double c_x;
|
||||||
|
double d_x;
|
||||||
|
double a_y;
|
||||||
|
double b_y;
|
||||||
|
double c_y;
|
||||||
|
double d_y;
|
||||||
|
double h;
|
||||||
|
};
|
||||||
|
class Spline {
|
||||||
|
public:
|
||||||
|
vector<Point2f> splineInterpTimes(const vector<Point2f> &tmp_line, int times);
|
||||||
|
vector<Point2f> splineInterpStep(vector<Point2f> tmp_line, double step);
|
||||||
|
vector<Func> cal_fun(const vector<Point2f> &point_v);
|
||||||
|
};
|
||||||
|
#endif
|
|
@ -0,0 +1,134 @@
|
||||||
|
/*************************************************************************
|
||||||
|
> File Name: counter.cpp
|
||||||
|
> Author: Xingang Pan, Jun Li
|
||||||
|
> Mail: px117@ie.cuhk.edu.hk
|
||||||
|
> Created Time: Thu Jul 14 20:23:08 2016
|
||||||
|
************************************************************************/
|
||||||
|
|
||||||
|
#include "counter.hpp"
|
||||||
|
|
||||||
|
double Counter::get_precision(void)
|
||||||
|
{
|
||||||
|
cerr<<"tp: "<<tp<<" fp: "<<fp<<" fn: "<<fn<<endl;
|
||||||
|
if(tp+fp == 0)
|
||||||
|
{
|
||||||
|
cerr<<"no positive detection"<<endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return tp/double(tp + fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
double Counter::get_recall(void)
|
||||||
|
{
|
||||||
|
if(tp+fn == 0)
|
||||||
|
{
|
||||||
|
cerr<<"no ground truth positive"<<endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return tp/double(tp + fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
long Counter::getTP(void)
|
||||||
|
{
|
||||||
|
return tp;
|
||||||
|
}
|
||||||
|
|
||||||
|
long Counter::getFP(void)
|
||||||
|
{
|
||||||
|
return fp;
|
||||||
|
}
|
||||||
|
|
||||||
|
long Counter::getFN(void)
|
||||||
|
{
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Counter::setTP(long value)
|
||||||
|
{
|
||||||
|
tp = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Counter::setFP(long value)
|
||||||
|
{
|
||||||
|
fp = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Counter::setFN(long value)
|
||||||
|
{
|
||||||
|
fn = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
tuple<vector<int>, long, long, long, long> Counter::count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes)
|
||||||
|
{
|
||||||
|
vector<int> anno_match(anno_lanes.size(), -1);
|
||||||
|
vector<int> detect_match;
|
||||||
|
if(anno_lanes.empty())
|
||||||
|
{
|
||||||
|
return make_tuple(anno_match, 0, detect_lanes.size(), 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(detect_lanes.empty())
|
||||||
|
{
|
||||||
|
return make_tuple(anno_match, 0, 0, 0, anno_lanes.size());
|
||||||
|
}
|
||||||
|
// hungarian match first
|
||||||
|
|
||||||
|
// first calc similarity matrix
|
||||||
|
vector<vector<double> > similarity(anno_lanes.size(), vector<double>(detect_lanes.size(), 0));
|
||||||
|
for(int i=0; i<anno_lanes.size(); i++)
|
||||||
|
{
|
||||||
|
const vector<Point2f> &curr_anno_lane = anno_lanes[i];
|
||||||
|
for(int j=0; j<detect_lanes.size(); j++)
|
||||||
|
{
|
||||||
|
const vector<Point2f> &curr_detect_lane = detect_lanes[j];
|
||||||
|
similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
makeMatch(similarity, anno_match, detect_match);
|
||||||
|
|
||||||
|
|
||||||
|
int curr_tp = 0;
|
||||||
|
// count and add
|
||||||
|
for(int i=0; i<anno_lanes.size(); i++)
|
||||||
|
{
|
||||||
|
if(anno_match[i]>=0 && similarity[i][anno_match[i]] > sim_threshold)
|
||||||
|
{
|
||||||
|
curr_tp++;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
anno_match[i] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int curr_fn = anno_lanes.size() - curr_tp;
|
||||||
|
int curr_fp = detect_lanes.size() - curr_tp;
|
||||||
|
return make_tuple(anno_match, curr_tp, curr_fp, 0, curr_fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void Counter::makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2) {
|
||||||
|
int m = similarity.size();
|
||||||
|
int n = similarity[0].size();
|
||||||
|
pipartiteGraph gra;
|
||||||
|
bool have_exchange = false;
|
||||||
|
if (m > n) {
|
||||||
|
have_exchange = true;
|
||||||
|
swap(m, n);
|
||||||
|
}
|
||||||
|
gra.resize(m, n);
|
||||||
|
for (int i = 0; i < gra.leftNum; i++) {
|
||||||
|
for (int j = 0; j < gra.rightNum; j++) {
|
||||||
|
if(have_exchange)
|
||||||
|
gra.mat[i][j] = similarity[j][i];
|
||||||
|
else
|
||||||
|
gra.mat[i][j] = similarity[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gra.match();
|
||||||
|
match1 = gra.leftMatch;
|
||||||
|
match2 = gra.rightMatch;
|
||||||
|
if (have_exchange) swap(match1, match2);
|
||||||
|
}
|
|
@ -0,0 +1,302 @@
|
||||||
|
/*************************************************************************
|
||||||
|
> File Name: evaluate.cpp
|
||||||
|
> Author: Xingang Pan, Jun Li
|
||||||
|
> Mail: px117@ie.cuhk.edu.hk
|
||||||
|
> Created Time: 2016年07月14日 星期四 18时28分45秒
|
||||||
|
************************************************************************/
|
||||||
|
|
||||||
|
#include "counter.hpp"
|
||||||
|
#include "spline.hpp"
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <string>
|
||||||
|
#include <opencv2/core/core.hpp>
|
||||||
|
#include <opencv2/highgui/highgui.hpp>
|
||||||
|
using namespace std;
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
void help(void) {
|
||||||
|
cout << "./evaluate [OPTIONS]" << endl;
|
||||||
|
cout << "-h : print usage help" << endl;
|
||||||
|
cout << "-a : directory for annotation files (default: "
|
||||||
|
"/data/driving/eval_data/anno_label/)" << endl;
|
||||||
|
cout << "-d : directory for detection files (default: "
|
||||||
|
"/data/driving/eval_data/predict_label/)" << endl;
|
||||||
|
cout << "-i : directory for image files (default: "
|
||||||
|
"/data/driving/eval_data/img/)" << endl;
|
||||||
|
cout << "-l : list of images used for evaluation (default: "
|
||||||
|
"/data/driving/eval_data/img/all.txt)" << endl;
|
||||||
|
cout << "-w : width of the lanes (default: 10)" << endl;
|
||||||
|
cout << "-t : threshold of iou (default: 0.4)" << endl;
|
||||||
|
cout << "-c : cols (max image width) (default: 1920)"
|
||||||
|
<< endl;
|
||||||
|
cout << "-r : rows (max image height) (default: 1080)"
|
||||||
|
<< endl;
|
||||||
|
cout << "-s : show visualization" << endl;
|
||||||
|
cout << "-f : start frame in the test set (default: 1)"
|
||||||
|
<< endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes);
|
||||||
|
void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
|
||||||
|
vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
|
||||||
|
int width_lane, string save_path = "");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
// process params
|
||||||
|
string anno_dir = "/data/driving/eval_data/anno_label/";
|
||||||
|
string detect_dir = "/data/driving/eval_data/predict_label/";
|
||||||
|
string im_dir = "/data/driving/eval_data/img/";
|
||||||
|
string list_im_file = "/data/driving/eval_data/img/all.txt";
|
||||||
|
string output_file = "./output.txt";
|
||||||
|
int width_lane = 10;
|
||||||
|
double iou_threshold = 0.4;
|
||||||
|
int im_width = 1920;
|
||||||
|
int im_height = 1080;
|
||||||
|
int oc;
|
||||||
|
bool show = false;
|
||||||
|
int frame = 1;
|
||||||
|
string save_path = "";
|
||||||
|
while ((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:p:")) != -1) {
|
||||||
|
switch (oc) {
|
||||||
|
case 'h':
|
||||||
|
help();
|
||||||
|
return 0;
|
||||||
|
case 'a':
|
||||||
|
anno_dir = optarg;
|
||||||
|
break;
|
||||||
|
case 'd':
|
||||||
|
detect_dir = optarg;
|
||||||
|
break;
|
||||||
|
case 'i':
|
||||||
|
im_dir = optarg;
|
||||||
|
break;
|
||||||
|
case 'l':
|
||||||
|
list_im_file = optarg;
|
||||||
|
break;
|
||||||
|
case 'w':
|
||||||
|
width_lane = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 't':
|
||||||
|
iou_threshold = atof(optarg);
|
||||||
|
break;
|
||||||
|
case 'c':
|
||||||
|
im_width = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
im_height = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
show = true;
|
||||||
|
break;
|
||||||
|
case 'p':
|
||||||
|
save_path = optarg;
|
||||||
|
break;
|
||||||
|
case 'f':
|
||||||
|
frame = atoi(optarg);
|
||||||
|
break;
|
||||||
|
case 'o':
|
||||||
|
output_file = optarg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "------------Configuration---------" << endl;
|
||||||
|
cout << "anno_dir: " << anno_dir << endl;
|
||||||
|
cout << "detect_dir: " << detect_dir << endl;
|
||||||
|
cout << "im_dir: " << im_dir << endl;
|
||||||
|
cout << "list_im_file: " << list_im_file << endl;
|
||||||
|
cout << "width_lane: " << width_lane << endl;
|
||||||
|
cout << "iou_threshold: " << iou_threshold << endl;
|
||||||
|
cout << "im_width: " << im_width << endl;
|
||||||
|
cout << "im_height: " << im_height << endl;
|
||||||
|
cout << "-----------------------------------" << endl;
|
||||||
|
cout << "Evaluating the results..." << endl;
|
||||||
|
// this is the max_width and max_height
|
||||||
|
|
||||||
|
if (width_lane < 1) {
|
||||||
|
cerr << "width_lane must be positive" << endl;
|
||||||
|
help();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
ifstream ifs_im_list(list_im_file, ios::in);
|
||||||
|
if (ifs_im_list.fail()) {
|
||||||
|
cerr << "Error: file " << list_im_file << " not exist!" << endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Counter counter(im_width, im_height, iou_threshold, width_lane);
|
||||||
|
|
||||||
|
vector<int> anno_match;
|
||||||
|
string sub_im_name;
|
||||||
|
// pre-load filelist
|
||||||
|
vector<string> filelists;
|
||||||
|
while (getline(ifs_im_list, sub_im_name)) {
|
||||||
|
filelists.push_back(sub_im_name);
|
||||||
|
}
|
||||||
|
ifs_im_list.close();
|
||||||
|
|
||||||
|
vector<tuple<vector<int>, long, long, long, long>> tuple_lists;
|
||||||
|
tuple_lists.resize(filelists.size());
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = 0; i < filelists.size(); i++) {
|
||||||
|
auto sub_im_name = filelists[i];
|
||||||
|
string full_im_name = im_dir + sub_im_name;
|
||||||
|
string sub_txt_name =
|
||||||
|
sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt";
|
||||||
|
string anno_file_name = anno_dir + sub_txt_name;
|
||||||
|
string detect_file_name = detect_dir + sub_txt_name;
|
||||||
|
vector<vector<Point2f>> anno_lanes;
|
||||||
|
vector<vector<Point2f>> detect_lanes;
|
||||||
|
read_lane_file(anno_file_name, anno_lanes);
|
||||||
|
read_lane_file(detect_file_name, detect_lanes);
|
||||||
|
// cerr<<count<<": "<<full_im_name<<endl;
|
||||||
|
tuple_lists[i] = counter.count_im_pair(anno_lanes, detect_lanes);
|
||||||
|
if (show) {
|
||||||
|
auto anno_match = get<0>(tuple_lists[i]);
|
||||||
|
visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane);
|
||||||
|
waitKey(0);
|
||||||
|
}
|
||||||
|
if (save_path != "") {
|
||||||
|
auto anno_match = get<0>(tuple_lists[i]);
|
||||||
|
visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane,
|
||||||
|
save_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
long tp = 0, fp = 0, tn = 0, fn = 0;
|
||||||
|
for (auto result : tuple_lists) {
|
||||||
|
tp += get<1>(result);
|
||||||
|
fp += get<2>(result);
|
||||||
|
// tn = get<3>(result);
|
||||||
|
fn += get<4>(result);
|
||||||
|
}
|
||||||
|
counter.setTP(tp);
|
||||||
|
counter.setFP(fp);
|
||||||
|
counter.setFN(fn);
|
||||||
|
|
||||||
|
double precision = counter.get_precision();
|
||||||
|
double recall = counter.get_recall();
|
||||||
|
double F = 2 * precision * recall / (precision + recall);
|
||||||
|
cerr << "finished process file" << endl;
|
||||||
|
cout << "precision: " << precision << endl;
|
||||||
|
cout << "recall: " << recall << endl;
|
||||||
|
cout << "Fmeasure: " << F << endl;
|
||||||
|
cout << "----------------------------------" << endl;
|
||||||
|
|
||||||
|
ofstream ofs_out_file;
|
||||||
|
ofs_out_file.open(output_file, ios::out);
|
||||||
|
ofs_out_file << "file: " << output_file << endl;
|
||||||
|
ofs_out_file << "tp: " << counter.getTP() << " fp: " << counter.getFP()
|
||||||
|
<< " fn: " << counter.getFN() << endl;
|
||||||
|
ofs_out_file << "precision: " << precision << endl;
|
||||||
|
ofs_out_file << "recall: " << recall << endl;
|
||||||
|
ofs_out_file << "Fmeasure: " << F << endl << endl;
|
||||||
|
ofs_out_file.close();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes) {
|
||||||
|
lanes.clear();
|
||||||
|
ifstream ifs_lane(file_name, ios::in);
|
||||||
|
if (ifs_lane.fail()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
string str_line;
|
||||||
|
while (getline(ifs_lane, str_line)) {
|
||||||
|
vector<Point2f> curr_lane;
|
||||||
|
stringstream ss;
|
||||||
|
ss << str_line;
|
||||||
|
double x, y;
|
||||||
|
while (ss >> x >> y) {
|
||||||
|
curr_lane.push_back(Point2f(x, y));
|
||||||
|
}
|
||||||
|
lanes.push_back(curr_lane);
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs_lane.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
|
||||||
|
vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
|
||||||
|
int width_lane, string save_path) {
|
||||||
|
Mat img = imread(full_im_name, 1);
|
||||||
|
Mat img2 = imread(full_im_name, 1);
|
||||||
|
vector<Point2f> curr_lane;
|
||||||
|
vector<Point2f> p_interp;
|
||||||
|
Spline splineSolver;
|
||||||
|
Scalar color_B = Scalar(255, 0, 0);
|
||||||
|
Scalar color_G = Scalar(0, 255, 0);
|
||||||
|
Scalar color_R = Scalar(0, 0, 255);
|
||||||
|
Scalar color_P = Scalar(255, 0, 255);
|
||||||
|
Scalar color;
|
||||||
|
for (int i = 0; i < anno_lanes.size(); i++) {
|
||||||
|
curr_lane = anno_lanes[i];
|
||||||
|
if (curr_lane.size() == 2) {
|
||||||
|
p_interp = curr_lane;
|
||||||
|
} else {
|
||||||
|
p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
|
||||||
|
}
|
||||||
|
if (anno_match[i] >= 0) {
|
||||||
|
color = color_G;
|
||||||
|
} else {
|
||||||
|
color = color_G;
|
||||||
|
}
|
||||||
|
for (int n = 0; n < p_interp.size() - 1; n++) {
|
||||||
|
line(img, p_interp[n], p_interp[n + 1], color, width_lane);
|
||||||
|
line(img2, p_interp[n], p_interp[n + 1], color, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool detected;
|
||||||
|
for (int i = 0; i < detect_lanes.size(); i++) {
|
||||||
|
detected = false;
|
||||||
|
curr_lane = detect_lanes[i];
|
||||||
|
if (curr_lane.size() == 2) {
|
||||||
|
p_interp = curr_lane;
|
||||||
|
} else {
|
||||||
|
p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
|
||||||
|
}
|
||||||
|
for (int n = 0; n < anno_lanes.size(); n++) {
|
||||||
|
if (anno_match[n] == i) {
|
||||||
|
detected = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (detected == true) {
|
||||||
|
color = color_B;
|
||||||
|
} else {
|
||||||
|
color = color_R;
|
||||||
|
}
|
||||||
|
for (int n = 0; n < p_interp.size() - 1; n++) {
|
||||||
|
line(img, p_interp[n], p_interp[n + 1], color, width_lane);
|
||||||
|
line(img2, p_interp[n], p_interp[n + 1], color, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (save_path != "") {
|
||||||
|
size_t pos = 0;
|
||||||
|
string s = full_im_name;
|
||||||
|
std::string token;
|
||||||
|
std::string delimiter = "/";
|
||||||
|
vector<string> names;
|
||||||
|
while ((pos = s.find(delimiter)) != std::string::npos) {
|
||||||
|
token = s.substr(0, pos);
|
||||||
|
names.emplace_back(token);
|
||||||
|
s.erase(0, pos + delimiter.length());
|
||||||
|
}
|
||||||
|
names.emplace_back(s);
|
||||||
|
string file_name = names[3] + '_' + names[4] + '_' + names[5];
|
||||||
|
// cout << file_name << endl;
|
||||||
|
imwrite(save_path + '/' + file_name, img);
|
||||||
|
} else {
|
||||||
|
namedWindow("visualize", 1);
|
||||||
|
imshow("visualize", img);
|
||||||
|
namedWindow("visualize2", 1);
|
||||||
|
imshow("visualize2", img2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
/*************************************************************************
|
||||||
|
> File Name: lane_compare.cpp
|
||||||
|
> Author: Xingang Pan, Jun Li
|
||||||
|
> Mail: px117@ie.cuhk.edu.hk
|
||||||
|
> Created Time: Fri Jul 15 10:26:32 2016
|
||||||
|
************************************************************************/
|
||||||
|
|
||||||
|
#include "lane_compare.hpp"
|
||||||
|
|
||||||
|
double LaneCompare::get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2)
|
||||||
|
{
|
||||||
|
if(lane1.size()<2 || lane2.size()<2)
|
||||||
|
{
|
||||||
|
cerr<<"lane size must be greater or equal to 2"<<endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
Mat im1 = Mat::zeros(im_height, im_width, CV_8UC1);
|
||||||
|
Mat im2 = Mat::zeros(im_height, im_width, CV_8UC1);
|
||||||
|
// draw lines on im1 and im2
|
||||||
|
vector<Point2f> p_interp1;
|
||||||
|
vector<Point2f> p_interp2;
|
||||||
|
if(lane1.size() == 2)
|
||||||
|
{
|
||||||
|
p_interp1 = lane1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
p_interp1 = splineSolver.splineInterpTimes(lane1, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(lane2.size() == 2)
|
||||||
|
{
|
||||||
|
p_interp2 = lane2;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
p_interp2 = splineSolver.splineInterpTimes(lane2, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scalar color_white = Scalar(1);
|
||||||
|
for(int n=0; n<p_interp1.size()-1; n++)
|
||||||
|
{
|
||||||
|
line(im1, p_interp1[n], p_interp1[n+1], color_white, lane_width);
|
||||||
|
}
|
||||||
|
for(int n=0; n<p_interp2.size()-1; n++)
|
||||||
|
{
|
||||||
|
line(im2, p_interp2[n], p_interp2[n+1], color_white, lane_width);
|
||||||
|
}
|
||||||
|
|
||||||
|
double sum_1 = cv::sum(im1).val[0];
|
||||||
|
double sum_2 = cv::sum(im2).val[0];
|
||||||
|
double inter_sum = cv::sum(im1.mul(im2)).val[0];
|
||||||
|
double union_sum = sum_1 + sum_2 - inter_sum;
|
||||||
|
double iou = inter_sum / union_sum;
|
||||||
|
return iou;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// resize the lane from Size(curr_width, curr_height) to Size(im_width, im_height)
|
||||||
|
void LaneCompare::resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height)
|
||||||
|
{
|
||||||
|
if(curr_width == im_width && curr_height == im_height)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
double x_scale = im_width/(double)curr_width;
|
||||||
|
double y_scale = im_height/(double)curr_height;
|
||||||
|
for(int n=0; n<curr_lane.size(); n++)
|
||||||
|
{
|
||||||
|
curr_lane[n] = Point2f(curr_lane[n].x*x_scale, curr_lane[n].y*y_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include "spline.hpp"
|
||||||
|
using namespace std;
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
vector<Point2f> Spline::splineInterpTimes(const vector<Point2f>& tmp_line, int times) {
|
||||||
|
vector<Point2f> res;
|
||||||
|
|
||||||
|
if(tmp_line.size() == 2) {
|
||||||
|
double x1 = tmp_line[0].x;
|
||||||
|
double y1 = tmp_line[0].y;
|
||||||
|
double x2 = tmp_line[1].x;
|
||||||
|
double y2 = tmp_line[1].y;
|
||||||
|
|
||||||
|
for (int k = 0; k <= times; k++) {
|
||||||
|
double xi = x1 + double((x2 - x1) * k) / times;
|
||||||
|
double yi = y1 + double((y2 - y1) * k) / times;
|
||||||
|
res.push_back(Point2f(xi, yi));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if(tmp_line.size() > 2)
|
||||||
|
{
|
||||||
|
vector<Func> tmp_func;
|
||||||
|
tmp_func = this->cal_fun(tmp_line);
|
||||||
|
if (tmp_func.empty()) {
|
||||||
|
cout << "in splineInterpTimes: cal_fun failed" << endl;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
for(int j = 0; j < tmp_func.size(); j++)
|
||||||
|
{
|
||||||
|
double delta = tmp_func[j].h / times;
|
||||||
|
for(int k = 0; k < times; k++)
|
||||||
|
{
|
||||||
|
double t1 = delta*k;
|
||||||
|
double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
|
||||||
|
double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
|
||||||
|
res.push_back(Point2f(x1, y1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.push_back(tmp_line[tmp_line.size() - 1]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cerr << "in splineInterpTimes: not enough points" << endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
vector<Point2f> Spline::splineInterpStep(vector<Point2f> tmp_line, double step) {
|
||||||
|
vector<Point2f> res;
|
||||||
|
/*
|
||||||
|
if (tmp_line.size() == 2) {
|
||||||
|
double x1 = tmp_line[0].x;
|
||||||
|
double y1 = tmp_line[0].y;
|
||||||
|
double x2 = tmp_line[1].x;
|
||||||
|
double y2 = tmp_line[1].y;
|
||||||
|
|
||||||
|
for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) {
|
||||||
|
double xi;
|
||||||
|
if (yi == y1) xi = x1;
|
||||||
|
else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1;
|
||||||
|
res.push_back(Point2f(xi, yi));
|
||||||
|
}
|
||||||
|
}*/
|
||||||
|
if (tmp_line.size() == 2) {
|
||||||
|
double x1 = tmp_line[0].x;
|
||||||
|
double y1 = tmp_line[0].y;
|
||||||
|
double x2 = tmp_line[1].x;
|
||||||
|
double y2 = tmp_line[1].y;
|
||||||
|
tmp_line[1].x = (x1 + x2) / 2;
|
||||||
|
tmp_line[1].y = (y1 + y2) / 2;
|
||||||
|
tmp_line.push_back(Point2f(x2, y2));
|
||||||
|
}
|
||||||
|
if (tmp_line.size() > 2) {
|
||||||
|
vector<Func> tmp_func;
|
||||||
|
tmp_func = this->cal_fun(tmp_line);
|
||||||
|
double ystart = tmp_line[0].y;
|
||||||
|
double yend = tmp_line[tmp_line.size() - 1].y;
|
||||||
|
bool down;
|
||||||
|
if (ystart < yend) down = 1;
|
||||||
|
else down = 0;
|
||||||
|
if (tmp_func.empty()) {
|
||||||
|
cerr << "in splineInterpStep: cal_fun failed" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int j = 0; j < tmp_func.size(); j++)
|
||||||
|
{
|
||||||
|
for(double t1 = 0; t1 < tmp_func[j].h; t1 += step)
|
||||||
|
{
|
||||||
|
double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
|
||||||
|
double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
|
||||||
|
res.push_back(Point2f(x1, y1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.push_back(tmp_line[tmp_line.size() - 1]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cerr << "in splineInterpStep: not enough points" << endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Func> Spline::cal_fun(const vector<Point2f> &point_v)
|
||||||
|
{
|
||||||
|
vector<Func> func_v;
|
||||||
|
int n = point_v.size();
|
||||||
|
if(n<=2) {
|
||||||
|
cout << "in cal_fun: point number less than 3" << endl;
|
||||||
|
return func_v;
|
||||||
|
}
|
||||||
|
|
||||||
|
func_v.resize(point_v.size()-1);
|
||||||
|
|
||||||
|
vector<double> Mx(n);
|
||||||
|
vector<double> My(n);
|
||||||
|
vector<double> A(n-2);
|
||||||
|
vector<double> B(n-2);
|
||||||
|
vector<double> C(n-2);
|
||||||
|
vector<double> Dx(n-2);
|
||||||
|
vector<double> Dy(n-2);
|
||||||
|
vector<double> h(n-1);
|
||||||
|
//vector<func> func_v(n-1);
|
||||||
|
|
||||||
|
for(int i = 0; i < n-1; i++)
|
||||||
|
{
|
||||||
|
h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = 0; i < n-2; i++)
|
||||||
|
{
|
||||||
|
A[i] = h[i];
|
||||||
|
B[i] = 2*(h[i]+h[i+1]);
|
||||||
|
C[i] = h[i+1];
|
||||||
|
|
||||||
|
Dx[i] = 6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] );
|
||||||
|
Dy[i] = 6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] );
|
||||||
|
}
|
||||||
|
|
||||||
|
//TDMA
|
||||||
|
C[0] = C[0] / B[0];
|
||||||
|
Dx[0] = Dx[0] / B[0];
|
||||||
|
Dy[0] = Dy[0] / B[0];
|
||||||
|
for(int i = 1; i < n-2; i++)
|
||||||
|
{
|
||||||
|
double tmp = B[i] - A[i]*C[i-1];
|
||||||
|
C[i] = C[i] / tmp;
|
||||||
|
Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp;
|
||||||
|
Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp;
|
||||||
|
}
|
||||||
|
Mx[n-2] = Dx[n-3];
|
||||||
|
My[n-2] = Dy[n-3];
|
||||||
|
for(int i = n-4; i >= 0; i--)
|
||||||
|
{
|
||||||
|
Mx[i+1] = Dx[i] - C[i]*Mx[i+2];
|
||||||
|
My[i+1] = Dy[i] - C[i]*My[i+2];
|
||||||
|
}
|
||||||
|
|
||||||
|
Mx[0] = 0;
|
||||||
|
Mx[n-1] = 0;
|
||||||
|
My[0] = 0;
|
||||||
|
My[n-1] = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < n-1; i++)
|
||||||
|
{
|
||||||
|
func_v[i].a_x = point_v[i].x;
|
||||||
|
func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6;
|
||||||
|
func_v[i].c_x = Mx[i]/2;
|
||||||
|
func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]);
|
||||||
|
|
||||||
|
func_v[i].a_y = point_v[i].y;
|
||||||
|
func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6;
|
||||||
|
func_v[i].c_y = My[i]/2;
|
||||||
|
func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]);
|
||||||
|
|
||||||
|
func_v[i].h = h[i];
|
||||||
|
}
|
||||||
|
return func_v;
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from PIL import Image
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def getLane(probmap, pts, cfg = None):
|
||||||
|
thr = 0.3
|
||||||
|
coordinate = np.zeros(pts)
|
||||||
|
cut_height = 0
|
||||||
|
if cfg.cut_height:
|
||||||
|
cut_height = cfg.cut_height
|
||||||
|
for i in range(pts):
|
||||||
|
line = probmap[round(cfg.img_height-i*20/(590-cut_height)*cfg.img_height)-1]
|
||||||
|
if np.max(line)/255 > thr:
|
||||||
|
coordinate[i] = np.argmax(line)+1
|
||||||
|
if np.sum(coordinate > 0) < 2:
|
||||||
|
coordinate = np.zeros(pts)
|
||||||
|
return coordinate
|
||||||
|
|
||||||
|
|
||||||
|
def prob2lines(prob_dir, out_dir, list_file, cfg = None):
|
||||||
|
lists = pd.read_csv(list_file, sep=' ', header=None,
|
||||||
|
names=('img', 'probmap', 'label1', 'label2', 'label3', 'label4'))
|
||||||
|
pts = 18
|
||||||
|
|
||||||
|
for k, im in enumerate(lists['img'], 1):
|
||||||
|
existPath = prob_dir + im[:-4] + '.exist.txt'
|
||||||
|
outname = out_dir + im[:-4] + '.lines.txt'
|
||||||
|
prefix = '/'.join(outname.split('/')[:-1])
|
||||||
|
if not os.path.exists(prefix):
|
||||||
|
os.makedirs(prefix)
|
||||||
|
f = open(outname, 'w')
|
||||||
|
|
||||||
|
labels = list(pd.read_csv(existPath, sep=' ', header=None).iloc[0])
|
||||||
|
coordinates = np.zeros((4, pts))
|
||||||
|
for i in range(4):
|
||||||
|
if labels[i] == 1:
|
||||||
|
probfile = prob_dir + im[:-4] + '_{0}_avg.png'.format(i+1)
|
||||||
|
probmap = np.array(Image.open(probfile))
|
||||||
|
coordinates[i] = getLane(probmap, pts, cfg)
|
||||||
|
|
||||||
|
if np.sum(coordinates[i] > 0) > 1:
|
||||||
|
for idx, value in enumerate(coordinates[i]):
|
||||||
|
if value > 0:
|
||||||
|
f.write('%d %d ' % (
|
||||||
|
round(value*1640/cfg.img_width)-1, round(590-idx*20)-1))
|
||||||
|
f.write('\n')
|
||||||
|
f.close()
|
|
@ -0,0 +1,115 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def isShort(lane):
|
||||||
|
start = [i for i, x in enumerate(lane) if x > 0]
|
||||||
|
if not start:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def fixGap(coordinate):
|
||||||
|
if any(x > 0 for x in coordinate):
|
||||||
|
start = [i for i, x in enumerate(coordinate) if x > 0][0]
|
||||||
|
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
|
||||||
|
lane = coordinate[start:end+1]
|
||||||
|
if any(x < 0 for x in lane):
|
||||||
|
gap_start = [i for i, x in enumerate(
|
||||||
|
lane[:-1]) if x > 0 and lane[i+1] < 0]
|
||||||
|
gap_end = [i+1 for i,
|
||||||
|
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
|
||||||
|
gap_id = [i for i, x in enumerate(lane) if x < 0]
|
||||||
|
if len(gap_start) == 0 or len(gap_end) == 0:
|
||||||
|
return coordinate
|
||||||
|
for id in gap_id:
|
||||||
|
for i in range(len(gap_start)):
|
||||||
|
if i >= len(gap_end):
|
||||||
|
return coordinate
|
||||||
|
if id > gap_start[i] and id < gap_end[i]:
|
||||||
|
gap_width = float(gap_end[i] - gap_start[i])
|
||||||
|
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
|
||||||
|
gap_end[i] - id) / gap_width * lane[gap_start[i]])
|
||||||
|
if not all(x > 0 for x in lane):
|
||||||
|
print("Gaps still exist!")
|
||||||
|
coordinate[start:end+1] = lane
|
||||||
|
return coordinate
|
||||||
|
|
||||||
|
def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
----------
|
||||||
|
prob_map: prob map for single lane, np array size (h, w)
|
||||||
|
resize_shape: reshape size target, (H, W)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
----------
|
||||||
|
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
|
||||||
|
"""
|
||||||
|
if resize_shape is None:
|
||||||
|
resize_shape = prob_map.shape
|
||||||
|
h, w = prob_map.shape
|
||||||
|
H, W = resize_shape
|
||||||
|
H -= cfg.cut_height
|
||||||
|
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
coords[:] = -1.0
|
||||||
|
for i in range(pts):
|
||||||
|
y = int((H - 10 - i * y_px_gap) * h / H)
|
||||||
|
if y < 0:
|
||||||
|
break
|
||||||
|
line = prob_map[y, :]
|
||||||
|
id = np.argmax(line)
|
||||||
|
if line[id] > thresh:
|
||||||
|
coords[i] = int(id / w * W)
|
||||||
|
if (coords > 0).sum() < 2:
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
fixGap(coords)
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
|
def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
----------
|
||||||
|
seg_pred: np.array size (5, h, w)
|
||||||
|
resize_shape: reshape size target, (H, W)
|
||||||
|
exist: list of existence, e.g. [0, 1, 1, 0]
|
||||||
|
smooth: whether to smooth the probability or not
|
||||||
|
y_px_gap: y pixel gap for sampling
|
||||||
|
pts: how many points for one lane
|
||||||
|
thresh: probability threshold
|
||||||
|
|
||||||
|
Return:
|
||||||
|
----------
|
||||||
|
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
|
||||||
|
"""
|
||||||
|
if resize_shape is None:
|
||||||
|
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
|
||||||
|
_, h, w = seg_pred.shape
|
||||||
|
H, W = resize_shape
|
||||||
|
coordinates = []
|
||||||
|
|
||||||
|
if pts is None:
|
||||||
|
pts = round(H / 2 / y_px_gap)
|
||||||
|
|
||||||
|
seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0)))
|
||||||
|
for i in range(cfg.num_classes - 1):
|
||||||
|
prob_map = seg_pred[..., i + 1]
|
||||||
|
if smooth:
|
||||||
|
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||||
|
coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg)
|
||||||
|
if isShort(coords):
|
||||||
|
continue
|
||||||
|
coordinates.append(
|
||||||
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||||
|
range(pts)])
|
||||||
|
|
||||||
|
|
||||||
|
if len(coordinates) == 0:
|
||||||
|
coords = np.zeros(pts)
|
||||||
|
coordinates.append(
|
||||||
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||||
|
range(pts)])
|
||||||
|
|
||||||
|
|
||||||
|
return coordinates
|
|
@ -0,0 +1,108 @@
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.linear_model import LinearRegression
|
||||||
|
import json as json
|
||||||
|
|
||||||
|
|
||||||
|
class LaneEval(object):
|
||||||
|
lr = LinearRegression()
|
||||||
|
pixel_thresh = 20
|
||||||
|
pt_thresh = 0.85
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_angle(xs, y_samples):
|
||||||
|
xs, ys = xs[xs >= 0], y_samples[xs >= 0]
|
||||||
|
if len(xs) > 1:
|
||||||
|
LaneEval.lr.fit(ys[:, None], xs)
|
||||||
|
k = LaneEval.lr.coef_[0]
|
||||||
|
theta = np.arctan(k)
|
||||||
|
else:
|
||||||
|
theta = 0
|
||||||
|
return theta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def line_accuracy(pred, gt, thresh):
|
||||||
|
pred = np.array([p if p >= 0 else -100 for p in pred])
|
||||||
|
gt = np.array([g if g >= 0 else -100 for g in gt])
|
||||||
|
return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bench(pred, gt, y_samples, running_time):
|
||||||
|
if any(len(p) != len(y_samples) for p in pred):
|
||||||
|
raise Exception('Format of lanes error.')
|
||||||
|
if running_time > 200 or len(gt) + 2 < len(pred):
|
||||||
|
return 0., 0., 1.
|
||||||
|
angles = [LaneEval.get_angle(
|
||||||
|
np.array(x_gts), np.array(y_samples)) for x_gts in gt]
|
||||||
|
threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
|
||||||
|
line_accs = []
|
||||||
|
fp, fn = 0., 0.
|
||||||
|
matched = 0.
|
||||||
|
for x_gts, thresh in zip(gt, threshs):
|
||||||
|
accs = [LaneEval.line_accuracy(
|
||||||
|
np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
|
||||||
|
max_acc = np.max(accs) if len(accs) > 0 else 0.
|
||||||
|
if max_acc < LaneEval.pt_thresh:
|
||||||
|
fn += 1
|
||||||
|
else:
|
||||||
|
matched += 1
|
||||||
|
line_accs.append(max_acc)
|
||||||
|
fp = len(pred) - matched
|
||||||
|
if len(gt) > 4 and fn > 0:
|
||||||
|
fn -= 1
|
||||||
|
s = sum(line_accs)
|
||||||
|
if len(gt) > 4:
|
||||||
|
s -= min(line_accs)
|
||||||
|
return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bench_one_submit(pred_file, gt_file):
|
||||||
|
try:
|
||||||
|
json_pred = [json.loads(line)
|
||||||
|
for line in open(pred_file).readlines()]
|
||||||
|
except BaseException as e:
|
||||||
|
raise Exception('Fail to load json file of the prediction.')
|
||||||
|
json_gt = [json.loads(line) for line in open(gt_file).readlines()]
|
||||||
|
if len(json_gt) != len(json_pred):
|
||||||
|
raise Exception(
|
||||||
|
'We do not get the predictions of all the test tasks')
|
||||||
|
gts = {l['raw_file']: l for l in json_gt}
|
||||||
|
accuracy, fp, fn = 0., 0., 0.
|
||||||
|
for pred in json_pred:
|
||||||
|
if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
|
||||||
|
raise Exception(
|
||||||
|
'raw_file or lanes or run_time not in some predictions.')
|
||||||
|
raw_file = pred['raw_file']
|
||||||
|
pred_lanes = pred['lanes']
|
||||||
|
run_time = pred['run_time']
|
||||||
|
if raw_file not in gts:
|
||||||
|
raise Exception(
|
||||||
|
'Some raw_file from your predictions do not exist in the test tasks.')
|
||||||
|
gt = gts[raw_file]
|
||||||
|
gt_lanes = gt['lanes']
|
||||||
|
y_samples = gt['h_samples']
|
||||||
|
try:
|
||||||
|
a, p, n = LaneEval.bench(
|
||||||
|
pred_lanes, gt_lanes, y_samples, run_time)
|
||||||
|
except BaseException as e:
|
||||||
|
raise Exception('Format of lanes error.')
|
||||||
|
accuracy += a
|
||||||
|
fp += p
|
||||||
|
fn += n
|
||||||
|
num = len(gts)
|
||||||
|
# the first return parameter is the default ranking parameter
|
||||||
|
return json.dumps([
|
||||||
|
{'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'},
|
||||||
|
{'name': 'FP', 'value': fp / num, 'order': 'asc'},
|
||||||
|
{'name': 'FN', 'value': fn / num, 'order': 'asc'}
|
||||||
|
]), accuracy / num
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
try:
|
||||||
|
if len(sys.argv) != 3:
|
||||||
|
raise Exception('Invalid input arguments')
|
||||||
|
print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
|
||||||
|
except Exception as e:
|
||||||
|
print(e.message)
|
||||||
|
sys.exit(e.message)
|
|
@ -0,0 +1,111 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from runner.logger import get_logger
|
||||||
|
|
||||||
|
from runner.registry import EVALUATOR
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from .lane import LaneEval
|
||||||
|
|
||||||
|
def split_path(path):
|
||||||
|
"""split path tree into list"""
|
||||||
|
folders = []
|
||||||
|
while True:
|
||||||
|
path, folder = os.path.split(path)
|
||||||
|
if folder != "":
|
||||||
|
folders.insert(0, folder)
|
||||||
|
else:
|
||||||
|
if path != "":
|
||||||
|
folders.insert(0, path)
|
||||||
|
break
|
||||||
|
return folders
|
||||||
|
|
||||||
|
|
||||||
|
@EVALUATOR.register_module
|
||||||
|
class Tusimple(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(Tusimple, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
exp_dir = os.path.join(self.cfg.work_dir, "output")
|
||||||
|
if not os.path.exists(exp_dir):
|
||||||
|
os.mkdir(exp_dir)
|
||||||
|
self.out_path = os.path.join(exp_dir, "coord_output")
|
||||||
|
if not os.path.exists(self.out_path):
|
||||||
|
os.mkdir(self.out_path)
|
||||||
|
self.dump_to_json = []
|
||||||
|
self.thresh = cfg.evaluator.thresh
|
||||||
|
self.logger = get_logger('resa')
|
||||||
|
if cfg.view:
|
||||||
|
self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
|
||||||
|
|
||||||
|
def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
|
||||||
|
img_name = batch['meta']['img_name']
|
||||||
|
img_path = batch['meta']['full_img_path']
|
||||||
|
for b in range(len(seg_pred)):
|
||||||
|
seg = seg_pred[b]
|
||||||
|
exist = [1 if exist_pred[b, i] >
|
||||||
|
0.5 else 0 for i in range(self.cfg.num_classes-1)]
|
||||||
|
lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
|
||||||
|
for i in range(len(lane_coords)):
|
||||||
|
lane_coords[i] = sorted(
|
||||||
|
lane_coords[i], key=lambda pair: pair[1])
|
||||||
|
|
||||||
|
path_tree = split_path(img_name[b])
|
||||||
|
save_dir, save_name = path_tree[-3:-1], path_tree[-1]
|
||||||
|
save_dir = os.path.join(self.out_path, *save_dir)
|
||||||
|
save_name = save_name[:-3] + "lines.txt"
|
||||||
|
save_name = os.path.join(save_dir, save_name)
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(save_name, "w") as f:
|
||||||
|
for l in lane_coords:
|
||||||
|
for (x, y) in l:
|
||||||
|
print("{} {}".format(x, y), end=" ", file=f)
|
||||||
|
print(file=f)
|
||||||
|
|
||||||
|
json_dict = {}
|
||||||
|
json_dict['lanes'] = []
|
||||||
|
json_dict['h_sample'] = []
|
||||||
|
json_dict['raw_file'] = os.path.join(*path_tree[-4:])
|
||||||
|
json_dict['run_time'] = 0
|
||||||
|
for l in lane_coords:
|
||||||
|
if len(l) == 0:
|
||||||
|
continue
|
||||||
|
json_dict['lanes'].append([])
|
||||||
|
for (x, y) in l:
|
||||||
|
json_dict['lanes'][-1].append(int(x))
|
||||||
|
for (x, y) in lane_coords[0]:
|
||||||
|
json_dict['h_sample'].append(y)
|
||||||
|
self.dump_to_json.append(json.dumps(json_dict))
|
||||||
|
if self.cfg.view:
|
||||||
|
img = cv2.imread(img_path[b])
|
||||||
|
new_img_name = img_name[b].replace('/', '_')
|
||||||
|
save_dir = os.path.join(self.view_dir, new_img_name)
|
||||||
|
dataset.view(img, lane_coords, save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self, dataset, output, batch):
|
||||||
|
seg_pred, exist_pred = output['seg'], output['exist']
|
||||||
|
seg_pred = F.softmax(seg_pred, dim=1)
|
||||||
|
seg_pred = seg_pred.detach().cpu().numpy()
|
||||||
|
exist_pred = exist_pred.detach().cpu().numpy()
|
||||||
|
self.evaluate_pred(dataset, seg_pred, exist_pred, batch)
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
best_acc = 0
|
||||||
|
output_file = os.path.join(self.out_path, 'predict_test.json')
|
||||||
|
with open(output_file, "w+") as f:
|
||||||
|
for line in self.dump_to_json:
|
||||||
|
print(line, end="\n", file=f)
|
||||||
|
|
||||||
|
eval_result, acc = LaneEval.bench_one_submit(output_file,
|
||||||
|
self.cfg.test_json_file)
|
||||||
|
|
||||||
|
self.logger.info(eval_result)
|
||||||
|
self.dump_to_json = []
|
||||||
|
best_acc = max(acc, best_acc)
|
||||||
|
return best_acc
|
|
@ -0,0 +1,50 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger_initialized = {}
|
||||||
|
|
||||||
|
def get_logger(name, log_file=None, log_level=logging.INFO):
|
||||||
|
"""Initialize and get a logger by name.
|
||||||
|
If the logger has not been initialized, this method will initialize the
|
||||||
|
logger by adding one or two handlers, otherwise the initialized logger will
|
||||||
|
be directly returned. During initialization, a StreamHandler will always be
|
||||||
|
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
||||||
|
will also be added.
|
||||||
|
Args:
|
||||||
|
name (str): Logger name.
|
||||||
|
log_file (str | None): The log filename. If specified, a FileHandler
|
||||||
|
will be added to the logger.
|
||||||
|
log_level (int): The logger level. Note that only the process of
|
||||||
|
rank 0 is affected, and other processes will set the level to
|
||||||
|
"Error" thus be silent most of the time.
|
||||||
|
Returns:
|
||||||
|
logging.Logger: The expected logger.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if name in logger_initialized:
|
||||||
|
return logger
|
||||||
|
# handle hierarchical names
|
||||||
|
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
||||||
|
# initialization since it is a child of "a".
|
||||||
|
for logger_name in logger_initialized:
|
||||||
|
if name.startswith(logger_name):
|
||||||
|
return logger
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
handlers = [stream_handler]
|
||||||
|
|
||||||
|
if log_file is not None:
|
||||||
|
file_handler = logging.FileHandler(log_file, 'w')
|
||||||
|
handlers.append(file_handler)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in handlers:
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
handler.setLevel(log_level)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
logger_initialized[name] = True
|
||||||
|
|
||||||
|
return logger
|
|
@ -0,0 +1,43 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional
|
||||||
|
from termcolor import colored
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
def save_model(net, optim, scheduler, recorder, is_best=False):
|
||||||
|
model_dir = os.path.join(recorder.work_dir, 'ckpt')
|
||||||
|
os.system('mkdir -p {}'.format(model_dir))
|
||||||
|
epoch = recorder.epoch
|
||||||
|
ckpt_name = 'best' if is_best else epoch
|
||||||
|
torch.save({
|
||||||
|
'net': net.state_dict(),
|
||||||
|
'optim': optim.state_dict(),
|
||||||
|
'scheduler': scheduler.state_dict(),
|
||||||
|
'recorder': recorder.state_dict(),
|
||||||
|
'epoch': epoch
|
||||||
|
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
|
||||||
|
|
||||||
|
|
||||||
|
def load_network_specified(net, model_dir, logger=None):
|
||||||
|
pretrained_net = torch.load(model_dir)['net']
|
||||||
|
net_state = net.state_dict()
|
||||||
|
state = {}
|
||||||
|
for k, v in pretrained_net.items():
|
||||||
|
if k not in net_state.keys() or v.size() != net_state[k].size():
|
||||||
|
if logger:
|
||||||
|
logger.info('skip weights: ' + k)
|
||||||
|
continue
|
||||||
|
state[k] = v
|
||||||
|
net.load_state_dict(state, strict=False)
|
||||||
|
|
||||||
|
|
||||||
|
def load_network(net, model_dir, finetune_from=None, logger=None):
|
||||||
|
if finetune_from:
|
||||||
|
if logger:
|
||||||
|
logger.info('Finetune model from: ' + finetune_from)
|
||||||
|
load_network_specified(net, finetune_from, logger)
|
||||||
|
return
|
||||||
|
pretrained_model = torch.load(model_dir)
|
||||||
|
net.load_state_dict(pretrained_model['net'], strict=True)
|
|
@ -0,0 +1,26 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_optimizer_factory = {
|
||||||
|
'adam': torch.optim.Adam,
|
||||||
|
'sgd': torch.optim.SGD
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(cfg, net):
|
||||||
|
params = []
|
||||||
|
lr = cfg.optimizer.lr
|
||||||
|
weight_decay = cfg.optimizer.weight_decay
|
||||||
|
|
||||||
|
for key, value in net.named_parameters():
|
||||||
|
if not value.requires_grad:
|
||||||
|
continue
|
||||||
|
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||||
|
|
||||||
|
if 'adam' in cfg.optimizer.type:
|
||||||
|
optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
|
||||||
|
else:
|
||||||
|
optimizer = _optimizer_factory[cfg.optimizer.type](
|
||||||
|
params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)
|
||||||
|
|
||||||
|
return optimizer
|
|
@ -0,0 +1,100 @@
|
||||||
|
from collections import deque, defaultdict
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20):
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, value):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += 1
|
||||||
|
self.total += value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
|
||||||
|
class Recorder(object):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.work_dir = self.get_work_dir()
|
||||||
|
cfg.work_dir = self.work_dir
|
||||||
|
self.log_path = os.path.join(self.work_dir, 'log.txt')
|
||||||
|
|
||||||
|
self.logger = get_logger('resa', self.log_path)
|
||||||
|
self.logger.info('Config: \n' + cfg.text)
|
||||||
|
|
||||||
|
# scalars
|
||||||
|
self.epoch = 0
|
||||||
|
self.step = 0
|
||||||
|
self.loss_stats = defaultdict(SmoothedValue)
|
||||||
|
self.batch_time = SmoothedValue()
|
||||||
|
self.data_time = SmoothedValue()
|
||||||
|
self.max_iter = self.cfg.total_iter
|
||||||
|
self.lr = 0.
|
||||||
|
|
||||||
|
def get_work_dir(self):
|
||||||
|
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)
|
||||||
|
work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
|
||||||
|
if not os.path.exists(work_dir):
|
||||||
|
os.makedirs(work_dir)
|
||||||
|
return work_dir
|
||||||
|
|
||||||
|
def update_loss_stats(self, loss_dict):
|
||||||
|
for k, v in loss_dict.items():
|
||||||
|
self.loss_stats[k].update(v.detach().cpu())
|
||||||
|
|
||||||
|
def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
|
||||||
|
self.logger.info(self)
|
||||||
|
# self.write(str(self))
|
||||||
|
|
||||||
|
def write(self, content):
|
||||||
|
with open(self.log_path, 'a+') as f:
|
||||||
|
f.write(content)
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
scalar_dict = {}
|
||||||
|
scalar_dict['step'] = self.step
|
||||||
|
return scalar_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, scalar_dict):
|
||||||
|
self.step = scalar_dict['step']
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_state = []
|
||||||
|
for k, v in self.loss_stats.items():
|
||||||
|
loss_state.append('{}: {:.4f}'.format(k, v.avg))
|
||||||
|
loss_state = ' '.join(loss_state)
|
||||||
|
|
||||||
|
recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
|
||||||
|
eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)
|
||||||
|
|
||||||
|
|
||||||
|
def build_recorder(cfg):
|
||||||
|
return Recorder(cfg)
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
from utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
TRAINER = Registry('trainer')
|
||||||
|
EVALUATOR = Registry('evaluator')
|
||||||
|
|
||||||
|
def build(cfg, registry, default_args=None):
|
||||||
|
if isinstance(cfg, list):
|
||||||
|
modules = [
|
||||||
|
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||||
|
]
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
else:
|
||||||
|
return build_from_cfg(cfg, registry, default_args)
|
||||||
|
|
||||||
|
def build_trainer(cfg):
|
||||||
|
return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))
|
||||||
|
|
||||||
|
def build_evaluator(cfg):
|
||||||
|
return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))
|
|
@ -0,0 +1,58 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from runner.registry import TRAINER
|
||||||
|
|
||||||
|
def dice_loss(input, target):
|
||||||
|
input = input.contiguous().view(input.size()[0], -1)
|
||||||
|
target = target.contiguous().view(target.size()[0], -1).float()
|
||||||
|
|
||||||
|
a = torch.sum(input * target, 1)
|
||||||
|
b = torch.sum(input * input, 1) + 0.001
|
||||||
|
c = torch.sum(target * target, 1) + 0.001
|
||||||
|
d = (2 * a) / (b + c)
|
||||||
|
return (1-d).mean()
|
||||||
|
|
||||||
|
@TRAINER.register_module
|
||||||
|
class RESA(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(RESA, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.loss_type = cfg.loss_type
|
||||||
|
if self.loss_type == 'cross_entropy':
|
||||||
|
weights = torch.ones(cfg.num_classes)
|
||||||
|
weights[0] = cfg.bg_weight
|
||||||
|
weights = weights.cuda()
|
||||||
|
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
|
||||||
|
weight=weights).cuda()
|
||||||
|
|
||||||
|
self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
|
||||||
|
|
||||||
|
def forward(self, net, batch):
|
||||||
|
output = net(batch['img'])
|
||||||
|
|
||||||
|
loss_stats = {}
|
||||||
|
loss = 0.
|
||||||
|
|
||||||
|
if self.loss_type == 'dice_loss':
|
||||||
|
target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
|
||||||
|
seg_loss = dice_loss(F.softmax(
|
||||||
|
output['seg'], dim=1)[:, 1:], target[:, 1:])
|
||||||
|
else:
|
||||||
|
seg_loss = self.criterion(F.log_softmax(
|
||||||
|
output['seg'], dim=1), batch['label'].long())
|
||||||
|
|
||||||
|
loss += seg_loss * self.cfg.seg_loss_weight
|
||||||
|
|
||||||
|
loss_stats.update({'seg_loss': seg_loss})
|
||||||
|
|
||||||
|
if 'exist' in output:
|
||||||
|
exist_loss = 0.1 * \
|
||||||
|
self.criterion_exist(output['exist'], batch['exist'].float())
|
||||||
|
loss += exist_loss
|
||||||
|
loss_stats.update({'exist_loss': exist_loss})
|
||||||
|
|
||||||
|
ret = {'loss': loss, 'loss_stats': loss_stats}
|
||||||
|
|
||||||
|
return ret
|
|
@ -0,0 +1,112 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pytorch_warmup as warmup
|
||||||
|
|
||||||
|
from models.registry import build_net
|
||||||
|
from .registry import build_trainer, build_evaluator
|
||||||
|
from .optimizer import build_optimizer
|
||||||
|
from .scheduler import build_scheduler
|
||||||
|
from datasets import build_dataloader
|
||||||
|
from .recorder import build_recorder
|
||||||
|
from .net_utils import save_model, load_network
|
||||||
|
|
||||||
|
|
||||||
|
class Runner(object):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.recorder = build_recorder(self.cfg)
|
||||||
|
self.net = build_net(self.cfg)
|
||||||
|
self.net = torch.nn.parallel.DataParallel(
|
||||||
|
self.net, device_ids = range(self.cfg.gpus)).cuda()
|
||||||
|
self.recorder.logger.info('Network: \n' + str(self.net))
|
||||||
|
self.resume()
|
||||||
|
self.optimizer = build_optimizer(self.cfg, self.net)
|
||||||
|
self.scheduler = build_scheduler(self.cfg, self.optimizer)
|
||||||
|
self.evaluator = build_evaluator(self.cfg)
|
||||||
|
self.warmup_scheduler = warmup.LinearWarmup(
|
||||||
|
self.optimizer, warmup_period=5000)
|
||||||
|
self.metric = 0.
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
if not self.cfg.load_from and not self.cfg.finetune_from:
|
||||||
|
return
|
||||||
|
load_network(self.net, self.cfg.load_from,
|
||||||
|
finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)
|
||||||
|
|
||||||
|
def to_cuda(self, batch):
|
||||||
|
for k in batch:
|
||||||
|
if k == 'meta':
|
||||||
|
continue
|
||||||
|
batch[k] = batch[k].cuda()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def train_epoch(self, epoch, train_loader):
|
||||||
|
self.net.train()
|
||||||
|
end = time.time()
|
||||||
|
max_iter = len(train_loader)
|
||||||
|
for i, data in enumerate(train_loader):
|
||||||
|
if self.recorder.step >= self.cfg.total_iter:
|
||||||
|
break
|
||||||
|
date_time = time.time() - end
|
||||||
|
self.recorder.step += 1
|
||||||
|
data = self.to_cuda(data)
|
||||||
|
output = self.trainer.forward(self.net, data)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss = output['loss']
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.warmup_scheduler.dampen()
|
||||||
|
batch_time = time.time() - end
|
||||||
|
end = time.time()
|
||||||
|
self.recorder.update_loss_stats(output['loss_stats'])
|
||||||
|
self.recorder.batch_time.update(batch_time)
|
||||||
|
self.recorder.data_time.update(date_time)
|
||||||
|
|
||||||
|
if i % self.cfg.log_interval == 0 or i == max_iter - 1:
|
||||||
|
lr = self.optimizer.param_groups[0]['lr']
|
||||||
|
self.recorder.lr = lr
|
||||||
|
self.recorder.record('train')
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.recorder.logger.info('start training...')
|
||||||
|
self.trainer = build_trainer(self.cfg)
|
||||||
|
train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
|
||||||
|
val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)
|
||||||
|
|
||||||
|
for epoch in range(self.cfg.epochs):
|
||||||
|
print('Epoch: [{}/{}]'.format(self.recorder.step, self.cfg.total_iter))
|
||||||
|
print('Epoch: [{}/{}]'.format(epoch, self.cfg.epochs))
|
||||||
|
self.recorder.epoch = epoch
|
||||||
|
self.train_epoch(epoch, train_loader)
|
||||||
|
if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
|
||||||
|
self.save_ckpt()
|
||||||
|
if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
|
||||||
|
self.validate(val_loader)
|
||||||
|
if self.recorder.step >= self.cfg.total_iter:
|
||||||
|
break
|
||||||
|
|
||||||
|
def validate(self, val_loader):
|
||||||
|
self.net.eval()
|
||||||
|
count = 10
|
||||||
|
for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
|
||||||
|
start_time = time.time()
|
||||||
|
data = self.to_cuda(data)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.net(data['img'])
|
||||||
|
self.evaluator.evaluate(val_loader.dataset, output, data)
|
||||||
|
# print("第{}张图片检测花了{}秒".format(i,time.time()-start_time))
|
||||||
|
|
||||||
|
metric = self.evaluator.summarize()
|
||||||
|
if not metric:
|
||||||
|
return
|
||||||
|
if metric > self.metric:
|
||||||
|
self.metric = metric
|
||||||
|
self.save_ckpt(is_best=True)
|
||||||
|
self.recorder.logger.info('Best metric: ' + str(self.metric))
|
||||||
|
|
||||||
|
def save_ckpt(self, is_best=False):
|
||||||
|
save_model(self.net, self.optimizer, self.scheduler,
|
||||||
|
self.recorder, is_best)
|
|
@ -0,0 +1,20 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
_scheduler_factory = {
|
||||||
|
'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_scheduler(cfg, optimizer):
|
||||||
|
|
||||||
|
assert cfg.scheduler.type in _scheduler_factory
|
||||||
|
|
||||||
|
cfg_cp = cfg.scheduler.copy()
|
||||||
|
cfg_cp.pop('type')
|
||||||
|
|
||||||
|
scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)
|
||||||
|
|
||||||
|
|
||||||
|
return scheduler
|
|
@ -0,0 +1,105 @@
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json']
|
||||||
|
VAL_SET = ['label_data_0531.json']
|
||||||
|
TRAIN_VAL_SET = TRAIN_SET + VAL_SET
|
||||||
|
TEST_SET = ['test_label.json']
|
||||||
|
|
||||||
|
def gen_label_for_json(args, image_set):
|
||||||
|
H, W = 720, 1280
|
||||||
|
SEG_WIDTH = 30
|
||||||
|
save_dir = args.savedir
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True)
|
||||||
|
list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w")
|
||||||
|
|
||||||
|
json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set))
|
||||||
|
with open(json_path) as f:
|
||||||
|
for line in f:
|
||||||
|
label = json.loads(line)
|
||||||
|
# ---------- clean and sort lanes -------------
|
||||||
|
lanes = []
|
||||||
|
_lanes = []
|
||||||
|
slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope
|
||||||
|
for i in range(len(label['lanes'])):
|
||||||
|
l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0]
|
||||||
|
if (len(l)>1):
|
||||||
|
_lanes.append(l)
|
||||||
|
slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180)
|
||||||
|
_lanes = [_lanes[i] for i in np.argsort(slope)]
|
||||||
|
slope = [slope[i] for i in np.argsort(slope)]
|
||||||
|
|
||||||
|
idx = [None for i in range(6)]
|
||||||
|
for i in range(len(slope)):
|
||||||
|
if slope[i] <= 90:
|
||||||
|
idx[2] = i
|
||||||
|
idx[1] = i-1 if i > 0 else None
|
||||||
|
idx[0] = i-2 if i > 1 else None
|
||||||
|
else:
|
||||||
|
idx[3] = i
|
||||||
|
idx[4] = i+1 if i+1 < len(slope) else None
|
||||||
|
idx[5] = i+2 if i+2 < len(slope) else None
|
||||||
|
break
|
||||||
|
for i in range(6):
|
||||||
|
lanes.append([] if idx[i] is None else _lanes[idx[i]])
|
||||||
|
|
||||||
|
# ---------------------------------------------
|
||||||
|
|
||||||
|
img_path = label['raw_file']
|
||||||
|
seg_img = np.zeros((H, W, 3))
|
||||||
|
list_str = [] # str to be written to list.txt
|
||||||
|
for i in range(len(lanes)):
|
||||||
|
coords = lanes[i]
|
||||||
|
if len(coords) < 4:
|
||||||
|
list_str.append('0')
|
||||||
|
continue
|
||||||
|
for j in range(len(coords)-1):
|
||||||
|
cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2)
|
||||||
|
list_str.append('1')
|
||||||
|
|
||||||
|
seg_path = img_path.split("/")
|
||||||
|
seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3]
|
||||||
|
os.makedirs(seg_path, exist_ok=True)
|
||||||
|
seg_path = os.path.join(seg_path, img_name[:-3]+"png")
|
||||||
|
cv2.imwrite(seg_path, seg_img)
|
||||||
|
|
||||||
|
seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"])
|
||||||
|
if seg_path[0] != '/':
|
||||||
|
seg_path = '/' + seg_path
|
||||||
|
if img_path[0] != '/':
|
||||||
|
img_path = '/' + img_path
|
||||||
|
list_str.insert(0, seg_path)
|
||||||
|
list_str.insert(0, img_path)
|
||||||
|
list_str = " ".join(list_str) + "\n"
|
||||||
|
list_f.write(list_str)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_json_file(save_dir, json_file, image_set):
|
||||||
|
with open(os.path.join(save_dir, json_file), "w") as outfile:
|
||||||
|
for json_name in (image_set):
|
||||||
|
with open(os.path.join(args.root, json_name)) as infile:
|
||||||
|
for line in infile:
|
||||||
|
outfile.write(line)
|
||||||
|
|
||||||
|
def generate_label(args):
|
||||||
|
save_dir = os.path.join(args.root, args.savedir)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET)
|
||||||
|
generate_json_file(save_dir, "test.json", TEST_SET)
|
||||||
|
|
||||||
|
print("generating train_val set...")
|
||||||
|
gen_label_for_json(args, 'train_val')
|
||||||
|
print("generating test set...")
|
||||||
|
gen_label_for_json(args, 'test')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--root', required=True, help='The root of the Tusimple dataset')
|
||||||
|
parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
generate_label(args)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .config import Config
|
||||||
|
from .registry import Registry, build_from_cfg
|
|
@ -0,0 +1,417 @@
|
||||||
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
import ast
|
||||||
|
import os.path as osp
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from argparse import Action, ArgumentParser
|
||||||
|
from collections import abc
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from addict import Dict
|
||||||
|
from yapf.yapflib.yapf_api import FormatCode
|
||||||
|
|
||||||
|
|
||||||
|
BASE_KEY = '_base_'
|
||||||
|
DELETE_KEY = '_delete_'
|
||||||
|
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
|
||||||
|
|
||||||
|
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||||
|
if not osp.isfile(filename):
|
||||||
|
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigDict(Dict):
|
||||||
|
|
||||||
|
def __missing__(self, name):
|
||||||
|
raise KeyError(name)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
value = super(ConfigDict, self).__getattr__(name)
|
||||||
|
except KeyError:
|
||||||
|
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
||||||
|
f"attribute '{name}'")
|
||||||
|
except Exception as e:
|
||||||
|
ex = e
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(parser, cfg, prefix=''):
|
||||||
|
for k, v in cfg.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
parser.add_argument('--' + prefix + k)
|
||||||
|
elif isinstance(v, int):
|
||||||
|
parser.add_argument('--' + prefix + k, type=int)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
parser.add_argument('--' + prefix + k, type=float)
|
||||||
|
elif isinstance(v, bool):
|
||||||
|
parser.add_argument('--' + prefix + k, action='store_true')
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
add_args(parser, v, prefix + k + '.')
|
||||||
|
elif isinstance(v, abc.Iterable):
|
||||||
|
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
|
||||||
|
else:
|
||||||
|
print(f'cannot parse key {prefix + k} of type {type(v)}')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""A facility for config and config files.
|
||||||
|
It supports common file formats as configs: python/json/yaml. The interface
|
||||||
|
is the same as a dict object and also allows access config values as
|
||||||
|
attributes.
|
||||||
|
Example:
|
||||||
|
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||||
|
>>> cfg.a
|
||||||
|
1
|
||||||
|
>>> cfg.b
|
||||||
|
{'b1': [0, 1]}
|
||||||
|
>>> cfg.b.b1
|
||||||
|
[0, 1]
|
||||||
|
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
||||||
|
>>> cfg.filename
|
||||||
|
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
||||||
|
>>> cfg.item4
|
||||||
|
'test'
|
||||||
|
>>> cfg
|
||||||
|
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
||||||
|
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_py_syntax(filename):
|
||||||
|
with open(filename) as f:
|
||||||
|
content = f.read()
|
||||||
|
try:
|
||||||
|
ast.parse(content)
|
||||||
|
except SyntaxError:
|
||||||
|
raise SyntaxError('There are syntax errors in config '
|
||||||
|
f'file {filename}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _file2dict(filename):
|
||||||
|
filename = osp.abspath(osp.expanduser(filename))
|
||||||
|
check_file_exist(filename)
|
||||||
|
if filename.endswith('.py'):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||||
|
temp_config_file = tempfile.NamedTemporaryFile(
|
||||||
|
dir=temp_config_dir, suffix='.py')
|
||||||
|
temp_config_name = osp.basename(temp_config_file.name)
|
||||||
|
shutil.copyfile(filename,
|
||||||
|
osp.join(temp_config_dir, temp_config_name))
|
||||||
|
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||||
|
sys.path.insert(0, temp_config_dir)
|
||||||
|
Config._validate_py_syntax(filename)
|
||||||
|
mod = import_module(temp_module_name)
|
||||||
|
sys.path.pop(0)
|
||||||
|
cfg_dict = {
|
||||||
|
name: value
|
||||||
|
for name, value in mod.__dict__.items()
|
||||||
|
if not name.startswith('__')
|
||||||
|
}
|
||||||
|
# delete imported module
|
||||||
|
del sys.modules[temp_module_name]
|
||||||
|
# close temp file
|
||||||
|
temp_config_file.close()
|
||||||
|
elif filename.endswith(('.yml', '.yaml', '.json')):
|
||||||
|
import mmcv
|
||||||
|
cfg_dict = mmcv.load(filename)
|
||||||
|
else:
|
||||||
|
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||||
|
|
||||||
|
cfg_text = filename + '\n'
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
cfg_text += f.read()
|
||||||
|
|
||||||
|
if BASE_KEY in cfg_dict:
|
||||||
|
cfg_dir = osp.dirname(filename)
|
||||||
|
base_filename = cfg_dict.pop(BASE_KEY)
|
||||||
|
base_filename = base_filename if isinstance(
|
||||||
|
base_filename, list) else [base_filename]
|
||||||
|
|
||||||
|
cfg_dict_list = list()
|
||||||
|
cfg_text_list = list()
|
||||||
|
for f in base_filename:
|
||||||
|
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
||||||
|
cfg_dict_list.append(_cfg_dict)
|
||||||
|
cfg_text_list.append(_cfg_text)
|
||||||
|
|
||||||
|
base_cfg_dict = dict()
|
||||||
|
for c in cfg_dict_list:
|
||||||
|
if len(base_cfg_dict.keys() & c.keys()) > 0:
|
||||||
|
raise KeyError('Duplicate key is not allowed among bases')
|
||||||
|
base_cfg_dict.update(c)
|
||||||
|
|
||||||
|
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||||
|
cfg_dict = base_cfg_dict
|
||||||
|
|
||||||
|
# merge cfg_text
|
||||||
|
cfg_text_list.append(cfg_text)
|
||||||
|
cfg_text = '\n'.join(cfg_text_list)
|
||||||
|
|
||||||
|
return cfg_dict, cfg_text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_a_into_b(a, b):
|
||||||
|
# merge dict `a` into dict `b` (non-inplace). values in `a` will
|
||||||
|
# overwrite `b`.
|
||||||
|
# copy first to avoid inplace modification
|
||||||
|
b = b.copy()
|
||||||
|
for k, v in a.items():
|
||||||
|
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||||
|
if not isinstance(b[k], dict):
|
||||||
|
raise TypeError(
|
||||||
|
f'{k}={v} in child config cannot inherit from base '
|
||||||
|
f'because {k} is a dict in the child config but is of '
|
||||||
|
f'type {type(b[k])} in base config. You may set '
|
||||||
|
f'`{DELETE_KEY}=True` to ignore the base config')
|
||||||
|
b[k] = Config._merge_a_into_b(v, b[k])
|
||||||
|
else:
|
||||||
|
b[k] = v
|
||||||
|
return b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromfile(filename):
|
||||||
|
cfg_dict, cfg_text = Config._file2dict(filename)
|
||||||
|
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auto_argparser(description=None):
|
||||||
|
"""Generate argparser from config file automatically (experimental)
|
||||||
|
"""
|
||||||
|
partial_parser = ArgumentParser(description=description)
|
||||||
|
partial_parser.add_argument('config', help='config file path')
|
||||||
|
cfg_file = partial_parser.parse_known_args()[0].config
|
||||||
|
cfg = Config.fromfile(cfg_file)
|
||||||
|
parser = ArgumentParser(description=description)
|
||||||
|
parser.add_argument('config', help='config file path')
|
||||||
|
add_args(parser, cfg)
|
||||||
|
return parser, cfg
|
||||||
|
|
||||||
|
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||||
|
if cfg_dict is None:
|
||||||
|
cfg_dict = dict()
|
||||||
|
elif not isinstance(cfg_dict, dict):
|
||||||
|
raise TypeError('cfg_dict must be a dict, but '
|
||||||
|
f'got {type(cfg_dict)}')
|
||||||
|
for key in cfg_dict:
|
||||||
|
if key in RESERVED_KEYS:
|
||||||
|
raise KeyError(f'{key} is reserved for config file')
|
||||||
|
|
||||||
|
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||||
|
super(Config, self).__setattr__('_filename', filename)
|
||||||
|
if cfg_text:
|
||||||
|
text = cfg_text
|
||||||
|
elif filename:
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
text = f.read()
|
||||||
|
else:
|
||||||
|
text = ''
|
||||||
|
super(Config, self).__setattr__('_text', text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self):
|
||||||
|
return self._filename
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return self._text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pretty_text(self):
|
||||||
|
|
||||||
|
indent = 4
|
||||||
|
|
||||||
|
def _indent(s_, num_spaces):
|
||||||
|
s = s_.split('\n')
|
||||||
|
if len(s) == 1:
|
||||||
|
return s_
|
||||||
|
first = s.pop(0)
|
||||||
|
s = [(num_spaces * ' ') + line for line in s]
|
||||||
|
s = '\n'.join(s)
|
||||||
|
s = first + '\n' + s
|
||||||
|
return s
|
||||||
|
|
||||||
|
def _format_basic_types(k, v, use_mapping=False):
|
||||||
|
if isinstance(v, str):
|
||||||
|
v_str = f"'{v}'"
|
||||||
|
else:
|
||||||
|
v_str = str(v)
|
||||||
|
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f'{k_str}: {v_str}'
|
||||||
|
else:
|
||||||
|
attr_str = f'{str(k)}={v_str}'
|
||||||
|
attr_str = _indent(attr_str, indent)
|
||||||
|
|
||||||
|
return attr_str
|
||||||
|
|
||||||
|
def _format_list(k, v, use_mapping=False):
|
||||||
|
# check if all items in the list are dict
|
||||||
|
if all(isinstance(_, dict) for _ in v):
|
||||||
|
v_str = '[\n'
|
||||||
|
v_str += '\n'.join(
|
||||||
|
f'dict({_indent(_format_dict(v_), indent)}),'
|
||||||
|
for v_ in v).rstrip(',')
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f'{k_str}: {v_str}'
|
||||||
|
else:
|
||||||
|
attr_str = f'{str(k)}={v_str}'
|
||||||
|
attr_str = _indent(attr_str, indent) + ']'
|
||||||
|
else:
|
||||||
|
attr_str = _format_basic_types(k, v, use_mapping)
|
||||||
|
return attr_str
|
||||||
|
|
||||||
|
def _contain_invalid_identifier(dict_str):
|
||||||
|
contain_invalid_identifier = False
|
||||||
|
for key_name in dict_str:
|
||||||
|
contain_invalid_identifier |= \
|
||||||
|
(not str(key_name).isidentifier())
|
||||||
|
return contain_invalid_identifier
|
||||||
|
|
||||||
|
def _format_dict(input_dict, outest_level=False):
|
||||||
|
r = ''
|
||||||
|
s = []
|
||||||
|
|
||||||
|
use_mapping = _contain_invalid_identifier(input_dict)
|
||||||
|
if use_mapping:
|
||||||
|
r += '{'
|
||||||
|
for idx, (k, v) in enumerate(input_dict.items()):
|
||||||
|
is_last = idx >= len(input_dict) - 1
|
||||||
|
end = '' if outest_level or is_last else ','
|
||||||
|
if isinstance(v, dict):
|
||||||
|
v_str = '\n' + _format_dict(v)
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f'{k_str}: dict({v_str}'
|
||||||
|
else:
|
||||||
|
attr_str = f'{str(k)}=dict({v_str}'
|
||||||
|
attr_str = _indent(attr_str, indent) + ')' + end
|
||||||
|
elif isinstance(v, list):
|
||||||
|
attr_str = _format_list(k, v, use_mapping) + end
|
||||||
|
else:
|
||||||
|
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||||
|
|
||||||
|
s.append(attr_str)
|
||||||
|
r += '\n'.join(s)
|
||||||
|
if use_mapping:
|
||||||
|
r += '}'
|
||||||
|
return r
|
||||||
|
|
||||||
|
cfg_dict = self._cfg_dict.to_dict()
|
||||||
|
text = _format_dict(cfg_dict, outest_level=True)
|
||||||
|
# copied from setup.cfg
|
||||||
|
yapf_style = dict(
|
||||||
|
based_on_style='pep8',
|
||||||
|
blank_line_before_nested_class_or_def=True,
|
||||||
|
split_before_expression_after_opening_paren=True)
|
||||||
|
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._cfg_dict)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._cfg_dict, name)
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return self._cfg_dict.__getitem__(name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = ConfigDict(value)
|
||||||
|
self._cfg_dict.__setattr__(name, value)
|
||||||
|
|
||||||
|
def __setitem__(self, name, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = ConfigDict(value)
|
||||||
|
self._cfg_dict.__setitem__(name, value)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._cfg_dict)
|
||||||
|
|
||||||
|
def dump(self, file=None):
|
||||||
|
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||||
|
if self.filename.endswith('.py'):
|
||||||
|
if file is None:
|
||||||
|
return self.pretty_text
|
||||||
|
else:
|
||||||
|
with open(file, 'w') as f:
|
||||||
|
f.write(self.pretty_text)
|
||||||
|
else:
|
||||||
|
import mmcv
|
||||||
|
if file is None:
|
||||||
|
file_format = self.filename.split('.')[-1]
|
||||||
|
return mmcv.dump(cfg_dict, file_format=file_format)
|
||||||
|
else:
|
||||||
|
mmcv.dump(cfg_dict, file)
|
||||||
|
|
||||||
|
def merge_from_dict(self, options):
|
||||||
|
"""Merge list into cfg_dict
|
||||||
|
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||||
|
Examples:
|
||||||
|
>>> options = {'model.backbone.depth': 50,
|
||||||
|
... 'model.backbone.with_cp':True}
|
||||||
|
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
|
||||||
|
>>> cfg.merge_from_dict(options)
|
||||||
|
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||||
|
>>> assert cfg_dict == dict(
|
||||||
|
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
||||||
|
Args:
|
||||||
|
options (dict): dict of configs to merge from.
|
||||||
|
"""
|
||||||
|
option_cfg_dict = {}
|
||||||
|
for full_key, v in options.items():
|
||||||
|
d = option_cfg_dict
|
||||||
|
key_list = full_key.split('.')
|
||||||
|
for subkey in key_list[:-1]:
|
||||||
|
d.setdefault(subkey, ConfigDict())
|
||||||
|
d = d[subkey]
|
||||||
|
subkey = key_list[-1]
|
||||||
|
d[subkey] = v
|
||||||
|
|
||||||
|
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||||
|
super(Config, self).__setattr__(
|
||||||
|
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
|
||||||
|
|
||||||
|
|
||||||
|
class DictAction(Action):
|
||||||
|
"""
|
||||||
|
argparse action to split an argument into KEY=VALUE form
|
||||||
|
on the first = and append to a dictionary. List options should
|
||||||
|
be passed as comma separated values, i.e KEY=V1,V2,V3
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_int_float_bool(val):
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if val.lower() in ['true', 'false']:
|
||||||
|
return True if val.lower() == 'true' else False
|
||||||
|
return val
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
options = {}
|
||||||
|
for kv in values:
|
||||||
|
key, val = kv.split('=', maxsplit=1)
|
||||||
|
val = [self._parse_int_float_bool(v) for v in val.split(',')]
|
||||||
|
if len(val) == 1:
|
||||||
|
val = val[0]
|
||||||
|
options[key] = val
|
||||||
|
setattr(namespace, self.dest, options)
|
|
@ -0,0 +1,81 @@
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
# borrow from mmdetection
|
||||||
|
|
||||||
|
def is_str(x):
|
||||||
|
"""Whether the input is an string instance."""
|
||||||
|
return isinstance(x, six.string_types)
|
||||||
|
|
||||||
|
class Registry(object):
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self._name = name
|
||||||
|
self._module_dict = dict()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
|
||||||
|
self._name, list(self._module_dict.keys()))
|
||||||
|
return format_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module_dict(self):
|
||||||
|
return self._module_dict
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
return self._module_dict.get(key, None)
|
||||||
|
|
||||||
|
def _register_module(self, module_class):
|
||||||
|
"""Register a module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`): Module to be registered.
|
||||||
|
"""
|
||||||
|
if not inspect.isclass(module_class):
|
||||||
|
raise TypeError('module must be a class, but got {}'.format(
|
||||||
|
type(module_class)))
|
||||||
|
module_name = module_class.__name__
|
||||||
|
if module_name in self._module_dict:
|
||||||
|
raise KeyError('{} is already registered in {}'.format(
|
||||||
|
module_name, self.name))
|
||||||
|
self._module_dict[module_name] = module_class
|
||||||
|
|
||||||
|
def register_module(self, cls):
|
||||||
|
self._register_module(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def build_from_cfg(cfg, registry, default_args=None):
|
||||||
|
"""Build a module from config dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
registry (:obj:`Registry`): The registry to search the type from.
|
||||||
|
default_args (dict, optional): Default initialization arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
obj: The constructed object.
|
||||||
|
"""
|
||||||
|
assert isinstance(cfg, dict) and 'type' in cfg
|
||||||
|
assert isinstance(default_args, dict) or default_args is None
|
||||||
|
args = {}
|
||||||
|
obj_type = cfg.type
|
||||||
|
if is_str(obj_type):
|
||||||
|
obj_cls = registry.get(obj_type)
|
||||||
|
if obj_cls is None:
|
||||||
|
raise KeyError('{} is not in the {} registry'.format(
|
||||||
|
obj_type, registry.name))
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError('type must be a str or valid type, but got {}'.format(
|
||||||
|
type(obj_type)))
|
||||||
|
if default_args is not None:
|
||||||
|
for name, value in default_args.items():
|
||||||
|
args.setdefault(name, value)
|
||||||
|
return obj_cls(**args)
|
|
@ -0,0 +1,357 @@
|
||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numbers
|
||||||
|
import collections
|
||||||
|
|
||||||
|
# copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py
|
||||||
|
|
||||||
|
__all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad',
|
||||||
|
'GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize']
|
||||||
|
|
||||||
|
|
||||||
|
class SampleResize(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
assert (isinstance(size, collections.Iterable) and len(size) == 2)
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
out = list()
|
||||||
|
out.append(cv2.resize(sample[0], self.size,
|
||||||
|
interpolation=cv2.INTER_CUBIC))
|
||||||
|
if len(sample) > 1:
|
||||||
|
out.append(cv2.resize(sample[1], self.size,
|
||||||
|
interpolation=cv2.INTER_NEAREST))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomCrop(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = random.randint(0, max(0, h - th))
|
||||||
|
w1 = random.randint(0, max(0, w - tw))
|
||||||
|
h2 = min(h1 + th, h)
|
||||||
|
w2 = min(w1 + tw, w)
|
||||||
|
|
||||||
|
for img in img_group:
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(img[h1:h2, w1:w2, ...])
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomCropRatio(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
tw, th = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = random.randint(0, max(0, h - th))
|
||||||
|
w1 = random.randint(0, max(0, w - tw))
|
||||||
|
h2 = min(h1 + th, h)
|
||||||
|
w2 = min(w1 + tw, w)
|
||||||
|
|
||||||
|
for img in img_group:
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(img[h1:h2, w1:w2, ...])
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCenterCrop(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = max(0, int((h - th) / 2))
|
||||||
|
w1 = max(0, int((w - tw) / 2))
|
||||||
|
h2 = min(h1 + th, h)
|
||||||
|
w2 = min(w1 + tw, w)
|
||||||
|
|
||||||
|
for img in img_group:
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(img[h1:h2, w1:w2, ...])
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomPad(object):
|
||||||
|
def __init__(self, size, padding):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.padding) == len(img_group))
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = random.randint(0, max(0, th - h))
|
||||||
|
w1 = random.randint(0, max(0, tw - w))
|
||||||
|
h2 = max(th - h - h1, 0)
|
||||||
|
w2 = max(tw - w - w1, 0)
|
||||||
|
|
||||||
|
for img, padding in zip(img_group, self.padding):
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(cv2.copyMakeBorder(
|
||||||
|
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCenterPad(object):
|
||||||
|
def __init__(self, size, padding):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.padding) == len(img_group))
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = max(0, int((th - h) / 2))
|
||||||
|
w1 = max(0, int((tw - w) / 2))
|
||||||
|
h2 = max(th - h - h1, 0)
|
||||||
|
w2 = max(tw - w - w1, 0)
|
||||||
|
|
||||||
|
for img, padding in zip(img_group, self.padding):
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(cv2.copyMakeBorder(
|
||||||
|
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupConcerPad(object):
|
||||||
|
def __init__(self, size, padding):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.padding) == len(img_group))
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
out_images = list()
|
||||||
|
h1 = 0
|
||||||
|
w1 = 0
|
||||||
|
h2 = max(th - h - h1, 0)
|
||||||
|
w2 = max(tw - w - w1, 0)
|
||||||
|
|
||||||
|
for img, padding in zip(img_group, self.padding):
|
||||||
|
assert (img.shape[0] == h and img.shape[1] == w)
|
||||||
|
out_images.append(cv2.copyMakeBorder(
|
||||||
|
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomScaleNew(object):
|
||||||
|
def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.interpolation) == len(img_group))
|
||||||
|
scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590
|
||||||
|
out_images = list()
|
||||||
|
for img, interpolation in zip(img_group, self.interpolation):
|
||||||
|
out_images.append(cv2.resize(img, None, fx=scale_w,
|
||||||
|
fy=scale_h, interpolation=interpolation))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomScale(object):
|
||||||
|
def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.interpolation) == len(img_group))
|
||||||
|
scale = random.uniform(self.size[0], self.size[1])
|
||||||
|
out_images = list()
|
||||||
|
for img, interpolation in zip(img_group, self.interpolation):
|
||||||
|
out_images.append(cv2.resize(img, None, fx=scale,
|
||||||
|
fy=scale, interpolation=interpolation))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomMultiScale(object):
|
||||||
|
def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.interpolation) == len(img_group))
|
||||||
|
scales = [0.5, 1.0, 1.5] # random.uniform(self.size[0], self.size[1])
|
||||||
|
out_images = list()
|
||||||
|
for scale in scales:
|
||||||
|
for img, interpolation in zip(img_group, self.interpolation):
|
||||||
|
out_images.append(cv2.resize(
|
||||||
|
img, None, fx=scale, fy=scale, interpolation=interpolation))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomScaleRatio(object):
|
||||||
|
def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784,
|
||||||
|
1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295]
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.interpolation) == len(img_group))
|
||||||
|
w_scale = random.randint(self.size[0], self.size[1])
|
||||||
|
h_scale = random.randint(self.size[2], self.size[3])
|
||||||
|
h, w, _ = img_group[0].shape
|
||||||
|
out_images = list()
|
||||||
|
out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h,
|
||||||
|
interpolation=self.interpolation[0])) # fx=w_scale*1.0/w, fy=h_scale*1.0/h
|
||||||
|
### process label map ###
|
||||||
|
origin_label = cv2.resize(
|
||||||
|
img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1])
|
||||||
|
origin_label = origin_label.astype(int)
|
||||||
|
label = origin_label[:, :, 0] * 5 + \
|
||||||
|
origin_label[:, :, 1] * 3 + origin_label[:, :, 2]
|
||||||
|
new_label = np.ones(label.shape) * 100
|
||||||
|
new_label = new_label.astype(int)
|
||||||
|
for cnt in range(37):
|
||||||
|
new_label = (
|
||||||
|
label == self.origin_id[cnt]) * (cnt - 100) + new_label
|
||||||
|
new_label = (label == self.origin_id[37]) * (36 - 100) + new_label
|
||||||
|
assert(100 not in np.unique(new_label))
|
||||||
|
out_images.append(new_label)
|
||||||
|
return out_images
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomRotation(object):
|
||||||
|
def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None):
|
||||||
|
self.degree = degree
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.padding = padding
|
||||||
|
if self.padding is None:
|
||||||
|
self.padding = [0, 0]
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.interpolation) == len(img_group))
|
||||||
|
v = random.random()
|
||||||
|
if v < 0.5:
|
||||||
|
degree = random.uniform(self.degree[0], self.degree[1])
|
||||||
|
h, w = img_group[0].shape[0:2]
|
||||||
|
center = (w / 2, h / 2)
|
||||||
|
map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0)
|
||||||
|
out_images = list()
|
||||||
|
for img, interpolation, padding in zip(img_group, self.interpolation, self.padding):
|
||||||
|
out_images.append(cv2.warpAffine(
|
||||||
|
img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding))
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
else:
|
||||||
|
return img_group
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomBlur(object):
|
||||||
|
def __init__(self, applied):
|
||||||
|
self.applied = applied
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
assert (len(self.applied) == len(img_group))
|
||||||
|
v = random.random()
|
||||||
|
if v < 0.5:
|
||||||
|
out_images = []
|
||||||
|
for img, a in zip(img_group, self.applied):
|
||||||
|
if a:
|
||||||
|
img = cv2.GaussianBlur(
|
||||||
|
img, (5, 5), random.uniform(1e-6, 0.6))
|
||||||
|
out_images.append(img)
|
||||||
|
if len(img.shape) > len(out_images[-1].shape):
|
||||||
|
out_images[-1] = out_images[-1][...,
|
||||||
|
np.newaxis] # single channel image
|
||||||
|
return out_images
|
||||||
|
else:
|
||||||
|
return img_group
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRandomHorizontalFlip(object):
|
||||||
|
"""Randomly horizontally flips the given numpy Image with a probability of 0.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, is_flow=False):
|
||||||
|
self.is_flow = is_flow
|
||||||
|
|
||||||
|
def __call__(self, img_group, is_flow=False):
|
||||||
|
v = random.random()
|
||||||
|
if v < 0.5:
|
||||||
|
out_images = [np.fliplr(img) for img in img_group]
|
||||||
|
if self.is_flow:
|
||||||
|
for i in range(0, len(out_images), 2):
|
||||||
|
# invert flow pixel values when flipping
|
||||||
|
out_images[i] = -out_images[i]
|
||||||
|
return out_images
|
||||||
|
else:
|
||||||
|
return img_group
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormalize(object):
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self, img_group):
|
||||||
|
out_images = list()
|
||||||
|
for img, m, s in zip(img_group, self.mean, self.std):
|
||||||
|
if len(m) == 1:
|
||||||
|
img = img - np.array(m) # single channel image
|
||||||
|
img = img / np.array(s)
|
||||||
|
else:
|
||||||
|
img = img - np.array(m)[np.newaxis, np.newaxis, ...]
|
||||||
|
img = img / np.array(s)[np.newaxis, np.newaxis, ...]
|
||||||
|
out_images.append(img)
|
||||||
|
|
||||||
|
return out_images
|
Loading…
Reference in New Issue