first commit
commit
c8504c8e34
|
@ -0,0 +1,22 @@
|
|||
work_dirs/
|
||||
predicts/
|
||||
output/
|
||||
data/
|
||||
data
|
||||
|
||||
__pycache__/
|
||||
*/*.un~
|
||||
.*.swp
|
||||
|
||||
|
||||
|
||||
*.egg-info/
|
||||
*.egg
|
||||
|
||||
output.txt
|
||||
.vscode/*
|
||||
.DS_Store
|
||||
tmp.*
|
||||
*.pt
|
||||
*.pth
|
||||
*.un~
|
|
@ -0,0 +1,74 @@
|
|||
|
||||
# Install
|
||||
|
||||
1. Clone the RESA repository
|
||||
```
|
||||
git clone https://github.com/zjulearning/resa.git
|
||||
```
|
||||
We call this directory as `$RESA_ROOT`
|
||||
|
||||
2. Create a conda virtual environment and activate it (conda is optional)
|
||||
|
||||
```Shell
|
||||
conda create -n resa python=3.8 -y
|
||||
conda activate resa
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```Shell
|
||||
# Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
|
||||
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
|
||||
|
||||
# Or you can install via pip
|
||||
pip install torch torchvision
|
||||
|
||||
# Install python packages
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Data preparation
|
||||
|
||||
Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
|
||||
|
||||
```Shell
|
||||
cd $RESA_ROOT
|
||||
ln -s $CULANEROOT data/CULane
|
||||
ln -s $TUSIMPLEROOT data/tusimple
|
||||
```
|
||||
|
||||
For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
|
||||
|
||||
```Shell
|
||||
python scripts/convert_tusimple.py --root $TUSIMPLEROOT
|
||||
# this will generate segmentations and two list files: train_gt.txt and test.txt
|
||||
```
|
||||
|
||||
For CULane, you should have structure like this:
|
||||
```
|
||||
$RESA_ROOT/data/CULane/driver_xx_xxframe # data folders x6
|
||||
$RESA_ROOT/data/CULane/laneseg_label_w16 # lane segmentation labels
|
||||
$RESA_ROOT/data/CULane/list # data lists
|
||||
```
|
||||
|
||||
For Tusimple, you should have structure like this:
|
||||
```
|
||||
$RESA_ROOT/data/tusimple/clips # data folders
|
||||
$RESA_ROOT/data/tusimple/lable_data_xxxx.json # label json file x4
|
||||
$RESA_ROOT/data/tusimple/test_tasks_0627.json # test tasks json file
|
||||
$RESA_ROOT/data/tusimple/test_label.json # test label json file
|
||||
```
|
||||
|
||||
5. Install CULane evaluation tools.
|
||||
|
||||
This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
|
||||
|
||||
|
||||
Then compile the evaluation tool of CULane.
|
||||
```Shell
|
||||
cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
|
||||
make
|
||||
cd -
|
||||
```
|
||||
|
||||
Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2021 Tu Zheng
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,148 @@
|
|||
# RESA
|
||||
PyTorch implementation of the paper "[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719)".
|
||||
|
||||
Our paper has been accepted by AAAI2021.
|
||||
|
||||
**News**: We also release RESA on [LaneDet](https://github.com/Turoad/lanedet). It's also recommended for you to try LaneDet.
|
||||
|
||||
## Introduction
|
||||
![intro](intro.png "intro")
|
||||
- RESA shifts sliced
|
||||
feature map recurrently in vertical and horizontal directions
|
||||
and enables each pixel to gather global information.
|
||||
- RESA achieves SOTA results on CULane and Tusimple Dataset.
|
||||
|
||||
## Get started
|
||||
1. Clone the RESA repository
|
||||
```
|
||||
git clone https://github.com/zjulearning/resa.git
|
||||
```
|
||||
We call this directory as `$RESA_ROOT`
|
||||
|
||||
2. Create a conda virtual environment and activate it (conda is optional)
|
||||
|
||||
```Shell
|
||||
conda create -n resa python=3.8 -y
|
||||
conda activate resa
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```Shell
|
||||
# Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
|
||||
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
|
||||
|
||||
# Or you can install via pip
|
||||
pip install torch torchvision
|
||||
|
||||
# Install python packages
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Data preparation
|
||||
|
||||
Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
|
||||
|
||||
```Shell
|
||||
cd $RESA_ROOT
|
||||
mkdir -p data
|
||||
ln -s $CULANEROOT data/CULane
|
||||
ln -s $TUSIMPLEROOT data/tusimple
|
||||
```
|
||||
|
||||
For CULane, you should have structure like this:
|
||||
```
|
||||
$CULANEROOT/driver_xx_xxframe # data folders x6
|
||||
$CULANEROOT/laneseg_label_w16 # lane segmentation labels
|
||||
$CULANEROOT/list # data lists
|
||||
```
|
||||
|
||||
For Tusimple, you should have structure like this:
|
||||
```
|
||||
$TUSIMPLEROOT/clips # data folders
|
||||
$TUSIMPLEROOT/lable_data_xxxx.json # label json file x4
|
||||
$TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
|
||||
$TUSIMPLEROOT/test_label.json # test label json file
|
||||
|
||||
```
|
||||
|
||||
For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.
|
||||
|
||||
```Shell
|
||||
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
|
||||
# this will generate seg_label directory
|
||||
```
|
||||
|
||||
5. Install CULane evaluation tools.
|
||||
|
||||
This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++. Or just install opencv with command `sudo apt-get install libopencv-dev`
|
||||
|
||||
|
||||
Then compile the evaluation tool of CULane.
|
||||
```Shell
|
||||
cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
|
||||
make
|
||||
cd -
|
||||
```
|
||||
|
||||
Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
For training, run
|
||||
|
||||
```Shell
|
||||
python main.py [configs/path_to_your_config] --gpus [gpu_ids]
|
||||
```
|
||||
|
||||
|
||||
For example, run
|
||||
```Shell
|
||||
python main.py configs/culane.py --gpus 0 1 2 3
|
||||
```
|
||||
|
||||
## Testing
|
||||
For testing, run
|
||||
```Shell
|
||||
python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num]
|
||||
```
|
||||
|
||||
For example, run
|
||||
```Shell
|
||||
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3
|
||||
|
||||
python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3
|
||||
```
|
||||
|
||||
|
||||
We provide two trained ResNet models on CULane and Tusimple, downloading our best performed model (Tusimple: [GoogleDrive](https://drive.google.com/file/d/1M1xi82y0RoWUwYYG9LmZHXWSD2D60o0D/view?usp=sharing)/[BaiduDrive(code:s5ii)](https://pan.baidu.com/s/1CgJFrt9OHe-RUNooPpHRGA),
|
||||
CULane: [GoogleDrive](https://drive.google.com/file/d/1pcqq9lpJ4ixJgFVFndlPe42VgVsjgn0Q/view?usp=sharing)/[BaiduDrive(code:rlwj)](https://pan.baidu.com/s/1ODKAZxpKrZIPXyaNnxcV3g)
|
||||
)
|
||||
|
||||
## Visualization
|
||||
Just add `--view`.
|
||||
|
||||
For example:
|
||||
```Shell
|
||||
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 --view
|
||||
```
|
||||
You will get the result in the directory: `work_dirs/[DATASET]/xxx/vis`.
|
||||
|
||||
## Citation
|
||||
If you use our method, please consider citing:
|
||||
```BibTeX
|
||||
@inproceedings{zheng2021resa,
|
||||
title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection},
|
||||
author={Zheng, Tu and Fang, Hao and Zhang, Yi and Tang, Wenjian and Yang, Zheng and Liu, Haifeng and Cai, Deng},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={35},
|
||||
number={4},
|
||||
pages={3547--3554},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
<!-- ## Thanks
|
||||
|
||||
The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark). -->
|
|
@ -0,0 +1,88 @@
|
|||
net = dict(
|
||||
type='RESANet',
|
||||
)
|
||||
|
||||
backbone = dict(
|
||||
type='ResNetWrapper',
|
||||
resnet='resnet50',
|
||||
pretrained=True,
|
||||
replace_stride_with_dilation=[False, True, True],
|
||||
out_conv=True,
|
||||
fea_stride=8,
|
||||
)
|
||||
|
||||
resa = dict(
|
||||
type='RESA',
|
||||
alpha=2.0,
|
||||
iter=4,
|
||||
input_channel=128,
|
||||
conv_stride=9,
|
||||
)
|
||||
|
||||
#decoder = 'PlainDecoder'
|
||||
decoder = 'BUSD'
|
||||
|
||||
trainer = dict(
|
||||
type='RESA'
|
||||
)
|
||||
|
||||
evaluator = dict(
|
||||
type='CULane',
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type='sgd',
|
||||
lr=0.025,
|
||||
weight_decay=1e-4,
|
||||
momentum=0.9
|
||||
)
|
||||
|
||||
epochs = 12
|
||||
batch_size = 8
|
||||
total_iter = (88880 // batch_size) * epochs
|
||||
import math
|
||||
scheduler = dict(
|
||||
type = 'LambdaLR',
|
||||
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||
)
|
||||
|
||||
loss_type = 'dice_loss'
|
||||
seg_loss_weight = 2.
|
||||
eval_ep = 6
|
||||
save_ep = epochs
|
||||
|
||||
bg_weight = 0.4
|
||||
|
||||
img_norm = dict(
|
||||
mean=[103.939, 116.779, 123.68],
|
||||
std=[1., 1., 1.]
|
||||
)
|
||||
|
||||
img_height = 288
|
||||
img_width = 800
|
||||
cut_height = 240
|
||||
|
||||
dataset_path = './data/CULane'
|
||||
dataset = dict(
|
||||
train=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='train_gt.txt',
|
||||
),
|
||||
val=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='test.txt',
|
||||
),
|
||||
test=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='test.txt',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
workers = 12
|
||||
num_classes = 4 + 1
|
||||
ignore_label = 255
|
||||
log_interval = 500
|
|
@ -0,0 +1,97 @@
|
|||
net = dict(
|
||||
type='RESANet',
|
||||
)
|
||||
|
||||
# backbone = dict(
|
||||
# type='ResNetWrapper',
|
||||
# resnet='resnet50',
|
||||
# pretrained=True,
|
||||
# replace_stride_with_dilation=[False, True, True],
|
||||
# out_conv=True,
|
||||
# fea_stride=8,
|
||||
# )
|
||||
|
||||
backbone = dict(
|
||||
type='ResNetWrapper',
|
||||
resnet='resnet34',
|
||||
pretrained=True,
|
||||
replace_stride_with_dilation=[False, False, False],
|
||||
out_conv=False,
|
||||
fea_stride=8,
|
||||
)
|
||||
|
||||
resa = dict(
|
||||
type='RESA',
|
||||
alpha=2.0,
|
||||
iter=4,
|
||||
input_channel=128,
|
||||
conv_stride=9,
|
||||
)
|
||||
|
||||
#decoder = 'PlainDecoder'
|
||||
decoder = 'BUSD'
|
||||
|
||||
trainer = dict(
|
||||
type='RESA'
|
||||
)
|
||||
|
||||
evaluator = dict(
|
||||
type='CULane',
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type='sgd',
|
||||
lr=0.025,
|
||||
weight_decay=1e-4,
|
||||
momentum=0.9
|
||||
)
|
||||
|
||||
epochs = 20
|
||||
batch_size = 8
|
||||
total_iter = (88880 // batch_size) * epochs
|
||||
import math
|
||||
scheduler = dict(
|
||||
type = 'LambdaLR',
|
||||
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||
)
|
||||
|
||||
loss_type = 'dice_loss'
|
||||
seg_loss_weight = 2.
|
||||
eval_ep = 1
|
||||
save_ep = epochs
|
||||
|
||||
bg_weight = 0.4
|
||||
|
||||
img_norm = dict(
|
||||
mean=[103.939, 116.779, 123.68],
|
||||
std=[1., 1., 1.]
|
||||
)
|
||||
|
||||
img_height = 288
|
||||
img_width = 800
|
||||
cut_height = 240
|
||||
|
||||
dataset_path = './data/CULane'
|
||||
dataset = dict(
|
||||
train=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='train_gt.txt',
|
||||
),
|
||||
val=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='test.txt',
|
||||
),
|
||||
test=dict(
|
||||
type='CULane',
|
||||
img_path=dataset_path,
|
||||
data_list='test.txt',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
workers = 12
|
||||
num_classes = 4 + 1
|
||||
ignore_label = 255
|
||||
log_interval = 500
|
|
@ -0,0 +1,93 @@
|
|||
net = dict(
|
||||
type='RESANet',
|
||||
)
|
||||
|
||||
backbone = dict(
|
||||
type='ResNetWrapper',
|
||||
resnet='resnet34',
|
||||
pretrained=True,
|
||||
replace_stride_with_dilation=[False, True, True],
|
||||
out_conv=True,
|
||||
fea_stride=8,
|
||||
)
|
||||
|
||||
resa = dict(
|
||||
type='RESA',
|
||||
alpha=2.0,
|
||||
iter=5,
|
||||
input_channel=128,
|
||||
conv_stride=9,
|
||||
)
|
||||
|
||||
decoder = 'BUSD'
|
||||
|
||||
trainer = dict(
|
||||
type='RESA'
|
||||
)
|
||||
|
||||
evaluator = dict(
|
||||
type='Tusimple',
|
||||
thresh = 0.60
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type='sgd',
|
||||
lr=0.020,
|
||||
weight_decay=1e-4,
|
||||
momentum=0.9
|
||||
)
|
||||
|
||||
total_iter = 181400
|
||||
import math
|
||||
scheduler = dict(
|
||||
type = 'LambdaLR',
|
||||
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
|
||||
)
|
||||
|
||||
bg_weight = 0.4
|
||||
|
||||
img_norm = dict(
|
||||
mean=[103.939, 116.779, 123.68],
|
||||
std=[1., 1., 1.]
|
||||
)
|
||||
|
||||
img_height = 368
|
||||
img_width = 640
|
||||
cut_height = 160
|
||||
seg_label = "seg_label"
|
||||
|
||||
dataset_path = './data/tusimple'
|
||||
test_json_file = './data/tusimple/test_label.json'
|
||||
|
||||
dataset = dict(
|
||||
train=dict(
|
||||
type='TuSimple',
|
||||
img_path=dataset_path,
|
||||
data_list='train_val_gt.txt',
|
||||
),
|
||||
val=dict(
|
||||
type='TuSimple',
|
||||
img_path=dataset_path,
|
||||
data_list='test_gt.txt'
|
||||
),
|
||||
test=dict(
|
||||
type='TuSimple',
|
||||
img_path=dataset_path,
|
||||
data_list='test_gt.txt'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
loss_type = 'cross_entropy'
|
||||
seg_loss_weight = 1.0
|
||||
|
||||
|
||||
batch_size = 4
|
||||
workers = 12
|
||||
num_classes = 6 + 1
|
||||
ignore_label = 255
|
||||
epochs = 300
|
||||
log_interval = 100
|
||||
eval_ep = 1
|
||||
save_ep = epochs
|
||||
log_note = ''
|
|
@ -0,0 +1,4 @@
|
|||
from .registry import build_dataset, build_dataloader
|
||||
|
||||
from .tusimple import TuSimple
|
||||
from .culane import CULane
|
|
@ -0,0 +1,86 @@
|
|||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision
|
||||
import utils.transforms as tf
|
||||
from .registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class BaseDataset(Dataset):
|
||||
def __init__(self, img_path, data_list, list_path='list', cfg=None):
|
||||
self.cfg = cfg
|
||||
self.img_path = img_path
|
||||
self.list_path = osp.join(img_path, list_path)
|
||||
self.data_list = data_list
|
||||
self.is_training = ('train' in data_list)
|
||||
|
||||
self.img_name_list = []
|
||||
self.full_img_path_list = []
|
||||
self.label_list = []
|
||||
self.exist_list = []
|
||||
|
||||
self.transform = self.transform_train() if self.is_training else self.transform_val()
|
||||
|
||||
self.init()
|
||||
|
||||
def transform_train(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def transform_val(self):
|
||||
val_transform = torchvision.transforms.Compose([
|
||||
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||
self.cfg.img_norm['std'], (1, ))),
|
||||
])
|
||||
return val_transform
|
||||
|
||||
def view(self, img, coords, file_path=None):
|
||||
for coord in coords:
|
||||
for x, y in coord:
|
||||
if x <= 0 or y <= 0:
|
||||
continue
|
||||
x, y = int(x), int(y)
|
||||
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
|
||||
|
||||
if file_path is not None:
|
||||
if not os.path.exists(osp.dirname(file_path)):
|
||||
os.makedirs(osp.dirname(file_path))
|
||||
cv2.imwrite(file_path, img)
|
||||
|
||||
|
||||
def init(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.full_img_path_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
|
||||
img = img[self.cfg.cut_height:, :, :]
|
||||
|
||||
if self.is_training:
|
||||
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
|
||||
if len(label.shape) > 2:
|
||||
label = label[:, :, 0]
|
||||
label = label.squeeze()
|
||||
label = label[self.cfg.cut_height:, :]
|
||||
exist = self.exist_list[idx]
|
||||
if self.transform:
|
||||
img, label = self.transform((img, label))
|
||||
label = torch.from_numpy(label).contiguous().long()
|
||||
else:
|
||||
img, = self.transform((img,))
|
||||
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
|
||||
meta = {'full_img_path': self.full_img_path_list[idx],
|
||||
'img_name': self.img_name_list[idx]}
|
||||
|
||||
data = {'img': img, 'meta': meta}
|
||||
if self.is_training:
|
||||
data.update({'label': label, 'exist': exist})
|
||||
return data
|
|
@ -0,0 +1,72 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision
|
||||
import utils.transforms as tf
|
||||
from .base_dataset import BaseDataset
|
||||
from .registry import DATASETS
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class CULane(BaseDataset):
|
||||
def __init__(self, img_path, data_list, cfg=None):
|
||||
super().__init__(img_path, data_list, cfg=cfg)
|
||||
self.ori_imgh = 590
|
||||
self.ori_imgw = 1640
|
||||
|
||||
def init(self):
|
||||
with open(osp.join(self.list_path, self.data_list)) as f:
|
||||
for line in f:
|
||||
line_split = line.strip().split(" ")
|
||||
self.img_name_list.append(line_split[0])
|
||||
self.full_img_path_list.append(self.img_path + line_split[0])
|
||||
if not self.is_training:
|
||||
continue
|
||||
self.label_list.append(self.img_path + line_split[1])
|
||||
self.exist_list.append(
|
||||
np.array([int(line_split[2]), int(line_split[3]),
|
||||
int(line_split[4]), int(line_split[5])]))
|
||||
|
||||
def transform_train(self):
|
||||
train_transform = torchvision.transforms.Compose([
|
||||
tf.GroupRandomRotation(degree=(-2, 2)),
|
||||
tf.GroupRandomHorizontalFlip(),
|
||||
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||
self.cfg.img_norm['std'], (1, ))),
|
||||
])
|
||||
return train_transform
|
||||
|
||||
def probmap2lane(self, probmaps, exists, pts=18):
|
||||
coords = []
|
||||
probmaps = probmaps[1:, ...]
|
||||
exists = exists > 0.5
|
||||
for probmap, exist in zip(probmaps, exists):
|
||||
if exist == 0:
|
||||
continue
|
||||
probmap = cv2.blur(probmap, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||
thr = 0.3
|
||||
coordinate = np.zeros(pts)
|
||||
cut_height = self.cfg.cut_height
|
||||
for i in range(pts):
|
||||
line = probmap[round(
|
||||
self.cfg.img_height-i*20/(self.ori_imgh-cut_height)*self.cfg.img_height)-1]
|
||||
|
||||
if np.max(line) > thr:
|
||||
coordinate[i] = np.argmax(line)+1
|
||||
if np.sum(coordinate > 0) < 2:
|
||||
continue
|
||||
|
||||
img_coord = np.zeros((pts, 2))
|
||||
img_coord[:, :] = -1
|
||||
for idx, value in enumerate(coordinate):
|
||||
if value > 0:
|
||||
img_coord[idx][0] = round(value*self.ori_imgw/self.cfg.img_width-1)
|
||||
img_coord[idx][1] = round(self.ori_imgh-idx*20-1)
|
||||
|
||||
img_coord = img_coord.astype(int)
|
||||
coords.append(img_coord)
|
||||
|
||||
return coords
|
|
@ -0,0 +1,36 @@
|
|||
from utils import Registry, build_from_cfg
|
||||
|
||||
import torch
|
||||
|
||||
DATASETS = Registry('datasets')
|
||||
|
||||
def build(cfg, registry, default_args=None):
|
||||
if isinstance(cfg, list):
|
||||
modules = [
|
||||
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||
]
|
||||
return nn.Sequential(*modules)
|
||||
else:
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_dataset(split_cfg, cfg):
|
||||
args = split_cfg.copy()
|
||||
args.pop('type')
|
||||
args = args.to_dict()
|
||||
args['cfg'] = cfg
|
||||
return build(split_cfg, DATASETS, default_args=args)
|
||||
|
||||
def build_dataloader(split_cfg, cfg, is_train=True):
|
||||
if is_train:
|
||||
shuffle = True
|
||||
else:
|
||||
shuffle = False
|
||||
|
||||
dataset = build_dataset(split_cfg, cfg)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size = cfg.batch_size, shuffle = shuffle,
|
||||
num_workers = cfg.workers, pin_memory = False, drop_last = False)
|
||||
|
||||
return data_loader
|
|
@ -0,0 +1,150 @@
|
|||
import os.path as osp
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torchvision
|
||||
import utils.transforms as tf
|
||||
from .base_dataset import BaseDataset
|
||||
from .registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class TuSimple(BaseDataset):
|
||||
def __init__(self, img_path, data_list, cfg=None):
|
||||
super().__init__(img_path, data_list, 'seg_label/list', cfg)
|
||||
|
||||
def transform_train(self):
|
||||
input_mean = self.cfg.img_norm['mean']
|
||||
train_transform = torchvision.transforms.Compose([
|
||||
tf.GroupRandomRotation(),
|
||||
tf.GroupRandomHorizontalFlip(),
|
||||
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
||||
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
||||
self.cfg.img_norm['std'], (1, ))),
|
||||
])
|
||||
return train_transform
|
||||
|
||||
|
||||
def init(self):
|
||||
with open(osp.join(self.list_path, self.data_list)) as f:
|
||||
for line in f:
|
||||
line_split = line.strip().split(" ")
|
||||
self.img_name_list.append(line_split[0])
|
||||
self.full_img_path_list.append(self.img_path + line_split[0])
|
||||
if not self.is_training:
|
||||
continue
|
||||
self.label_list.append(self.img_path + line_split[1])
|
||||
self.exist_list.append(
|
||||
np.array([int(line_split[2]), int(line_split[3]),
|
||||
int(line_split[4]), int(line_split[5]),
|
||||
int(line_split[6]), int(line_split[7])
|
||||
]))
|
||||
|
||||
def fix_gap(self, coordinate):
|
||||
if any(x > 0 for x in coordinate):
|
||||
start = [i for i, x in enumerate(coordinate) if x > 0][0]
|
||||
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
|
||||
lane = coordinate[start:end+1]
|
||||
if any(x < 0 for x in lane):
|
||||
gap_start = [i for i, x in enumerate(
|
||||
lane[:-1]) if x > 0 and lane[i+1] < 0]
|
||||
gap_end = [i+1 for i,
|
||||
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
|
||||
gap_id = [i for i, x in enumerate(lane) if x < 0]
|
||||
if len(gap_start) == 0 or len(gap_end) == 0:
|
||||
return coordinate
|
||||
for id in gap_id:
|
||||
for i in range(len(gap_start)):
|
||||
if i >= len(gap_end):
|
||||
return coordinate
|
||||
if id > gap_start[i] and id < gap_end[i]:
|
||||
gap_width = float(gap_end[i] - gap_start[i])
|
||||
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
|
||||
gap_end[i] - id) / gap_width * lane[gap_start[i]])
|
||||
if not all(x > 0 for x in lane):
|
||||
print("Gaps still exist!")
|
||||
coordinate[start:end+1] = lane
|
||||
return coordinate
|
||||
|
||||
def is_short(self, lane):
|
||||
start = [i for i, x in enumerate(lane) if x > 0]
|
||||
if not start:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
|
||||
"""
|
||||
Arguments:
|
||||
----------
|
||||
prob_map: prob map for single lane, np array size (h, w)
|
||||
resize_shape: reshape size target, (H, W)
|
||||
|
||||
Return:
|
||||
----------
|
||||
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
|
||||
"""
|
||||
if resize_shape is None:
|
||||
resize_shape = prob_map.shape
|
||||
h, w = prob_map.shape
|
||||
H, W = resize_shape
|
||||
H -= self.cfg.cut_height
|
||||
|
||||
coords = np.zeros(pts)
|
||||
coords[:] = -1.0
|
||||
for i in range(pts):
|
||||
y = int((H - 10 - i * y_px_gap) * h / H)
|
||||
if y < 0:
|
||||
break
|
||||
line = prob_map[y, :]
|
||||
id = np.argmax(line)
|
||||
if line[id] > thresh:
|
||||
coords[i] = int(id / w * W)
|
||||
if (coords > 0).sum() < 2:
|
||||
coords = np.zeros(pts)
|
||||
self.fix_gap(coords)
|
||||
#print(coords.shape)
|
||||
|
||||
return coords
|
||||
|
||||
def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
|
||||
"""
|
||||
Arguments:
|
||||
----------
|
||||
seg_pred: np.array size (5, h, w)
|
||||
resize_shape: reshape size target, (H, W)
|
||||
exist: list of existence, e.g. [0, 1, 1, 0]
|
||||
smooth: whether to smooth the probability or not
|
||||
y_px_gap: y pixel gap for sampling
|
||||
pts: how many points for one lane
|
||||
thresh: probability threshold
|
||||
|
||||
Return:
|
||||
----------
|
||||
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
|
||||
"""
|
||||
if resize_shape is None:
|
||||
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
|
||||
_, h, w = seg_pred.shape
|
||||
H, W = resize_shape
|
||||
coordinates = []
|
||||
|
||||
for i in range(self.cfg.num_classes - 1):
|
||||
prob_map = seg_pred[i + 1]
|
||||
if smooth:
|
||||
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||
coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
|
||||
if self.is_short(coords):
|
||||
continue
|
||||
coordinates.append(
|
||||
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||
range(pts)])
|
||||
|
||||
|
||||
if len(coordinates) == 0:
|
||||
coords = np.zeros(pts)
|
||||
coordinates.append(
|
||||
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||
range(pts)])
|
||||
#print(coordinates)
|
||||
|
||||
return coordinates
|
|
@ -0,0 +1,73 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import shutil
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim
|
||||
import cv2
|
||||
import numpy as np
|
||||
import models
|
||||
import argparse
|
||||
from utils.config import Config
|
||||
from runner.runner import Runner
|
||||
from datasets import build_dataloader
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.gpus = len(args.gpus)
|
||||
|
||||
cfg.load_from = args.load_from
|
||||
cfg.finetune_from = args.finetune_from
|
||||
cfg.view = args.view
|
||||
|
||||
cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.fastest = True
|
||||
|
||||
runner = Runner(cfg)
|
||||
|
||||
if args.validate:
|
||||
val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
|
||||
runner.validate(val_loader)
|
||||
else:
|
||||
runner.train()
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a detector')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--work_dirs', type=str, default='work_dirs',
|
||||
help='work dirs')
|
||||
parser.add_argument(
|
||||
'--load_from', default=None,
|
||||
help='the checkpoint file to resume from')
|
||||
parser.add_argument(
|
||||
'--finetune_from', default=None,
|
||||
help='whether to finetune from the checkpoint')
|
||||
parser.add_argument(
|
||||
'--validate',
|
||||
action='store_true',
|
||||
help='whether to evaluate the checkpoint during training')
|
||||
parser.add_argument(
|
||||
'--view',
|
||||
action='store_true',
|
||||
help='whether to show visualization result')
|
||||
parser.add_argument('--gpus', nargs='+', type=int, default='0')
|
||||
parser.add_argument('--seed', type=int,
|
||||
default=None, help='random seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
from .resa import *
|
|
@ -0,0 +1,129 @@
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PlainDecoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(PlainDecoder, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.dropout = nn.Dropout2d(0.1)
|
||||
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.conv8(x)
|
||||
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class non_bottleneck_1d(nn.Module):
|
||||
def __init__(self, chann, dropprob, dilated):
|
||||
super().__init__()
|
||||
|
||||
self.conv3x1_1 = nn.Conv2d(
|
||||
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||
|
||||
self.conv1x3_1 = nn.Conv2d(
|
||||
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||
dilation=(dilated, 1))
|
||||
|
||||
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||
dilation=(1, dilated))
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.dropout = nn.Dropout2d(dropprob)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv3x1_1(input)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_1(output)
|
||||
output = self.bn1(output)
|
||||
output = F.relu(output)
|
||||
|
||||
output = self.conv3x1_2(output)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_2(output)
|
||||
output = self.bn2(output)
|
||||
|
||||
if (self.dropout.p != 0):
|
||||
output = self.dropout(output)
|
||||
|
||||
# +input = identity (residual connection)
|
||||
return F.relu(output + input)
|
||||
|
||||
|
||||
class UpsamplerBlock(nn.Module):
|
||||
def __init__(self, ninput, noutput, up_width, up_height):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||
|
||||
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
self.follows = nn.ModuleList()
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
|
||||
# interpolate
|
||||
self.up_width = up_width
|
||||
self.up_height = up_height
|
||||
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||
self.interpolate_bn = nn.BatchNorm2d(
|
||||
noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
out = F.relu(output)
|
||||
for follow in self.follows:
|
||||
out = follow(out)
|
||||
|
||||
interpolate_output = self.interpolate_conv(input)
|
||||
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||
interpolate_output = F.relu(interpolate_output)
|
||||
|
||||
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return out + interpolate
|
||||
|
||||
class BUSD(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
img_height = cfg.img_height
|
||||
img_width = cfg.img_width
|
||||
num_classes = cfg.num_classes
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
|
||||
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||
|
||||
self.output_conv = conv1x1(16, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output)
|
||||
|
||||
output = self.output_conv(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,135 @@
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
class PlainDecoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(PlainDecoder, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.dropout = nn.Dropout2d(0.1)
|
||||
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.conv8(x)
|
||||
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class non_bottleneck_1d(nn.Module):
|
||||
def __init__(self, chann, dropprob, dilated):
|
||||
super().__init__()
|
||||
|
||||
self.conv3x1_1 = nn.Conv2d(
|
||||
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||
|
||||
self.conv1x3_1 = nn.Conv2d(
|
||||
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||
dilation=(dilated, 1))
|
||||
|
||||
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||
dilation=(1, dilated))
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.dropout = nn.Dropout2d(dropprob)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv3x1_1(input)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_1(output)
|
||||
output = self.bn1(output)
|
||||
output = F.relu(output)
|
||||
|
||||
output = self.conv3x1_2(output)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_2(output)
|
||||
output = self.bn2(output)
|
||||
|
||||
if (self.dropout.p != 0):
|
||||
output = self.dropout(output)
|
||||
|
||||
# +input = identity (residual connection)
|
||||
return F.relu(output + input)
|
||||
|
||||
|
||||
class UpsamplerBlock(nn.Module):
|
||||
def __init__(self, ninput, noutput, up_width, up_height):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||
|
||||
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
self.follows = nn.ModuleList()
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
|
||||
# interpolate
|
||||
self.up_width = up_width
|
||||
self.up_height = up_height
|
||||
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||
self.interpolate_bn = nn.BatchNorm2d(
|
||||
noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
out = F.relu(output)
|
||||
for follow in self.follows:
|
||||
out = follow(out)
|
||||
|
||||
interpolate_output = self.interpolate_conv(input)
|
||||
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||
interpolate_output = F.relu(interpolate_output)
|
||||
|
||||
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return out + interpolate
|
||||
|
||||
class BUSD(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
img_height = cfg.img_height
|
||||
img_width = cfg.img_width
|
||||
num_classes = cfg.num_classes
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||
|
||||
self.output_conv = conv1x1(32, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
x = input[0]
|
||||
output = input[1]
|
||||
|
||||
for i,layer in enumerate(self.layers):
|
||||
output = layer(output)
|
||||
if i == 0:
|
||||
output = torch.cat((x, output), dim=1)
|
||||
|
||||
|
||||
|
||||
output = self.output_conv(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,143 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PlainDecoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(PlainDecoder, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.dropout = nn.Dropout2d(0.1)
|
||||
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.conv8(x)
|
||||
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class non_bottleneck_1d(nn.Module):
|
||||
def __init__(self, chann, dropprob, dilated):
|
||||
super().__init__()
|
||||
|
||||
self.conv3x1_1 = nn.Conv2d(
|
||||
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
||||
|
||||
self.conv1x3_1 = nn.Conv2d(
|
||||
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
||||
dilation=(dilated, 1))
|
||||
|
||||
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
||||
dilation=(1, dilated))
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
||||
|
||||
self.dropout = nn.Dropout2d(dropprob)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv3x1_1(input)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_1(output)
|
||||
output = self.bn1(output)
|
||||
output = F.relu(output)
|
||||
|
||||
output = self.conv3x1_2(output)
|
||||
output = F.relu(output)
|
||||
output = self.conv1x3_2(output)
|
||||
output = self.bn2(output)
|
||||
|
||||
if (self.dropout.p != 0):
|
||||
output = self.dropout(output)
|
||||
|
||||
# +input = identity (residual connection)
|
||||
return F.relu(output + input)
|
||||
|
||||
|
||||
class UpsamplerBlock(nn.Module):
|
||||
def __init__(self, ninput, noutput, up_width, up_height):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
||||
|
||||
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
self.follows = nn.ModuleList()
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
||||
|
||||
# interpolate
|
||||
self.up_width = up_width
|
||||
self.up_height = up_height
|
||||
self.interpolate_conv = conv1x1(ninput, noutput)
|
||||
self.interpolate_bn = nn.BatchNorm2d(
|
||||
noutput, eps=1e-3, track_running_stats=True)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
out = F.relu(output)
|
||||
for follow in self.follows:
|
||||
out = follow(out)
|
||||
|
||||
interpolate_output = self.interpolate_conv(input)
|
||||
interpolate_output = self.interpolate_bn(interpolate_output)
|
||||
interpolate_output = F.relu(interpolate_output)
|
||||
|
||||
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
||||
mode='bilinear', align_corners=False)
|
||||
|
||||
return out + interpolate
|
||||
|
||||
class BUSD(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
img_height = cfg.img_height
|
||||
img_width = cfg.img_width
|
||||
num_classes = cfg.num_classes
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
||||
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
||||
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
||||
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
||||
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
|
||||
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
||||
self.out1 = conv1x1(128, 64)
|
||||
self.out2 = conv1x1(64, 32)
|
||||
self.output_conv = conv1x1(16, num_classes)
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
out1 = input[0]
|
||||
out2 = input[1]
|
||||
output = input[2]
|
||||
|
||||
for i,layer in enumerate(self.layers):
|
||||
if i == 0:
|
||||
output = layer(output)
|
||||
output = torch.cat((out2, output), dim=1)
|
||||
output = self.out1(output)
|
||||
elif i == 1:
|
||||
output = layer(output)
|
||||
output = torch.cat((out1, output), dim=1)
|
||||
output = self.out2(output)
|
||||
else:
|
||||
output = layer(output)
|
||||
|
||||
output = self.output_conv(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,422 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
from torchvision.transforms._presets import ImageClassification
|
||||
from torchvision.utils import _log_api_usage_once
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
||||
from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
|
||||
import collections
|
||||
from itertools import repeat
|
||||
M = TypeVar("M", bound=nn.Module)
|
||||
|
||||
BUILTIN_MODELS = {}
|
||||
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
|
||||
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
|
||||
key = name if name is not None else fn.__name__
|
||||
if key in BUILTIN_MODELS:
|
||||
raise ValueError(f"An entry is already registered under the name '{key}'.")
|
||||
BUILTIN_MODELS[key] = fn
|
||||
return fn
|
||||
|
||||
return wrapper
|
||||
|
||||
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||
Otherwise, we will make a tuple of length n, all with value of x.
|
||||
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||
|
||||
Args:
|
||||
x (Any): input value
|
||||
n (int): length of the resulting tuple
|
||||
"""
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
class ConvNormActivation(torch.nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
|
||||
groups: int = 1,
|
||||
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||
inplace: Optional[bool] = True,
|
||||
bias: Optional[bool] = None,
|
||||
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
||||
) -> None:
|
||||
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int) and isinstance(dilation, int):
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
else:
|
||||
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
|
||||
kernel_size = _make_ntuple(kernel_size, _conv_dim)
|
||||
dilation = _make_ntuple(dilation, _conv_dim)
|
||||
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
|
||||
if bias is None:
|
||||
bias = norm_layer is None
|
||||
|
||||
layers = [
|
||||
conv_layer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
]
|
||||
|
||||
if norm_layer is not None:
|
||||
layers.append(norm_layer(out_channels))
|
||||
|
||||
if activation_layer is not None:
|
||||
params = {} if inplace is None else {"inplace": inplace}
|
||||
layers.append(activation_layer(**params))
|
||||
super().__init__(*layers)
|
||||
_log_api_usage_once(self)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if self.__class__ == ConvNormActivation:
|
||||
warnings.warn(
|
||||
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
|
||||
)
|
||||
|
||||
|
||||
class Conv2dNormActivation(ConvNormActivation):
|
||||
"""
|
||||
Configurable block used for Convolution2d-Normalization-Activation blocks.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
||||
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
||||
stride (int, optional): Stride of the convolution. Default: 1
|
||||
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
|
||||
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
|
||||
dilation (int): Spacing between kernel elements. Default: 1
|
||||
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
||||
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Optional[Union[int, Tuple[int, int], str]] = None,
|
||||
groups: int = 1,
|
||||
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
inplace: Optional[bool] = True,
|
||||
bias: Optional[bool] = None,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
norm_layer,
|
||||
activation_layer,
|
||||
dilation,
|
||||
inplace,
|
||||
bias,
|
||||
torch.nn.Conv2d,
|
||||
)
|
||||
|
||||
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
|
||||
|
||||
|
||||
# necessary for backwards compatibility
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(
|
||||
self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers: List[nn.Module] = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(
|
||||
Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||
)
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
Conv2dNormActivation(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
stride=stride,
|
||||
groups=hidden_dim,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU6,
|
||||
),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.out_channels = oup
|
||||
self._is_cn = stride > 1
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 1000,
|
||||
width_mult: float = 1.0,
|
||||
inverted_residual_setting: Optional[List[List[int]]] = None,
|
||||
round_nearest: int = 8,
|
||||
block: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
dropout: float = 0.2,
|
||||
) -> None:
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
dropout (float): The droupout probability
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
_log_api_usage_once(self)
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 1], # **
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 1], # **
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError(
|
||||
f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
|
||||
)
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features: List[nn.Module] = [
|
||||
Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||
]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(
|
||||
Conv2dNormActivation(
|
||||
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
|
||||
)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1
|
||||
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||
# x = torch.flatten(x, 1)
|
||||
# x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
_COMMON_META = {
|
||||
"num_params": 3504872,
|
||||
"min_size": (1, 1),
|
||||
"categories": _IMAGENET_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
class MobileNet_V2_Weights(WeightsEnum):
|
||||
IMAGENET1K_V1 = Weights(
|
||||
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
transforms=partial(ImageClassification, crop_size=224),
|
||||
meta={
|
||||
**_COMMON_META,
|
||||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
|
||||
"_metrics": {
|
||||
"ImageNet-1K": {
|
||||
"acc@1": 71.878,
|
||||
"acc@5": 90.286,
|
||||
}
|
||||
},
|
||||
"_ops": 0.301,
|
||||
"_file_size": 13.555,
|
||||
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
||||
},
|
||||
)
|
||||
IMAGENET1K_V2 = Weights(
|
||||
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
|
||||
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||||
meta={
|
||||
**_COMMON_META,
|
||||
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
|
||||
"_metrics": {
|
||||
"ImageNet-1K": {
|
||||
"acc@1": 72.154,
|
||||
"acc@5": 90.822,
|
||||
}
|
||||
},
|
||||
"_ops": 0.301,
|
||||
"_file_size": 13.598,
|
||||
"_docs": """
|
||||
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
||||
`new training recipe
|
||||
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
||||
""",
|
||||
},
|
||||
)
|
||||
DEFAULT = IMAGENET1K_V2
|
||||
|
||||
|
||||
# @register_model()
|
||||
# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
|
||||
def mobilenet_v2(
|
||||
*, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
|
||||
) -> MobileNetV2:
|
||||
"""MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
|
||||
Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
|
||||
|
||||
Args:
|
||||
weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
|
||||
pretrained weights to use. See
|
||||
:class:`~torchvision.models.MobileNet_V2_Weights` below for
|
||||
more details, and possible values. By default, no pre-trained
|
||||
weights are used.
|
||||
progress (bool, optional): If True, displays a progress bar of the
|
||||
download to stderr. Default is True.
|
||||
**kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
|
||||
base class. Please refer to the `source code
|
||||
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
|
||||
for more details about this class.
|
||||
|
||||
.. autoclass:: torchvision.models.MobileNet_V2_Weights
|
||||
:members:
|
||||
"""
|
||||
weights = MobileNet_V2_Weights.verify(weights)
|
||||
|
||||
if weights is not None:
|
||||
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||||
|
||||
model = MobileNetV2(**kwargs)
|
||||
|
||||
if weights is not None:
|
||||
model.load_state_dict(weights.get_state_dict(progress=progress))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
class MobileNetv2Wrapper(nn.Module):
|
||||
def __init__(self):
|
||||
super(MobileNetv2Wrapper, self).__init__()
|
||||
weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
|
||||
|
||||
self.model = MobileNetV2()
|
||||
|
||||
if weights is not None:
|
||||
self.model.load_state_dict(weights.get_state_dict(progress=True))
|
||||
self.out = conv1x1(
|
||||
1280, 128)
|
||||
|
||||
def forward(self, x):
|
||||
# print(x.shape)
|
||||
x = self.model(x)
|
||||
# print(x.shape)
|
||||
if self.out:
|
||||
x = self.out(x)
|
||||
# print(x.shape)
|
||||
return x
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
from torchvision.transforms._presets import ImageClassification
|
||||
from torchvision.utils import _log_api_usage_once
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
||||
from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar
|
||||
import collections
|
||||
from itertools import repeat
|
||||
M = TypeVar("M", bound=nn.Module)
|
||||
|
||||
BUILTIN_MODELS = {}
|
||||
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
|
||||
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
|
||||
key = name if name is not None else fn.__name__
|
||||
if key in BUILTIN_MODELS:
|
||||
raise ValueError(f"An entry is already registered under the name '{key}'.")
|
||||
BUILTIN_MODELS[key] = fn
|
||||
return fn
|
||||
|
||||
return wrapper
|
||||
|
||||
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||
Otherwise, we will make a tuple of length n, all with value of x.
|
||||
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||
|
||||
Args:
|
||||
x (Any): input value
|
||||
n (int): length of the resulting tuple
|
||||
"""
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
class ConvNormActivation(torch.nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
|
||||
groups: int = 1,
|
||||
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||
inplace: Optional[bool] = True,
|
||||
bias: Optional[bool] = None,
|
||||
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
||||
) -> None:
|
||||
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int) and isinstance(dilation, int):
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
else:
|
||||
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
|
||||
kernel_size = _make_ntuple(kernel_size, _conv_dim)
|
||||
dilation = _make_ntuple(dilation, _conv_dim)
|
||||
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
|
||||
if bias is None:
|
||||
bias = norm_layer is None
|
||||
|
||||
layers = [
|
||||
conv_layer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
]
|
||||
|
||||
if norm_layer is not None:
|
||||
layers.append(norm_layer(out_channels))
|
||||
|
||||
if activation_layer is not None:
|
||||
params = {} if inplace is None else {"inplace": inplace}
|
||||
layers.append(activation_layer(**params))
|
||||
super().__init__(*layers)
|
||||
_log_api_usage_once(self)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if self.__class__ == ConvNormActivation:
|
||||
warnings.warn(
|
||||
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
|
||||
)
|
||||
|
||||
|
||||
class Conv2dNormActivation(ConvNormActivation):
|
||||
"""
|
||||
Configurable block used for Convolution2d-Normalization-Activation blocks.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
||||
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
||||
stride (int, optional): Stride of the convolution. Default: 1
|
||||
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
|
||||
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
|
||||
dilation (int): Spacing between kernel elements. Default: 1
|
||||
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
||||
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Optional[Union[int, Tuple[int, int], str]] = None,
|
||||
groups: int = 1,
|
||||
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
||||
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
inplace: Optional[bool] = True,
|
||||
bias: Optional[bool] = None,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
norm_layer,
|
||||
activation_layer,
|
||||
dilation,
|
||||
inplace,
|
||||
bias,
|
||||
torch.nn.Conv2d,
|
||||
)
|
||||
|
||||
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
|
||||
|
||||
|
||||
# necessary for backwards compatibility
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(
|
||||
self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers: List[nn.Module] = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(
|
||||
Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||
)
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
Conv2dNormActivation(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
stride=stride,
|
||||
groups=hidden_dim,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU6,
|
||||
),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.out_channels = oup
|
||||
self._is_cn = stride > 1
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 1000,
|
||||
width_mult: float = 1.0,
|
||||
inverted_residual_setting: Optional[List[List[int]]] = None,
|
||||
round_nearest: int = 8,
|
||||
block: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
dropout: float = 0.2,
|
||||
) -> None:
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
dropout (float): The droupout probability
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
_log_api_usage_once(self)
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 1],
|
||||
[6, 32, 3, 1],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError(
|
||||
f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
|
||||
)
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features: List[nn.Module] = [
|
||||
Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
|
||||
]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(
|
||||
Conv2dNormActivation(
|
||||
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
|
||||
)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
# self.layer1 = nn.Sequential(*features[:])
|
||||
# self.layer2 = features[57:120]
|
||||
# self.layer3 = features[120:]
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
out_layers = []
|
||||
for layer in self.features.named_modules():
|
||||
for i, layer1 in enumerate(layer[1]):
|
||||
# print(layer1)
|
||||
x = layer1(x)
|
||||
# print("第{}层,输出大小{}".format(i, x.shape))
|
||||
if i in [0, 10, 18]:
|
||||
out_layers.append(x)
|
||||
break
|
||||
# x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1
|
||||
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||
# x = torch.flatten(x, 1)
|
||||
# x = self.classifier(x)
|
||||
return out_layers
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
_COMMON_META = {
|
||||
"num_params": 3504872,
|
||||
"min_size": (1, 1),
|
||||
"categories": _IMAGENET_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
class MobileNet_V2_Weights(WeightsEnum):
|
||||
IMAGENET1K_V1 = Weights(
|
||||
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
transforms=partial(ImageClassification, crop_size=224),
|
||||
meta={
|
||||
**_COMMON_META,
|
||||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
|
||||
"_metrics": {
|
||||
"ImageNet-1K": {
|
||||
"acc@1": 71.878,
|
||||
"acc@5": 90.286,
|
||||
}
|
||||
},
|
||||
"_ops": 0.301,
|
||||
"_file_size": 13.555,
|
||||
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
||||
},
|
||||
)
|
||||
IMAGENET1K_V2 = Weights(
|
||||
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
|
||||
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||||
meta={
|
||||
**_COMMON_META,
|
||||
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
|
||||
"_metrics": {
|
||||
"ImageNet-1K": {
|
||||
"acc@1": 72.154,
|
||||
"acc@5": 90.822,
|
||||
}
|
||||
},
|
||||
"_ops": 0.301,
|
||||
"_file_size": 13.598,
|
||||
"_docs": """
|
||||
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
||||
`new training recipe
|
||||
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
||||
""",
|
||||
},
|
||||
)
|
||||
DEFAULT = IMAGENET1K_V2
|
||||
|
||||
|
||||
# @register_model()
|
||||
# @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
|
||||
def mobilenet_v2(
|
||||
*, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any
|
||||
) -> MobileNetV2:
|
||||
"""MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
|
||||
Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
|
||||
|
||||
Args:
|
||||
weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
|
||||
pretrained weights to use. See
|
||||
:class:`~torchvision.models.MobileNet_V2_Weights` below for
|
||||
more details, and possible values. By default, no pre-trained
|
||||
weights are used.
|
||||
progress (bool, optional): If True, displays a progress bar of the
|
||||
download to stderr. Default is True.
|
||||
**kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
|
||||
base class. Please refer to the `source code
|
||||
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
|
||||
for more details about this class.
|
||||
|
||||
.. autoclass:: torchvision.models.MobileNet_V2_Weights
|
||||
:members:
|
||||
"""
|
||||
weights = MobileNet_V2_Weights.verify(weights)
|
||||
|
||||
if weights is not None:
|
||||
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||||
|
||||
model = MobileNetV2(**kwargs)
|
||||
|
||||
if weights is not None:
|
||||
model.load_state_dict(weights.get_state_dict(progress=progress))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
class MobileNetv2Wrapper(nn.Module):
|
||||
def __init__(self):
|
||||
super(MobileNetv2Wrapper, self).__init__()
|
||||
weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1)
|
||||
|
||||
self.model = MobileNetV2()
|
||||
|
||||
if weights is not None:
|
||||
self.model.load_state_dict(weights.get_state_dict(progress=True))
|
||||
self.out3 = conv1x1(1280, 128)
|
||||
|
||||
def forward(self, x):
|
||||
# print(x.shape)
|
||||
out_layers = self.model(x)
|
||||
# print(x.shape)
|
||||
|
||||
# out_layers[0] = self.out1(out_layers[0])
|
||||
# out_layers[1] = self.out2(out_layers[1])
|
||||
out_layers[2] = self.out3(out_layers[2])
|
||||
# print(x.shape)
|
||||
return out_layers
|
||||
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
from utils import Registry, build_from_cfg
|
||||
|
||||
NET = Registry('net')
|
||||
|
||||
def build(cfg, registry, default_args=None):
|
||||
if isinstance(cfg, list):
|
||||
modules = [
|
||||
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||
]
|
||||
return nn.Sequential(*modules)
|
||||
else:
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_net(cfg):
|
||||
return build(cfg.net, NET, default_args=dict(cfg=cfg))
|
|
@ -0,0 +1,142 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.registry import NET
|
||||
# from .resnet_copy import ResNetWrapper
|
||||
# from .resnet import ResNetWrapper
|
||||
from .decoder_copy2 import BUSD, PlainDecoder
|
||||
# from .decoder import BUSD, PlainDecoder
|
||||
# from .mobilenetv2 import MobileNetv2Wrapper
|
||||
from .mobilenetv2_copy2 import MobileNetv2Wrapper
|
||||
|
||||
|
||||
class RESA(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(RESA, self).__init__()
|
||||
self.iter = cfg.resa.iter
|
||||
chan = cfg.resa.input_channel
|
||||
fea_stride = cfg.backbone.fea_stride
|
||||
self.height = cfg.img_height // fea_stride
|
||||
self.width = cfg.img_width // fea_stride
|
||||
self.alpha = cfg.resa.alpha
|
||||
conv_stride = cfg.resa.conv_stride
|
||||
|
||||
for i in range(self.iter):
|
||||
conv_vert1 = nn.Conv2d(
|
||||
chan, chan, (1, conv_stride),
|
||||
padding=(0, conv_stride//2), groups=1, bias=False)
|
||||
conv_vert2 = nn.Conv2d(
|
||||
chan, chan, (1, conv_stride),
|
||||
padding=(0, conv_stride//2), groups=1, bias=False)
|
||||
|
||||
setattr(self, 'conv_d'+str(i), conv_vert1)
|
||||
setattr(self, 'conv_u'+str(i), conv_vert2)
|
||||
|
||||
conv_hori1 = nn.Conv2d(
|
||||
chan, chan, (conv_stride, 1),
|
||||
padding=(conv_stride//2, 0), groups=1, bias=False)
|
||||
conv_hori2 = nn.Conv2d(
|
||||
chan, chan, (conv_stride, 1),
|
||||
padding=(conv_stride//2, 0), groups=1, bias=False)
|
||||
|
||||
setattr(self, 'conv_r'+str(i), conv_hori1)
|
||||
setattr(self, 'conv_l'+str(i), conv_hori2)
|
||||
|
||||
idx_d = (torch.arange(self.height) + self.height //
|
||||
2**(self.iter - i)) % self.height
|
||||
setattr(self, 'idx_d'+str(i), idx_d)
|
||||
|
||||
idx_u = (torch.arange(self.height) - self.height //
|
||||
2**(self.iter - i)) % self.height
|
||||
setattr(self, 'idx_u'+str(i), idx_u)
|
||||
|
||||
idx_r = (torch.arange(self.width) + self.width //
|
||||
2**(self.iter - i)) % self.width
|
||||
setattr(self, 'idx_r'+str(i), idx_r)
|
||||
|
||||
idx_l = (torch.arange(self.width) - self.width //
|
||||
2**(self.iter - i)) % self.width
|
||||
setattr(self, 'idx_l'+str(i), idx_l)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.clone()
|
||||
|
||||
for direction in ['d', 'u']:
|
||||
for i in range(self.iter):
|
||||
conv = getattr(self, 'conv_' + direction + str(i))
|
||||
idx = getattr(self, 'idx_' + direction + str(i))
|
||||
x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
|
||||
|
||||
for direction in ['r', 'l']:
|
||||
for i in range(self.iter):
|
||||
conv = getattr(self, 'conv_' + direction + str(i))
|
||||
idx = getattr(self, 'idx_' + direction + str(i))
|
||||
x.add_(self.alpha * F.relu(conv(x[..., idx])))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ExistHead(nn.Module):
|
||||
def __init__(self, cfg=None):
|
||||
super(ExistHead, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.dropout = nn.Dropout2d(0.1) # ???
|
||||
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
||||
|
||||
stride = cfg.backbone.fea_stride * 2
|
||||
self.fc9 = nn.Linear(
|
||||
int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
|
||||
self.fc10 = nn.Linear(128, cfg.num_classes-1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.conv8(x)
|
||||
|
||||
x = F.softmax(x, dim=1)
|
||||
x = F.avg_pool2d(x, 2, stride=2, padding=0)
|
||||
x = x.view(-1, x.numel() // x.shape[0])
|
||||
x = self.fc9(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc10(x)
|
||||
x = torch.sigmoid(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@NET.register_module
|
||||
class RESANet(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(RESANet, self).__init__()
|
||||
self.cfg = cfg
|
||||
# self.backbone = ResNetWrapper(resnet='resnet34',pretrained=True,
|
||||
# replace_stride_with_dilation=[False, False, False],
|
||||
# out_conv=False)
|
||||
self.backbone = MobileNetv2Wrapper()
|
||||
self.resa = RESA(cfg)
|
||||
self.decoder = eval(cfg.decoder)(cfg)
|
||||
self.heads = ExistHead(cfg)
|
||||
|
||||
def forward(self, batch):
|
||||
# x1, fea, _, _ = self.backbone(batch)
|
||||
# fea = self.resa(fea)
|
||||
# # print(fea.shape)
|
||||
# seg = self.decoder([x1,fea])
|
||||
# # print(seg.shape)
|
||||
# exist = self.heads(fea)
|
||||
|
||||
fea1,fea2,fea = self.backbone(batch)
|
||||
# print('fea1',fea1.shape)
|
||||
# print('fea2',fea2.shape)
|
||||
# print('fea',fea.shape)
|
||||
fea = self.resa(fea)
|
||||
# print(fea.shape)
|
||||
seg = self.decoder([fea1,fea2,fea])
|
||||
# print(seg.shape)
|
||||
exist = self.heads(fea)
|
||||
|
||||
output = {'seg': seg, 'exist': exist}
|
||||
|
||||
return output
|
|
@ -0,0 +1,377 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
|
||||
# This code is borrow from torchvision.
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
# if dilation > 1:
|
||||
# raise NotImplementedError(
|
||||
# "Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes, dilation=dilation)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetWrapper(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(ResNetWrapper, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.in_channels = [64, 128, 256, 512]
|
||||
if 'in_channels' in cfg.backbone:
|
||||
self.in_channels = cfg.backbone.in_channels
|
||||
self.model = eval(cfg.backbone.resnet)(
|
||||
pretrained=cfg.backbone.pretrained,
|
||||
replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
|
||||
self.out = None
|
||||
if cfg.backbone.out_conv:
|
||||
out_channel = 512
|
||||
for chan in reversed(self.in_channels):
|
||||
if chan < 0: continue
|
||||
out_channel = chan
|
||||
break
|
||||
self.out = conv1x1(
|
||||
out_channel * self.model.expansion, 128)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
if self.out:
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, in_channels=None):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.in_channels = in_channels
|
||||
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
|
||||
self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
if in_channels[3] > 0:
|
||||
self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.expansion = block.expansion
|
||||
|
||||
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
if self.in_channels[3] > 0:
|
||||
x = self.layer4(x)
|
||||
|
||||
# x = self.avgpool(x)
|
||||
# x = torch.flatten(x, 1)
|
||||
# x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
|
@ -0,0 +1,432 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18':
|
||||
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34':
|
||||
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50':
|
||||
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101':
|
||||
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152':
|
||||
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d':
|
||||
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d':
|
||||
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2':
|
||||
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2':
|
||||
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
# if dilation > 1:
|
||||
# raise NotImplementedError(
|
||||
# "Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes, dilation=dilation)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class ResNetWrapper(nn.Module):
|
||||
def __init__(self,
|
||||
resnet='resnet18',
|
||||
pretrained=True,
|
||||
replace_stride_with_dilation=[False, False, False],
|
||||
out_conv=False,
|
||||
fea_stride=8,
|
||||
out_channel=128,
|
||||
in_channels=[64, 128, 256, 512],
|
||||
cfg=None):
|
||||
super(ResNetWrapper, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.model = eval(resnet)(
|
||||
pretrained=pretrained,
|
||||
replace_stride_with_dilation=replace_stride_with_dilation,
|
||||
in_channels=self.in_channels)
|
||||
self.out = None
|
||||
if out_conv:
|
||||
out_channel = 512
|
||||
for chan in reversed(self.in_channels):
|
||||
if chan < 0: continue
|
||||
out_channel = chan
|
||||
break
|
||||
self.out = conv1x1(out_channel * self.model.expansion,
|
||||
cfg.featuremap_out_channel)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
if self.out:
|
||||
x[-1] = self.out(x[-1])
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
zero_init_residual=False,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None,
|
||||
norm_layer=None,
|
||||
in_channels=None):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(
|
||||
replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3,
|
||||
self.inplanes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.in_channels = in_channels
|
||||
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
in_channels[1],
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
in_channels[2],
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
if in_channels[3] > 0:
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
in_channels[3],
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.expansion = block.expansion
|
||||
|
||||
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out_layers = []
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
# out_layers = []
|
||||
for name in ['layer1', 'layer2', 'layer3', 'layer4']:
|
||||
if not hasattr(self, name):
|
||||
continue
|
||||
layer = getattr(self, name)
|
||||
x = layer(x)
|
||||
out_layers.append(x)
|
||||
|
||||
return out_layers
|
||||
|
||||
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
print('pretrained model: ', model_urls[arch])
|
||||
# state_dict = torch.load(model_urls[arch])['net']
|
||||
state_dict = load_state_dict_from_url(model_urls[arch])
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
|
@ -0,0 +1,8 @@
|
|||
pandas
|
||||
addict
|
||||
sklearn
|
||||
opencv-python
|
||||
pytorch_warmup
|
||||
scikit-image
|
||||
tqdm
|
||||
termcolor
|
|
@ -0,0 +1,4 @@
|
|||
from .evaluator import *
|
||||
from .resa_trainer import *
|
||||
|
||||
from .registry import build_evaluator
|
|
@ -0,0 +1,2 @@
|
|||
from .tusimple.tusimple import Tusimple
|
||||
from .culane.culane import CULane
|
|
@ -0,0 +1,158 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from runner.logger import get_logger
|
||||
|
||||
from runner.registry import EVALUATOR
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from shutil import rmtree
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def check():
|
||||
import subprocess
|
||||
import sys
|
||||
FNULL = open(os.devnull, 'w')
|
||||
result = subprocess.call(
|
||||
'./runner/evaluator/culane/lane_evaluation/evaluate', stdout=FNULL, stderr=FNULL)
|
||||
if result > 1:
|
||||
print('There is something wrong with evaluate tool, please compile it.')
|
||||
sys.exit()
|
||||
|
||||
def read_helper(path):
|
||||
lines = open(path, 'r').readlines()[1:]
|
||||
lines = ' '.join(lines)
|
||||
values = lines.split(' ')[1::2]
|
||||
keys = lines.split(' ')[0::2]
|
||||
keys = [key[:-1] for key in keys]
|
||||
res = {k : v for k,v in zip(keys,values)}
|
||||
return res
|
||||
|
||||
def call_culane_eval(data_dir, output_path='./output'):
|
||||
if data_dir[-1] != '/':
|
||||
data_dir = data_dir + '/'
|
||||
detect_dir=os.path.join(output_path, 'lines')+'/'
|
||||
|
||||
w_lane=30
|
||||
iou=0.5; # Set iou to 0.3 or 0.5
|
||||
im_w=1640
|
||||
im_h=590
|
||||
frame=1
|
||||
list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt')
|
||||
list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt')
|
||||
list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt')
|
||||
list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt')
|
||||
list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt')
|
||||
list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt')
|
||||
list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt')
|
||||
list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt')
|
||||
list8 = os.path.join(data_dir,'list/test_split/test8_night.txt')
|
||||
if not os.path.exists(os.path.join(output_path,'txt')):
|
||||
os.mkdir(os.path.join(output_path,'txt'))
|
||||
out0 = os.path.join(output_path,'txt','out0_normal.txt')
|
||||
out1 = os.path.join(output_path,'txt','out1_crowd.txt')
|
||||
out2 = os.path.join(output_path,'txt','out2_hlight.txt')
|
||||
out3 = os.path.join(output_path,'txt','out3_shadow.txt')
|
||||
out4 = os.path.join(output_path,'txt','out4_noline.txt')
|
||||
out5 = os.path.join(output_path,'txt','out5_arrow.txt')
|
||||
out6 = os.path.join(output_path,'txt','out6_curve.txt')
|
||||
out7 = os.path.join(output_path,'txt','out7_cross.txt')
|
||||
out8 = os.path.join(output_path,'txt','out8_night.txt')
|
||||
|
||||
eval_cmd = './runner/evaluator/culane/lane_evaluation/evaluate'
|
||||
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
|
||||
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
|
||||
res_all = {}
|
||||
res_all['normal'] = read_helper(out0)
|
||||
res_all['crowd']= read_helper(out1)
|
||||
res_all['night']= read_helper(out8)
|
||||
res_all['noline'] = read_helper(out4)
|
||||
res_all['shadow'] = read_helper(out3)
|
||||
res_all['arrow']= read_helper(out5)
|
||||
res_all['hlight'] = read_helper(out2)
|
||||
res_all['curve']= read_helper(out6)
|
||||
res_all['cross']= read_helper(out7)
|
||||
return res_all
|
||||
|
||||
@EVALUATOR.register_module
|
||||
class CULane(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(CULane, self).__init__()
|
||||
# Firstly, check the evaluation tool
|
||||
check()
|
||||
self.cfg = cfg
|
||||
self.blur = torch.nn.Conv2d(
|
||||
5, 5, 9, padding=4, bias=False, groups=5).cuda()
|
||||
torch.nn.init.constant_(self.blur.weight, 1 / 81)
|
||||
self.logger = get_logger('resa')
|
||||
self.out_dir = os.path.join(self.cfg.work_dir, 'lines')
|
||||
if cfg.view:
|
||||
self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
|
||||
|
||||
def evaluate(self, dataset, output, batch):
|
||||
seg, exists = output['seg'], output['exist']
|
||||
predictmaps = F.softmax(seg, dim=1).cpu().numpy()
|
||||
exists = exists.cpu().numpy()
|
||||
batch_size = seg.size(0)
|
||||
img_name = batch['meta']['img_name']
|
||||
img_path = batch['meta']['full_img_path']
|
||||
for i in range(batch_size):
|
||||
coords = dataset.probmap2lane(predictmaps[i], exists[i])
|
||||
outname = self.out_dir + img_name[i][:-4] + '.lines.txt'
|
||||
outdir = os.path.dirname(outname)
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
f = open(outname, 'w')
|
||||
for coord in coords:
|
||||
for x, y in coord:
|
||||
if x < 0 and y < 0:
|
||||
continue
|
||||
f.write('%d %d ' % (x, y))
|
||||
f.write('\n')
|
||||
f.close()
|
||||
|
||||
if self.cfg.view:
|
||||
img = cv2.imread(img_path[i]).astype(np.float32)
|
||||
dataset.view(img, coords, self.view_dir+img_name[i])
|
||||
|
||||
|
||||
def summarize(self):
|
||||
self.logger.info('summarize result...')
|
||||
eval_list_path = os.path.join(
|
||||
self.cfg.dataset_path, "list", self.cfg.dataset.val.data_list)
|
||||
#prob2lines(self.prob_dir, self.out_dir, eval_list_path, self.cfg)
|
||||
res = call_culane_eval(self.cfg.dataset_path, output_path=self.cfg.work_dir)
|
||||
TP,FP,FN = 0,0,0
|
||||
out_str = 'Copypaste: '
|
||||
for k, v in res.items():
|
||||
val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
|
||||
val_tp, val_fp, val_fn = int(v['tp']), int(v['fp']), int(v['fn'])
|
||||
val_p, val_r, val_f1 = float(v['precision']), float(v['recall']), float(v['Fmeasure'])
|
||||
TP += val_tp
|
||||
FP += val_fp
|
||||
FN += val_fn
|
||||
self.logger.info(k + ': ' + str(v))
|
||||
out_str += k
|
||||
for metric, value in v.items():
|
||||
out_str += ' ' + str(value).rstrip('\n')
|
||||
out_str += ' '
|
||||
P = TP * 1.0 / (TP + FP + 1e-9)
|
||||
R = TP * 1.0 / (TP + FN + 1e-9)
|
||||
F = 2*P*R/(P + R + 1e-9)
|
||||
overall_result_str = ('Overall Precision: %f Recall: %f F1: %f' % (P, R, F))
|
||||
self.logger.info(overall_result_str)
|
||||
out_str = out_str + overall_result_str
|
||||
self.logger.info(out_str)
|
||||
|
||||
# delete the tmp output
|
||||
rmtree(self.out_dir)
|
|
@ -0,0 +1,2 @@
|
|||
build/
|
||||
evaluate
|
|
@ -0,0 +1,50 @@
|
|||
PROJECT_NAME:= evaluate
|
||||
|
||||
# config ----------------------------------
|
||||
OPENCV_VERSION := 3
|
||||
|
||||
INCLUDE_DIRS := include
|
||||
LIBRARY_DIRS := lib /usr/local/lib
|
||||
|
||||
COMMON_FLAGS := -DCPU_ONLY
|
||||
CXXFLAGS := -std=c++11 -fopenmp
|
||||
LDFLAGS := -fopenmp -Wl,-rpath,./lib
|
||||
BUILD_DIR := build
|
||||
|
||||
|
||||
# make rules -------------------------------
|
||||
CXX ?= g++
|
||||
BUILD_DIR ?= ./build
|
||||
|
||||
LIBRARIES += opencv_core opencv_highgui opencv_imgproc
|
||||
ifeq ($(OPENCV_VERSION), 3)
|
||||
LIBRARIES += opencv_imgcodecs
|
||||
endif
|
||||
|
||||
CXXFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
|
||||
LDFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(LIBRARY_DIRS),-L$(includedir)) $(foreach library,$(LIBRARIES),-l$(library))
|
||||
SRC_DIRS += $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print)
|
||||
CXX_SRCS += $(shell find src/ -name "*.cpp")
|
||||
CXX_TARGETS:=$(patsubst %.cpp, $(BUILD_DIR)/%.o, $(CXX_SRCS))
|
||||
ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS)))
|
||||
|
||||
.PHONY: all
|
||||
all: $(PROJECT_NAME)
|
||||
|
||||
.PHONY: $(ALL_BUILD_DIRS)
|
||||
$(ALL_BUILD_DIRS):
|
||||
@mkdir -p $@
|
||||
|
||||
$(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS)
|
||||
@echo "CXX" $<
|
||||
@$(CXX) $(CXXFLAGS) -c -o $@ $<
|
||||
|
||||
$(PROJECT_NAME): $(CXX_TARGETS)
|
||||
@echo "CXX/LD" $@
|
||||
@$(CXX) -o $@ $^ $(LDFLAGS)
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
@rm -rf $(CXX_TARGETS)
|
||||
@rm -rf $(PROJECT_NAME)
|
||||
@rm -rf $(BUILD_DIR)
|
|
@ -0,0 +1,47 @@
|
|||
#ifndef COUNTER_HPP
|
||||
#define COUNTER_HPP
|
||||
|
||||
#include "lane_compare.hpp"
|
||||
#include "hungarianGraph.hpp"
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <opencv2/core/core.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
// before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp
|
||||
class Counter
|
||||
{
|
||||
public:
|
||||
Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){
|
||||
im_width = _im_width;
|
||||
im_height = _im_height;
|
||||
sim_threshold = _iou_threshold;
|
||||
lane_compare = new LaneCompare(_im_width, _im_height, _lane_width, LaneCompare::IOU);
|
||||
};
|
||||
double get_precision(void);
|
||||
double get_recall(void);
|
||||
long getTP(void);
|
||||
long getFP(void);
|
||||
long getFN(void);
|
||||
void setTP(long);
|
||||
void setFP(long);
|
||||
void setFN(long);
|
||||
// direct add tp, fp, tn and fn
|
||||
// first match with hungarian
|
||||
tuple<vector<int>, long, long, long, long> count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes);
|
||||
void makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2);
|
||||
|
||||
private:
|
||||
double sim_threshold;
|
||||
int im_width;
|
||||
int im_height;
|
||||
long tp;
|
||||
long fp;
|
||||
long fn;
|
||||
LaneCompare *lane_compare;
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,71 @@
|
|||
#ifndef HUNGARIAN_GRAPH_HPP
|
||||
#define HUNGARIAN_GRAPH_HPP
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
struct pipartiteGraph {
|
||||
vector<vector<double> > mat;
|
||||
vector<bool> leftUsed, rightUsed;
|
||||
vector<double> leftWeight, rightWeight;
|
||||
vector<int>rightMatch, leftMatch;
|
||||
int leftNum, rightNum;
|
||||
bool matchDfs(int u) {
|
||||
leftUsed[u] = true;
|
||||
for (int v = 0; v < rightNum; v++) {
|
||||
if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) {
|
||||
rightUsed[v] = true;
|
||||
if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) {
|
||||
rightMatch[v] = u;
|
||||
leftMatch[u] = v;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void resize(int leftNum, int rightNum) {
|
||||
this->leftNum = leftNum;
|
||||
this->rightNum = rightNum;
|
||||
leftMatch.resize(leftNum);
|
||||
rightMatch.resize(rightNum);
|
||||
leftUsed.resize(leftNum);
|
||||
rightUsed.resize(rightNum);
|
||||
leftWeight.resize(leftNum);
|
||||
rightWeight.resize(rightNum);
|
||||
mat.resize(leftNum);
|
||||
for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum);
|
||||
}
|
||||
void match() {
|
||||
for (int i = 0; i < leftNum; i++) leftMatch[i] = -1;
|
||||
for (int i = 0; i < rightNum; i++) rightMatch[i] = -1;
|
||||
for (int i = 0; i < rightNum; i++) rightWeight[i] = 0;
|
||||
for (int i = 0; i < leftNum; i++) {
|
||||
leftWeight[i] = -1e5;
|
||||
for (int j = 0; j < rightNum; j++) {
|
||||
if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (int u = 0; u < leftNum; u++) {
|
||||
while (1) {
|
||||
for (int i = 0; i < leftNum; i++) leftUsed[i] = false;
|
||||
for (int i = 0; i < rightNum; i++) rightUsed[i] = false;
|
||||
if (matchDfs(u)) break;
|
||||
double d = 1e10;
|
||||
for (int i = 0; i < leftNum; i++) {
|
||||
if (leftUsed[i] ) {
|
||||
for (int j = 0; j < rightNum; j++) {
|
||||
if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (d == 1e10) return ;
|
||||
for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d;
|
||||
for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#endif // HUNGARIAN_GRAPH_HPP
|
|
@ -0,0 +1,51 @@
|
|||
#ifndef LANE_COMPARE_HPP
|
||||
#define LANE_COMPARE_HPP
|
||||
|
||||
#include "spline.hpp"
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <opencv2/core/version.hpp>
|
||||
#include <opencv2/core/core.hpp>
|
||||
|
||||
#if CV_VERSION_EPOCH == 2
|
||||
#define OPENCV2
|
||||
#elif CV_VERSION_MAJOR == 3
|
||||
#define OPENCV3
|
||||
#else
|
||||
#error Not support this OpenCV version
|
||||
#endif
|
||||
|
||||
#ifdef OPENCV3
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#elif defined(OPENCV2)
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
class LaneCompare{
|
||||
public:
|
||||
enum CompareMode{
|
||||
IOU,
|
||||
Caltech
|
||||
};
|
||||
|
||||
LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){
|
||||
im_width = _im_width;
|
||||
im_height = _im_height;
|
||||
compare_mode = _compare_mode;
|
||||
lane_width = _lane_width;
|
||||
}
|
||||
|
||||
double get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2);
|
||||
void resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height);
|
||||
private:
|
||||
CompareMode compare_mode;
|
||||
int im_width;
|
||||
int im_height;
|
||||
int lane_width;
|
||||
Spline splineSolver;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,28 @@
|
|||
#ifndef SPLINE_HPP
|
||||
#define SPLINE_HPP
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <math.h>
|
||||
#include <opencv2/core/core.hpp>
|
||||
|
||||
using namespace cv;
|
||||
using namespace std;
|
||||
|
||||
struct Func {
|
||||
double a_x;
|
||||
double b_x;
|
||||
double c_x;
|
||||
double d_x;
|
||||
double a_y;
|
||||
double b_y;
|
||||
double c_y;
|
||||
double d_y;
|
||||
double h;
|
||||
};
|
||||
class Spline {
|
||||
public:
|
||||
vector<Point2f> splineInterpTimes(const vector<Point2f> &tmp_line, int times);
|
||||
vector<Point2f> splineInterpStep(vector<Point2f> tmp_line, double step);
|
||||
vector<Func> cal_fun(const vector<Point2f> &point_v);
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,134 @@
|
|||
/*************************************************************************
|
||||
> File Name: counter.cpp
|
||||
> Author: Xingang Pan, Jun Li
|
||||
> Mail: px117@ie.cuhk.edu.hk
|
||||
> Created Time: Thu Jul 14 20:23:08 2016
|
||||
************************************************************************/
|
||||
|
||||
#include "counter.hpp"
|
||||
|
||||
double Counter::get_precision(void)
|
||||
{
|
||||
cerr<<"tp: "<<tp<<" fp: "<<fp<<" fn: "<<fn<<endl;
|
||||
if(tp+fp == 0)
|
||||
{
|
||||
cerr<<"no positive detection"<<endl;
|
||||
return -1;
|
||||
}
|
||||
return tp/double(tp + fp);
|
||||
}
|
||||
|
||||
double Counter::get_recall(void)
|
||||
{
|
||||
if(tp+fn == 0)
|
||||
{
|
||||
cerr<<"no ground truth positive"<<endl;
|
||||
return -1;
|
||||
}
|
||||
return tp/double(tp + fn);
|
||||
}
|
||||
|
||||
long Counter::getTP(void)
|
||||
{
|
||||
return tp;
|
||||
}
|
||||
|
||||
long Counter::getFP(void)
|
||||
{
|
||||
return fp;
|
||||
}
|
||||
|
||||
long Counter::getFN(void)
|
||||
{
|
||||
return fn;
|
||||
}
|
||||
|
||||
void Counter::setTP(long value)
|
||||
{
|
||||
tp = value;
|
||||
}
|
||||
|
||||
void Counter::setFP(long value)
|
||||
{
|
||||
fp = value;
|
||||
}
|
||||
|
||||
void Counter::setFN(long value)
|
||||
{
|
||||
fn = value;
|
||||
}
|
||||
|
||||
tuple<vector<int>, long, long, long, long> Counter::count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes)
|
||||
{
|
||||
vector<int> anno_match(anno_lanes.size(), -1);
|
||||
vector<int> detect_match;
|
||||
if(anno_lanes.empty())
|
||||
{
|
||||
return make_tuple(anno_match, 0, detect_lanes.size(), 0, 0);
|
||||
}
|
||||
|
||||
if(detect_lanes.empty())
|
||||
{
|
||||
return make_tuple(anno_match, 0, 0, 0, anno_lanes.size());
|
||||
}
|
||||
// hungarian match first
|
||||
|
||||
// first calc similarity matrix
|
||||
vector<vector<double> > similarity(anno_lanes.size(), vector<double>(detect_lanes.size(), 0));
|
||||
for(int i=0; i<anno_lanes.size(); i++)
|
||||
{
|
||||
const vector<Point2f> &curr_anno_lane = anno_lanes[i];
|
||||
for(int j=0; j<detect_lanes.size(); j++)
|
||||
{
|
||||
const vector<Point2f> &curr_detect_lane = detect_lanes[j];
|
||||
similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
makeMatch(similarity, anno_match, detect_match);
|
||||
|
||||
|
||||
int curr_tp = 0;
|
||||
// count and add
|
||||
for(int i=0; i<anno_lanes.size(); i++)
|
||||
{
|
||||
if(anno_match[i]>=0 && similarity[i][anno_match[i]] > sim_threshold)
|
||||
{
|
||||
curr_tp++;
|
||||
}
|
||||
else
|
||||
{
|
||||
anno_match[i] = -1;
|
||||
}
|
||||
}
|
||||
int curr_fn = anno_lanes.size() - curr_tp;
|
||||
int curr_fp = detect_lanes.size() - curr_tp;
|
||||
return make_tuple(anno_match, curr_tp, curr_fp, 0, curr_fn);
|
||||
}
|
||||
|
||||
|
||||
void Counter::makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2) {
|
||||
int m = similarity.size();
|
||||
int n = similarity[0].size();
|
||||
pipartiteGraph gra;
|
||||
bool have_exchange = false;
|
||||
if (m > n) {
|
||||
have_exchange = true;
|
||||
swap(m, n);
|
||||
}
|
||||
gra.resize(m, n);
|
||||
for (int i = 0; i < gra.leftNum; i++) {
|
||||
for (int j = 0; j < gra.rightNum; j++) {
|
||||
if(have_exchange)
|
||||
gra.mat[i][j] = similarity[j][i];
|
||||
else
|
||||
gra.mat[i][j] = similarity[i][j];
|
||||
}
|
||||
}
|
||||
gra.match();
|
||||
match1 = gra.leftMatch;
|
||||
match2 = gra.rightMatch;
|
||||
if (have_exchange) swap(match1, match2);
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/*************************************************************************
|
||||
> File Name: evaluate.cpp
|
||||
> Author: Xingang Pan, Jun Li
|
||||
> Mail: px117@ie.cuhk.edu.hk
|
||||
> Created Time: 2016年07月14日 星期四 18时28分45秒
|
||||
************************************************************************/
|
||||
|
||||
#include "counter.hpp"
|
||||
#include "spline.hpp"
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <opencv2/core/core.hpp>
|
||||
#include <opencv2/highgui/highgui.hpp>
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
void help(void) {
|
||||
cout << "./evaluate [OPTIONS]" << endl;
|
||||
cout << "-h : print usage help" << endl;
|
||||
cout << "-a : directory for annotation files (default: "
|
||||
"/data/driving/eval_data/anno_label/)" << endl;
|
||||
cout << "-d : directory for detection files (default: "
|
||||
"/data/driving/eval_data/predict_label/)" << endl;
|
||||
cout << "-i : directory for image files (default: "
|
||||
"/data/driving/eval_data/img/)" << endl;
|
||||
cout << "-l : list of images used for evaluation (default: "
|
||||
"/data/driving/eval_data/img/all.txt)" << endl;
|
||||
cout << "-w : width of the lanes (default: 10)" << endl;
|
||||
cout << "-t : threshold of iou (default: 0.4)" << endl;
|
||||
cout << "-c : cols (max image width) (default: 1920)"
|
||||
<< endl;
|
||||
cout << "-r : rows (max image height) (default: 1080)"
|
||||
<< endl;
|
||||
cout << "-s : show visualization" << endl;
|
||||
cout << "-f : start frame in the test set (default: 1)"
|
||||
<< endl;
|
||||
}
|
||||
|
||||
void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes);
|
||||
void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
|
||||
vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
|
||||
int width_lane, string save_path = "");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// process params
|
||||
string anno_dir = "/data/driving/eval_data/anno_label/";
|
||||
string detect_dir = "/data/driving/eval_data/predict_label/";
|
||||
string im_dir = "/data/driving/eval_data/img/";
|
||||
string list_im_file = "/data/driving/eval_data/img/all.txt";
|
||||
string output_file = "./output.txt";
|
||||
int width_lane = 10;
|
||||
double iou_threshold = 0.4;
|
||||
int im_width = 1920;
|
||||
int im_height = 1080;
|
||||
int oc;
|
||||
bool show = false;
|
||||
int frame = 1;
|
||||
string save_path = "";
|
||||
while ((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:p:")) != -1) {
|
||||
switch (oc) {
|
||||
case 'h':
|
||||
help();
|
||||
return 0;
|
||||
case 'a':
|
||||
anno_dir = optarg;
|
||||
break;
|
||||
case 'd':
|
||||
detect_dir = optarg;
|
||||
break;
|
||||
case 'i':
|
||||
im_dir = optarg;
|
||||
break;
|
||||
case 'l':
|
||||
list_im_file = optarg;
|
||||
break;
|
||||
case 'w':
|
||||
width_lane = atoi(optarg);
|
||||
break;
|
||||
case 't':
|
||||
iou_threshold = atof(optarg);
|
||||
break;
|
||||
case 'c':
|
||||
im_width = atoi(optarg);
|
||||
break;
|
||||
case 'r':
|
||||
im_height = atoi(optarg);
|
||||
break;
|
||||
case 's':
|
||||
show = true;
|
||||
break;
|
||||
case 'p':
|
||||
save_path = optarg;
|
||||
break;
|
||||
case 'f':
|
||||
frame = atoi(optarg);
|
||||
break;
|
||||
case 'o':
|
||||
output_file = optarg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cout << "------------Configuration---------" << endl;
|
||||
cout << "anno_dir: " << anno_dir << endl;
|
||||
cout << "detect_dir: " << detect_dir << endl;
|
||||
cout << "im_dir: " << im_dir << endl;
|
||||
cout << "list_im_file: " << list_im_file << endl;
|
||||
cout << "width_lane: " << width_lane << endl;
|
||||
cout << "iou_threshold: " << iou_threshold << endl;
|
||||
cout << "im_width: " << im_width << endl;
|
||||
cout << "im_height: " << im_height << endl;
|
||||
cout << "-----------------------------------" << endl;
|
||||
cout << "Evaluating the results..." << endl;
|
||||
// this is the max_width and max_height
|
||||
|
||||
if (width_lane < 1) {
|
||||
cerr << "width_lane must be positive" << endl;
|
||||
help();
|
||||
return 1;
|
||||
}
|
||||
|
||||
ifstream ifs_im_list(list_im_file, ios::in);
|
||||
if (ifs_im_list.fail()) {
|
||||
cerr << "Error: file " << list_im_file << " not exist!" << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
Counter counter(im_width, im_height, iou_threshold, width_lane);
|
||||
|
||||
vector<int> anno_match;
|
||||
string sub_im_name;
|
||||
// pre-load filelist
|
||||
vector<string> filelists;
|
||||
while (getline(ifs_im_list, sub_im_name)) {
|
||||
filelists.push_back(sub_im_name);
|
||||
}
|
||||
ifs_im_list.close();
|
||||
|
||||
vector<tuple<vector<int>, long, long, long, long>> tuple_lists;
|
||||
tuple_lists.resize(filelists.size());
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < filelists.size(); i++) {
|
||||
auto sub_im_name = filelists[i];
|
||||
string full_im_name = im_dir + sub_im_name;
|
||||
string sub_txt_name =
|
||||
sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt";
|
||||
string anno_file_name = anno_dir + sub_txt_name;
|
||||
string detect_file_name = detect_dir + sub_txt_name;
|
||||
vector<vector<Point2f>> anno_lanes;
|
||||
vector<vector<Point2f>> detect_lanes;
|
||||
read_lane_file(anno_file_name, anno_lanes);
|
||||
read_lane_file(detect_file_name, detect_lanes);
|
||||
// cerr<<count<<": "<<full_im_name<<endl;
|
||||
tuple_lists[i] = counter.count_im_pair(anno_lanes, detect_lanes);
|
||||
if (show) {
|
||||
auto anno_match = get<0>(tuple_lists[i]);
|
||||
visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane);
|
||||
waitKey(0);
|
||||
}
|
||||
if (save_path != "") {
|
||||
auto anno_match = get<0>(tuple_lists[i]);
|
||||
visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane,
|
||||
save_path);
|
||||
}
|
||||
}
|
||||
|
||||
long tp = 0, fp = 0, tn = 0, fn = 0;
|
||||
for (auto result : tuple_lists) {
|
||||
tp += get<1>(result);
|
||||
fp += get<2>(result);
|
||||
// tn = get<3>(result);
|
||||
fn += get<4>(result);
|
||||
}
|
||||
counter.setTP(tp);
|
||||
counter.setFP(fp);
|
||||
counter.setFN(fn);
|
||||
|
||||
double precision = counter.get_precision();
|
||||
double recall = counter.get_recall();
|
||||
double F = 2 * precision * recall / (precision + recall);
|
||||
cerr << "finished process file" << endl;
|
||||
cout << "precision: " << precision << endl;
|
||||
cout << "recall: " << recall << endl;
|
||||
cout << "Fmeasure: " << F << endl;
|
||||
cout << "----------------------------------" << endl;
|
||||
|
||||
ofstream ofs_out_file;
|
||||
ofs_out_file.open(output_file, ios::out);
|
||||
ofs_out_file << "file: " << output_file << endl;
|
||||
ofs_out_file << "tp: " << counter.getTP() << " fp: " << counter.getFP()
|
||||
<< " fn: " << counter.getFN() << endl;
|
||||
ofs_out_file << "precision: " << precision << endl;
|
||||
ofs_out_file << "recall: " << recall << endl;
|
||||
ofs_out_file << "Fmeasure: " << F << endl << endl;
|
||||
ofs_out_file.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes) {
|
||||
lanes.clear();
|
||||
ifstream ifs_lane(file_name, ios::in);
|
||||
if (ifs_lane.fail()) {
|
||||
return;
|
||||
}
|
||||
|
||||
string str_line;
|
||||
while (getline(ifs_lane, str_line)) {
|
||||
vector<Point2f> curr_lane;
|
||||
stringstream ss;
|
||||
ss << str_line;
|
||||
double x, y;
|
||||
while (ss >> x >> y) {
|
||||
curr_lane.push_back(Point2f(x, y));
|
||||
}
|
||||
lanes.push_back(curr_lane);
|
||||
}
|
||||
|
||||
ifs_lane.close();
|
||||
}
|
||||
|
||||
void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
|
||||
vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
|
||||
int width_lane, string save_path) {
|
||||
Mat img = imread(full_im_name, 1);
|
||||
Mat img2 = imread(full_im_name, 1);
|
||||
vector<Point2f> curr_lane;
|
||||
vector<Point2f> p_interp;
|
||||
Spline splineSolver;
|
||||
Scalar color_B = Scalar(255, 0, 0);
|
||||
Scalar color_G = Scalar(0, 255, 0);
|
||||
Scalar color_R = Scalar(0, 0, 255);
|
||||
Scalar color_P = Scalar(255, 0, 255);
|
||||
Scalar color;
|
||||
for (int i = 0; i < anno_lanes.size(); i++) {
|
||||
curr_lane = anno_lanes[i];
|
||||
if (curr_lane.size() == 2) {
|
||||
p_interp = curr_lane;
|
||||
} else {
|
||||
p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
|
||||
}
|
||||
if (anno_match[i] >= 0) {
|
||||
color = color_G;
|
||||
} else {
|
||||
color = color_G;
|
||||
}
|
||||
for (int n = 0; n < p_interp.size() - 1; n++) {
|
||||
line(img, p_interp[n], p_interp[n + 1], color, width_lane);
|
||||
line(img2, p_interp[n], p_interp[n + 1], color, 2);
|
||||
}
|
||||
}
|
||||
bool detected;
|
||||
for (int i = 0; i < detect_lanes.size(); i++) {
|
||||
detected = false;
|
||||
curr_lane = detect_lanes[i];
|
||||
if (curr_lane.size() == 2) {
|
||||
p_interp = curr_lane;
|
||||
} else {
|
||||
p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
|
||||
}
|
||||
for (int n = 0; n < anno_lanes.size(); n++) {
|
||||
if (anno_match[n] == i) {
|
||||
detected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (detected == true) {
|
||||
color = color_B;
|
||||
} else {
|
||||
color = color_R;
|
||||
}
|
||||
for (int n = 0; n < p_interp.size() - 1; n++) {
|
||||
line(img, p_interp[n], p_interp[n + 1], color, width_lane);
|
||||
line(img2, p_interp[n], p_interp[n + 1], color, 2);
|
||||
}
|
||||
}
|
||||
if (save_path != "") {
|
||||
size_t pos = 0;
|
||||
string s = full_im_name;
|
||||
std::string token;
|
||||
std::string delimiter = "/";
|
||||
vector<string> names;
|
||||
while ((pos = s.find(delimiter)) != std::string::npos) {
|
||||
token = s.substr(0, pos);
|
||||
names.emplace_back(token);
|
||||
s.erase(0, pos + delimiter.length());
|
||||
}
|
||||
names.emplace_back(s);
|
||||
string file_name = names[3] + '_' + names[4] + '_' + names[5];
|
||||
// cout << file_name << endl;
|
||||
imwrite(save_path + '/' + file_name, img);
|
||||
} else {
|
||||
namedWindow("visualize", 1);
|
||||
imshow("visualize", img);
|
||||
namedWindow("visualize2", 1);
|
||||
imshow("visualize2", img2);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
/*************************************************************************
|
||||
> File Name: lane_compare.cpp
|
||||
> Author: Xingang Pan, Jun Li
|
||||
> Mail: px117@ie.cuhk.edu.hk
|
||||
> Created Time: Fri Jul 15 10:26:32 2016
|
||||
************************************************************************/
|
||||
|
||||
#include "lane_compare.hpp"
|
||||
|
||||
double LaneCompare::get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2)
|
||||
{
|
||||
if(lane1.size()<2 || lane2.size()<2)
|
||||
{
|
||||
cerr<<"lane size must be greater or equal to 2"<<endl;
|
||||
return 0;
|
||||
}
|
||||
Mat im1 = Mat::zeros(im_height, im_width, CV_8UC1);
|
||||
Mat im2 = Mat::zeros(im_height, im_width, CV_8UC1);
|
||||
// draw lines on im1 and im2
|
||||
vector<Point2f> p_interp1;
|
||||
vector<Point2f> p_interp2;
|
||||
if(lane1.size() == 2)
|
||||
{
|
||||
p_interp1 = lane1;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_interp1 = splineSolver.splineInterpTimes(lane1, 50);
|
||||
}
|
||||
|
||||
if(lane2.size() == 2)
|
||||
{
|
||||
p_interp2 = lane2;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_interp2 = splineSolver.splineInterpTimes(lane2, 50);
|
||||
}
|
||||
|
||||
Scalar color_white = Scalar(1);
|
||||
for(int n=0; n<p_interp1.size()-1; n++)
|
||||
{
|
||||
line(im1, p_interp1[n], p_interp1[n+1], color_white, lane_width);
|
||||
}
|
||||
for(int n=0; n<p_interp2.size()-1; n++)
|
||||
{
|
||||
line(im2, p_interp2[n], p_interp2[n+1], color_white, lane_width);
|
||||
}
|
||||
|
||||
double sum_1 = cv::sum(im1).val[0];
|
||||
double sum_2 = cv::sum(im2).val[0];
|
||||
double inter_sum = cv::sum(im1.mul(im2)).val[0];
|
||||
double union_sum = sum_1 + sum_2 - inter_sum;
|
||||
double iou = inter_sum / union_sum;
|
||||
return iou;
|
||||
}
|
||||
|
||||
|
||||
// resize the lane from Size(curr_width, curr_height) to Size(im_width, im_height)
|
||||
void LaneCompare::resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height)
|
||||
{
|
||||
if(curr_width == im_width && curr_height == im_height)
|
||||
{
|
||||
return;
|
||||
}
|
||||
double x_scale = im_width/(double)curr_width;
|
||||
double y_scale = im_height/(double)curr_height;
|
||||
for(int n=0; n<curr_lane.size(); n++)
|
||||
{
|
||||
curr_lane[n] = Point2f(curr_lane[n].x*x_scale, curr_lane[n].y*y_scale);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "spline.hpp"
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
vector<Point2f> Spline::splineInterpTimes(const vector<Point2f>& tmp_line, int times) {
|
||||
vector<Point2f> res;
|
||||
|
||||
if(tmp_line.size() == 2) {
|
||||
double x1 = tmp_line[0].x;
|
||||
double y1 = tmp_line[0].y;
|
||||
double x2 = tmp_line[1].x;
|
||||
double y2 = tmp_line[1].y;
|
||||
|
||||
for (int k = 0; k <= times; k++) {
|
||||
double xi = x1 + double((x2 - x1) * k) / times;
|
||||
double yi = y1 + double((y2 - y1) * k) / times;
|
||||
res.push_back(Point2f(xi, yi));
|
||||
}
|
||||
}
|
||||
|
||||
else if(tmp_line.size() > 2)
|
||||
{
|
||||
vector<Func> tmp_func;
|
||||
tmp_func = this->cal_fun(tmp_line);
|
||||
if (tmp_func.empty()) {
|
||||
cout << "in splineInterpTimes: cal_fun failed" << endl;
|
||||
return res;
|
||||
}
|
||||
for(int j = 0; j < tmp_func.size(); j++)
|
||||
{
|
||||
double delta = tmp_func[j].h / times;
|
||||
for(int k = 0; k < times; k++)
|
||||
{
|
||||
double t1 = delta*k;
|
||||
double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
|
||||
double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
|
||||
res.push_back(Point2f(x1, y1));
|
||||
}
|
||||
}
|
||||
res.push_back(tmp_line[tmp_line.size() - 1]);
|
||||
}
|
||||
else {
|
||||
cerr << "in splineInterpTimes: not enough points" << endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
vector<Point2f> Spline::splineInterpStep(vector<Point2f> tmp_line, double step) {
|
||||
vector<Point2f> res;
|
||||
/*
|
||||
if (tmp_line.size() == 2) {
|
||||
double x1 = tmp_line[0].x;
|
||||
double y1 = tmp_line[0].y;
|
||||
double x2 = tmp_line[1].x;
|
||||
double y2 = tmp_line[1].y;
|
||||
|
||||
for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) {
|
||||
double xi;
|
||||
if (yi == y1) xi = x1;
|
||||
else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1;
|
||||
res.push_back(Point2f(xi, yi));
|
||||
}
|
||||
}*/
|
||||
if (tmp_line.size() == 2) {
|
||||
double x1 = tmp_line[0].x;
|
||||
double y1 = tmp_line[0].y;
|
||||
double x2 = tmp_line[1].x;
|
||||
double y2 = tmp_line[1].y;
|
||||
tmp_line[1].x = (x1 + x2) / 2;
|
||||
tmp_line[1].y = (y1 + y2) / 2;
|
||||
tmp_line.push_back(Point2f(x2, y2));
|
||||
}
|
||||
if (tmp_line.size() > 2) {
|
||||
vector<Func> tmp_func;
|
||||
tmp_func = this->cal_fun(tmp_line);
|
||||
double ystart = tmp_line[0].y;
|
||||
double yend = tmp_line[tmp_line.size() - 1].y;
|
||||
bool down;
|
||||
if (ystart < yend) down = 1;
|
||||
else down = 0;
|
||||
if (tmp_func.empty()) {
|
||||
cerr << "in splineInterpStep: cal_fun failed" << endl;
|
||||
}
|
||||
|
||||
for(int j = 0; j < tmp_func.size(); j++)
|
||||
{
|
||||
for(double t1 = 0; t1 < tmp_func[j].h; t1 += step)
|
||||
{
|
||||
double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
|
||||
double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
|
||||
res.push_back(Point2f(x1, y1));
|
||||
}
|
||||
}
|
||||
res.push_back(tmp_line[tmp_line.size() - 1]);
|
||||
}
|
||||
else {
|
||||
cerr << "in splineInterpStep: not enough points" << endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
vector<Func> Spline::cal_fun(const vector<Point2f> &point_v)
|
||||
{
|
||||
vector<Func> func_v;
|
||||
int n = point_v.size();
|
||||
if(n<=2) {
|
||||
cout << "in cal_fun: point number less than 3" << endl;
|
||||
return func_v;
|
||||
}
|
||||
|
||||
func_v.resize(point_v.size()-1);
|
||||
|
||||
vector<double> Mx(n);
|
||||
vector<double> My(n);
|
||||
vector<double> A(n-2);
|
||||
vector<double> B(n-2);
|
||||
vector<double> C(n-2);
|
||||
vector<double> Dx(n-2);
|
||||
vector<double> Dy(n-2);
|
||||
vector<double> h(n-1);
|
||||
//vector<func> func_v(n-1);
|
||||
|
||||
for(int i = 0; i < n-1; i++)
|
||||
{
|
||||
h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2));
|
||||
}
|
||||
|
||||
for(int i = 0; i < n-2; i++)
|
||||
{
|
||||
A[i] = h[i];
|
||||
B[i] = 2*(h[i]+h[i+1]);
|
||||
C[i] = h[i+1];
|
||||
|
||||
Dx[i] = 6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] );
|
||||
Dy[i] = 6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] );
|
||||
}
|
||||
|
||||
//TDMA
|
||||
C[0] = C[0] / B[0];
|
||||
Dx[0] = Dx[0] / B[0];
|
||||
Dy[0] = Dy[0] / B[0];
|
||||
for(int i = 1; i < n-2; i++)
|
||||
{
|
||||
double tmp = B[i] - A[i]*C[i-1];
|
||||
C[i] = C[i] / tmp;
|
||||
Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp;
|
||||
Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp;
|
||||
}
|
||||
Mx[n-2] = Dx[n-3];
|
||||
My[n-2] = Dy[n-3];
|
||||
for(int i = n-4; i >= 0; i--)
|
||||
{
|
||||
Mx[i+1] = Dx[i] - C[i]*Mx[i+2];
|
||||
My[i+1] = Dy[i] - C[i]*My[i+2];
|
||||
}
|
||||
|
||||
Mx[0] = 0;
|
||||
Mx[n-1] = 0;
|
||||
My[0] = 0;
|
||||
My[n-1] = 0;
|
||||
|
||||
for(int i = 0; i < n-1; i++)
|
||||
{
|
||||
func_v[i].a_x = point_v[i].x;
|
||||
func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6;
|
||||
func_v[i].c_x = Mx[i]/2;
|
||||
func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]);
|
||||
|
||||
func_v[i].a_y = point_v[i].y;
|
||||
func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6;
|
||||
func_v[i].c_y = My[i]/2;
|
||||
func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]);
|
||||
|
||||
func_v[i].h = h[i];
|
||||
}
|
||||
return func_v;
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import tqdm
|
||||
|
||||
|
||||
def getLane(probmap, pts, cfg = None):
|
||||
thr = 0.3
|
||||
coordinate = np.zeros(pts)
|
||||
cut_height = 0
|
||||
if cfg.cut_height:
|
||||
cut_height = cfg.cut_height
|
||||
for i in range(pts):
|
||||
line = probmap[round(cfg.img_height-i*20/(590-cut_height)*cfg.img_height)-1]
|
||||
if np.max(line)/255 > thr:
|
||||
coordinate[i] = np.argmax(line)+1
|
||||
if np.sum(coordinate > 0) < 2:
|
||||
coordinate = np.zeros(pts)
|
||||
return coordinate
|
||||
|
||||
|
||||
def prob2lines(prob_dir, out_dir, list_file, cfg = None):
|
||||
lists = pd.read_csv(list_file, sep=' ', header=None,
|
||||
names=('img', 'probmap', 'label1', 'label2', 'label3', 'label4'))
|
||||
pts = 18
|
||||
|
||||
for k, im in enumerate(lists['img'], 1):
|
||||
existPath = prob_dir + im[:-4] + '.exist.txt'
|
||||
outname = out_dir + im[:-4] + '.lines.txt'
|
||||
prefix = '/'.join(outname.split('/')[:-1])
|
||||
if not os.path.exists(prefix):
|
||||
os.makedirs(prefix)
|
||||
f = open(outname, 'w')
|
||||
|
||||
labels = list(pd.read_csv(existPath, sep=' ', header=None).iloc[0])
|
||||
coordinates = np.zeros((4, pts))
|
||||
for i in range(4):
|
||||
if labels[i] == 1:
|
||||
probfile = prob_dir + im[:-4] + '_{0}_avg.png'.format(i+1)
|
||||
probmap = np.array(Image.open(probfile))
|
||||
coordinates[i] = getLane(probmap, pts, cfg)
|
||||
|
||||
if np.sum(coordinates[i] > 0) > 1:
|
||||
for idx, value in enumerate(coordinates[i]):
|
||||
if value > 0:
|
||||
f.write('%d %d ' % (
|
||||
round(value*1640/cfg.img_width)-1, round(590-idx*20)-1))
|
||||
f.write('\n')
|
||||
f.close()
|
|
@ -0,0 +1,115 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def isShort(lane):
|
||||
start = [i for i, x in enumerate(lane) if x > 0]
|
||||
if not start:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def fixGap(coordinate):
|
||||
if any(x > 0 for x in coordinate):
|
||||
start = [i for i, x in enumerate(coordinate) if x > 0][0]
|
||||
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
|
||||
lane = coordinate[start:end+1]
|
||||
if any(x < 0 for x in lane):
|
||||
gap_start = [i for i, x in enumerate(
|
||||
lane[:-1]) if x > 0 and lane[i+1] < 0]
|
||||
gap_end = [i+1 for i,
|
||||
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
|
||||
gap_id = [i for i, x in enumerate(lane) if x < 0]
|
||||
if len(gap_start) == 0 or len(gap_end) == 0:
|
||||
return coordinate
|
||||
for id in gap_id:
|
||||
for i in range(len(gap_start)):
|
||||
if i >= len(gap_end):
|
||||
return coordinate
|
||||
if id > gap_start[i] and id < gap_end[i]:
|
||||
gap_width = float(gap_end[i] - gap_start[i])
|
||||
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
|
||||
gap_end[i] - id) / gap_width * lane[gap_start[i]])
|
||||
if not all(x > 0 for x in lane):
|
||||
print("Gaps still exist!")
|
||||
coordinate[start:end+1] = lane
|
||||
return coordinate
|
||||
|
||||
def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None):
|
||||
"""
|
||||
Arguments:
|
||||
----------
|
||||
prob_map: prob map for single lane, np array size (h, w)
|
||||
resize_shape: reshape size target, (H, W)
|
||||
|
||||
Return:
|
||||
----------
|
||||
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
|
||||
"""
|
||||
if resize_shape is None:
|
||||
resize_shape = prob_map.shape
|
||||
h, w = prob_map.shape
|
||||
H, W = resize_shape
|
||||
H -= cfg.cut_height
|
||||
|
||||
coords = np.zeros(pts)
|
||||
coords[:] = -1.0
|
||||
for i in range(pts):
|
||||
y = int((H - 10 - i * y_px_gap) * h / H)
|
||||
if y < 0:
|
||||
break
|
||||
line = prob_map[y, :]
|
||||
id = np.argmax(line)
|
||||
if line[id] > thresh:
|
||||
coords[i] = int(id / w * W)
|
||||
if (coords > 0).sum() < 2:
|
||||
coords = np.zeros(pts)
|
||||
fixGap(coords)
|
||||
return coords
|
||||
|
||||
|
||||
def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None):
|
||||
"""
|
||||
Arguments:
|
||||
----------
|
||||
seg_pred: np.array size (5, h, w)
|
||||
resize_shape: reshape size target, (H, W)
|
||||
exist: list of existence, e.g. [0, 1, 1, 0]
|
||||
smooth: whether to smooth the probability or not
|
||||
y_px_gap: y pixel gap for sampling
|
||||
pts: how many points for one lane
|
||||
thresh: probability threshold
|
||||
|
||||
Return:
|
||||
----------
|
||||
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
|
||||
"""
|
||||
if resize_shape is None:
|
||||
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
|
||||
_, h, w = seg_pred.shape
|
||||
H, W = resize_shape
|
||||
coordinates = []
|
||||
|
||||
if pts is None:
|
||||
pts = round(H / 2 / y_px_gap)
|
||||
|
||||
seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0)))
|
||||
for i in range(cfg.num_classes - 1):
|
||||
prob_map = seg_pred[..., i + 1]
|
||||
if smooth:
|
||||
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
||||
coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg)
|
||||
if isShort(coords):
|
||||
continue
|
||||
coordinates.append(
|
||||
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||
range(pts)])
|
||||
|
||||
|
||||
if len(coordinates) == 0:
|
||||
coords = np.zeros(pts)
|
||||
coordinates.append(
|
||||
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
||||
range(pts)])
|
||||
|
||||
|
||||
return coordinates
|
|
@ -0,0 +1,108 @@
|
|||
import numpy as np
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import json as json
|
||||
|
||||
|
||||
class LaneEval(object):
|
||||
lr = LinearRegression()
|
||||
pixel_thresh = 20
|
||||
pt_thresh = 0.85
|
||||
|
||||
@staticmethod
|
||||
def get_angle(xs, y_samples):
|
||||
xs, ys = xs[xs >= 0], y_samples[xs >= 0]
|
||||
if len(xs) > 1:
|
||||
LaneEval.lr.fit(ys[:, None], xs)
|
||||
k = LaneEval.lr.coef_[0]
|
||||
theta = np.arctan(k)
|
||||
else:
|
||||
theta = 0
|
||||
return theta
|
||||
|
||||
@staticmethod
|
||||
def line_accuracy(pred, gt, thresh):
|
||||
pred = np.array([p if p >= 0 else -100 for p in pred])
|
||||
gt = np.array([g if g >= 0 else -100 for g in gt])
|
||||
return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)
|
||||
|
||||
@staticmethod
|
||||
def bench(pred, gt, y_samples, running_time):
|
||||
if any(len(p) != len(y_samples) for p in pred):
|
||||
raise Exception('Format of lanes error.')
|
||||
if running_time > 200 or len(gt) + 2 < len(pred):
|
||||
return 0., 0., 1.
|
||||
angles = [LaneEval.get_angle(
|
||||
np.array(x_gts), np.array(y_samples)) for x_gts in gt]
|
||||
threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
|
||||
line_accs = []
|
||||
fp, fn = 0., 0.
|
||||
matched = 0.
|
||||
for x_gts, thresh in zip(gt, threshs):
|
||||
accs = [LaneEval.line_accuracy(
|
||||
np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
|
||||
max_acc = np.max(accs) if len(accs) > 0 else 0.
|
||||
if max_acc < LaneEval.pt_thresh:
|
||||
fn += 1
|
||||
else:
|
||||
matched += 1
|
||||
line_accs.append(max_acc)
|
||||
fp = len(pred) - matched
|
||||
if len(gt) > 4 and fn > 0:
|
||||
fn -= 1
|
||||
s = sum(line_accs)
|
||||
if len(gt) > 4:
|
||||
s -= min(line_accs)
|
||||
return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)
|
||||
|
||||
@staticmethod
|
||||
def bench_one_submit(pred_file, gt_file):
|
||||
try:
|
||||
json_pred = [json.loads(line)
|
||||
for line in open(pred_file).readlines()]
|
||||
except BaseException as e:
|
||||
raise Exception('Fail to load json file of the prediction.')
|
||||
json_gt = [json.loads(line) for line in open(gt_file).readlines()]
|
||||
if len(json_gt) != len(json_pred):
|
||||
raise Exception(
|
||||
'We do not get the predictions of all the test tasks')
|
||||
gts = {l['raw_file']: l for l in json_gt}
|
||||
accuracy, fp, fn = 0., 0., 0.
|
||||
for pred in json_pred:
|
||||
if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
|
||||
raise Exception(
|
||||
'raw_file or lanes or run_time not in some predictions.')
|
||||
raw_file = pred['raw_file']
|
||||
pred_lanes = pred['lanes']
|
||||
run_time = pred['run_time']
|
||||
if raw_file not in gts:
|
||||
raise Exception(
|
||||
'Some raw_file from your predictions do not exist in the test tasks.')
|
||||
gt = gts[raw_file]
|
||||
gt_lanes = gt['lanes']
|
||||
y_samples = gt['h_samples']
|
||||
try:
|
||||
a, p, n = LaneEval.bench(
|
||||
pred_lanes, gt_lanes, y_samples, run_time)
|
||||
except BaseException as e:
|
||||
raise Exception('Format of lanes error.')
|
||||
accuracy += a
|
||||
fp += p
|
||||
fn += n
|
||||
num = len(gts)
|
||||
# the first return parameter is the default ranking parameter
|
||||
return json.dumps([
|
||||
{'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'},
|
||||
{'name': 'FP', 'value': fp / num, 'order': 'asc'},
|
||||
{'name': 'FN', 'value': fn / num, 'order': 'asc'}
|
||||
]), accuracy / num
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
try:
|
||||
if len(sys.argv) != 3:
|
||||
raise Exception('Invalid input arguments')
|
||||
print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
|
||||
except Exception as e:
|
||||
print(e.message)
|
||||
sys.exit(e.message)
|
|
@ -0,0 +1,111 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from runner.logger import get_logger
|
||||
|
||||
from runner.registry import EVALUATOR
|
||||
import json
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from .lane import LaneEval
|
||||
|
||||
def split_path(path):
|
||||
"""split path tree into list"""
|
||||
folders = []
|
||||
while True:
|
||||
path, folder = os.path.split(path)
|
||||
if folder != "":
|
||||
folders.insert(0, folder)
|
||||
else:
|
||||
if path != "":
|
||||
folders.insert(0, path)
|
||||
break
|
||||
return folders
|
||||
|
||||
|
||||
@EVALUATOR.register_module
|
||||
class Tusimple(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(Tusimple, self).__init__()
|
||||
self.cfg = cfg
|
||||
exp_dir = os.path.join(self.cfg.work_dir, "output")
|
||||
if not os.path.exists(exp_dir):
|
||||
os.mkdir(exp_dir)
|
||||
self.out_path = os.path.join(exp_dir, "coord_output")
|
||||
if not os.path.exists(self.out_path):
|
||||
os.mkdir(self.out_path)
|
||||
self.dump_to_json = []
|
||||
self.thresh = cfg.evaluator.thresh
|
||||
self.logger = get_logger('resa')
|
||||
if cfg.view:
|
||||
self.view_dir = os.path.join(self.cfg.work_dir, 'vis')
|
||||
|
||||
def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
|
||||
img_name = batch['meta']['img_name']
|
||||
img_path = batch['meta']['full_img_path']
|
||||
for b in range(len(seg_pred)):
|
||||
seg = seg_pred[b]
|
||||
exist = [1 if exist_pred[b, i] >
|
||||
0.5 else 0 for i in range(self.cfg.num_classes-1)]
|
||||
lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
|
||||
for i in range(len(lane_coords)):
|
||||
lane_coords[i] = sorted(
|
||||
lane_coords[i], key=lambda pair: pair[1])
|
||||
|
||||
path_tree = split_path(img_name[b])
|
||||
save_dir, save_name = path_tree[-3:-1], path_tree[-1]
|
||||
save_dir = os.path.join(self.out_path, *save_dir)
|
||||
save_name = save_name[:-3] + "lines.txt"
|
||||
save_name = os.path.join(save_dir, save_name)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
with open(save_name, "w") as f:
|
||||
for l in lane_coords:
|
||||
for (x, y) in l:
|
||||
print("{} {}".format(x, y), end=" ", file=f)
|
||||
print(file=f)
|
||||
|
||||
json_dict = {}
|
||||
json_dict['lanes'] = []
|
||||
json_dict['h_sample'] = []
|
||||
json_dict['raw_file'] = os.path.join(*path_tree[-4:])
|
||||
json_dict['run_time'] = 0
|
||||
for l in lane_coords:
|
||||
if len(l) == 0:
|
||||
continue
|
||||
json_dict['lanes'].append([])
|
||||
for (x, y) in l:
|
||||
json_dict['lanes'][-1].append(int(x))
|
||||
for (x, y) in lane_coords[0]:
|
||||
json_dict['h_sample'].append(y)
|
||||
self.dump_to_json.append(json.dumps(json_dict))
|
||||
if self.cfg.view:
|
||||
img = cv2.imread(img_path[b])
|
||||
new_img_name = img_name[b].replace('/', '_')
|
||||
save_dir = os.path.join(self.view_dir, new_img_name)
|
||||
dataset.view(img, lane_coords, save_dir)
|
||||
|
||||
|
||||
def evaluate(self, dataset, output, batch):
|
||||
seg_pred, exist_pred = output['seg'], output['exist']
|
||||
seg_pred = F.softmax(seg_pred, dim=1)
|
||||
seg_pred = seg_pred.detach().cpu().numpy()
|
||||
exist_pred = exist_pred.detach().cpu().numpy()
|
||||
self.evaluate_pred(dataset, seg_pred, exist_pred, batch)
|
||||
|
||||
def summarize(self):
|
||||
best_acc = 0
|
||||
output_file = os.path.join(self.out_path, 'predict_test.json')
|
||||
with open(output_file, "w+") as f:
|
||||
for line in self.dump_to_json:
|
||||
print(line, end="\n", file=f)
|
||||
|
||||
eval_result, acc = LaneEval.bench_one_submit(output_file,
|
||||
self.cfg.test_json_file)
|
||||
|
||||
self.logger.info(eval_result)
|
||||
self.dump_to_json = []
|
||||
best_acc = max(acc, best_acc)
|
||||
return best_acc
|
|
@ -0,0 +1,50 @@
|
|||
import logging
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
def get_logger(name, log_file=None, log_level=logging.INFO):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
||||
will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
# handle hierarchical names
|
||||
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
||||
# initialization since it is a child of "a".
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
handlers = [stream_handler]
|
||||
|
||||
if log_file is not None:
|
||||
file_handler = logging.FileHandler(log_file, 'w')
|
||||
handlers.append(file_handler)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(log_level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.setLevel(log_level)
|
||||
|
||||
logger_initialized[name] = True
|
||||
|
||||
return logger
|
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import os
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torch.nn.functional
|
||||
from termcolor import colored
|
||||
from .logger import get_logger
|
||||
|
||||
def save_model(net, optim, scheduler, recorder, is_best=False):
|
||||
model_dir = os.path.join(recorder.work_dir, 'ckpt')
|
||||
os.system('mkdir -p {}'.format(model_dir))
|
||||
epoch = recorder.epoch
|
||||
ckpt_name = 'best' if is_best else epoch
|
||||
torch.save({
|
||||
'net': net.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'recorder': recorder.state_dict(),
|
||||
'epoch': epoch
|
||||
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
|
||||
|
||||
|
||||
def load_network_specified(net, model_dir, logger=None):
|
||||
pretrained_net = torch.load(model_dir)['net']
|
||||
net_state = net.state_dict()
|
||||
state = {}
|
||||
for k, v in pretrained_net.items():
|
||||
if k not in net_state.keys() or v.size() != net_state[k].size():
|
||||
if logger:
|
||||
logger.info('skip weights: ' + k)
|
||||
continue
|
||||
state[k] = v
|
||||
net.load_state_dict(state, strict=False)
|
||||
|
||||
|
||||
def load_network(net, model_dir, finetune_from=None, logger=None):
|
||||
if finetune_from:
|
||||
if logger:
|
||||
logger.info('Finetune model from: ' + finetune_from)
|
||||
load_network_specified(net, finetune_from, logger)
|
||||
return
|
||||
pretrained_model = torch.load(model_dir)
|
||||
net.load_state_dict(pretrained_model['net'], strict=True)
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
|
||||
|
||||
_optimizer_factory = {
|
||||
'adam': torch.optim.Adam,
|
||||
'sgd': torch.optim.SGD
|
||||
}
|
||||
|
||||
|
||||
def build_optimizer(cfg, net):
|
||||
params = []
|
||||
lr = cfg.optimizer.lr
|
||||
weight_decay = cfg.optimizer.weight_decay
|
||||
|
||||
for key, value in net.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
|
||||
if 'adam' in cfg.optimizer.type:
|
||||
optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
|
||||
else:
|
||||
optimizer = _optimizer_factory[cfg.optimizer.type](
|
||||
params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)
|
||||
|
||||
return optimizer
|
|
@ -0,0 +1,100 @@
|
|||
from collections import deque, defaultdict
|
||||
import torch
|
||||
import os
|
||||
import datetime
|
||||
from .logger import get_logger
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20):
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def update(self, value):
|
||||
self.deque.append(value)
|
||||
self.count += 1
|
||||
self.total += value
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.work_dir = self.get_work_dir()
|
||||
cfg.work_dir = self.work_dir
|
||||
self.log_path = os.path.join(self.work_dir, 'log.txt')
|
||||
|
||||
self.logger = get_logger('resa', self.log_path)
|
||||
self.logger.info('Config: \n' + cfg.text)
|
||||
|
||||
# scalars
|
||||
self.epoch = 0
|
||||
self.step = 0
|
||||
self.loss_stats = defaultdict(SmoothedValue)
|
||||
self.batch_time = SmoothedValue()
|
||||
self.data_time = SmoothedValue()
|
||||
self.max_iter = self.cfg.total_iter
|
||||
self.lr = 0.
|
||||
|
||||
def get_work_dir(self):
|
||||
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)
|
||||
work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
return work_dir
|
||||
|
||||
def update_loss_stats(self, loss_dict):
|
||||
for k, v in loss_dict.items():
|
||||
self.loss_stats[k].update(v.detach().cpu())
|
||||
|
||||
def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
|
||||
self.logger.info(self)
|
||||
# self.write(str(self))
|
||||
|
||||
def write(self, content):
|
||||
with open(self.log_path, 'a+') as f:
|
||||
f.write(content)
|
||||
f.write('\n')
|
||||
|
||||
def state_dict(self):
|
||||
scalar_dict = {}
|
||||
scalar_dict['step'] = self.step
|
||||
return scalar_dict
|
||||
|
||||
def load_state_dict(self, scalar_dict):
|
||||
self.step = scalar_dict['step']
|
||||
|
||||
def __str__(self):
|
||||
loss_state = []
|
||||
for k, v in self.loss_stats.items():
|
||||
loss_state.append('{}: {:.4f}'.format(k, v.avg))
|
||||
loss_state = ' '.join(loss_state)
|
||||
|
||||
recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
|
||||
eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)
|
||||
|
||||
|
||||
def build_recorder(cfg):
|
||||
return Recorder(cfg)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from utils import Registry, build_from_cfg
|
||||
|
||||
TRAINER = Registry('trainer')
|
||||
EVALUATOR = Registry('evaluator')
|
||||
|
||||
def build(cfg, registry, default_args=None):
|
||||
if isinstance(cfg, list):
|
||||
modules = [
|
||||
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||
]
|
||||
return nn.Sequential(*modules)
|
||||
else:
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
def build_trainer(cfg):
|
||||
return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))
|
||||
|
||||
def build_evaluator(cfg):
|
||||
return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))
|
|
@ -0,0 +1,58 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from runner.registry import TRAINER
|
||||
|
||||
def dice_loss(input, target):
|
||||
input = input.contiguous().view(input.size()[0], -1)
|
||||
target = target.contiguous().view(target.size()[0], -1).float()
|
||||
|
||||
a = torch.sum(input * target, 1)
|
||||
b = torch.sum(input * input, 1) + 0.001
|
||||
c = torch.sum(target * target, 1) + 0.001
|
||||
d = (2 * a) / (b + c)
|
||||
return (1-d).mean()
|
||||
|
||||
@TRAINER.register_module
|
||||
class RESA(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(RESA, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.loss_type = cfg.loss_type
|
||||
if self.loss_type == 'cross_entropy':
|
||||
weights = torch.ones(cfg.num_classes)
|
||||
weights[0] = cfg.bg_weight
|
||||
weights = weights.cuda()
|
||||
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
|
||||
weight=weights).cuda()
|
||||
|
||||
self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
|
||||
|
||||
def forward(self, net, batch):
|
||||
output = net(batch['img'])
|
||||
|
||||
loss_stats = {}
|
||||
loss = 0.
|
||||
|
||||
if self.loss_type == 'dice_loss':
|
||||
target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
|
||||
seg_loss = dice_loss(F.softmax(
|
||||
output['seg'], dim=1)[:, 1:], target[:, 1:])
|
||||
else:
|
||||
seg_loss = self.criterion(F.log_softmax(
|
||||
output['seg'], dim=1), batch['label'].long())
|
||||
|
||||
loss += seg_loss * self.cfg.seg_loss_weight
|
||||
|
||||
loss_stats.update({'seg_loss': seg_loss})
|
||||
|
||||
if 'exist' in output:
|
||||
exist_loss = 0.1 * \
|
||||
self.criterion_exist(output['exist'], batch['exist'].float())
|
||||
loss += exist_loss
|
||||
loss_stats.update({'exist_loss': exist_loss})
|
||||
|
||||
ret = {'loss': loss, 'loss_stats': loss_stats}
|
||||
|
||||
return ret
|
|
@ -0,0 +1,112 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pytorch_warmup as warmup
|
||||
|
||||
from models.registry import build_net
|
||||
from .registry import build_trainer, build_evaluator
|
||||
from .optimizer import build_optimizer
|
||||
from .scheduler import build_scheduler
|
||||
from datasets import build_dataloader
|
||||
from .recorder import build_recorder
|
||||
from .net_utils import save_model, load_network
|
||||
|
||||
|
||||
class Runner(object):
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.recorder = build_recorder(self.cfg)
|
||||
self.net = build_net(self.cfg)
|
||||
self.net = torch.nn.parallel.DataParallel(
|
||||
self.net, device_ids = range(self.cfg.gpus)).cuda()
|
||||
self.recorder.logger.info('Network: \n' + str(self.net))
|
||||
self.resume()
|
||||
self.optimizer = build_optimizer(self.cfg, self.net)
|
||||
self.scheduler = build_scheduler(self.cfg, self.optimizer)
|
||||
self.evaluator = build_evaluator(self.cfg)
|
||||
self.warmup_scheduler = warmup.LinearWarmup(
|
||||
self.optimizer, warmup_period=5000)
|
||||
self.metric = 0.
|
||||
|
||||
def resume(self):
|
||||
if not self.cfg.load_from and not self.cfg.finetune_from:
|
||||
return
|
||||
load_network(self.net, self.cfg.load_from,
|
||||
finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)
|
||||
|
||||
def to_cuda(self, batch):
|
||||
for k in batch:
|
||||
if k == 'meta':
|
||||
continue
|
||||
batch[k] = batch[k].cuda()
|
||||
return batch
|
||||
|
||||
def train_epoch(self, epoch, train_loader):
|
||||
self.net.train()
|
||||
end = time.time()
|
||||
max_iter = len(train_loader)
|
||||
for i, data in enumerate(train_loader):
|
||||
if self.recorder.step >= self.cfg.total_iter:
|
||||
break
|
||||
date_time = time.time() - end
|
||||
self.recorder.step += 1
|
||||
data = self.to_cuda(data)
|
||||
output = self.trainer.forward(self.net, data)
|
||||
self.optimizer.zero_grad()
|
||||
loss = output['loss']
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
self.warmup_scheduler.dampen()
|
||||
batch_time = time.time() - end
|
||||
end = time.time()
|
||||
self.recorder.update_loss_stats(output['loss_stats'])
|
||||
self.recorder.batch_time.update(batch_time)
|
||||
self.recorder.data_time.update(date_time)
|
||||
|
||||
if i % self.cfg.log_interval == 0 or i == max_iter - 1:
|
||||
lr = self.optimizer.param_groups[0]['lr']
|
||||
self.recorder.lr = lr
|
||||
self.recorder.record('train')
|
||||
|
||||
def train(self):
|
||||
self.recorder.logger.info('start training...')
|
||||
self.trainer = build_trainer(self.cfg)
|
||||
train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
|
||||
val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)
|
||||
|
||||
for epoch in range(self.cfg.epochs):
|
||||
print('Epoch: [{}/{}]'.format(self.recorder.step, self.cfg.total_iter))
|
||||
print('Epoch: [{}/{}]'.format(epoch, self.cfg.epochs))
|
||||
self.recorder.epoch = epoch
|
||||
self.train_epoch(epoch, train_loader)
|
||||
if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
|
||||
self.save_ckpt()
|
||||
if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
|
||||
self.validate(val_loader)
|
||||
if self.recorder.step >= self.cfg.total_iter:
|
||||
break
|
||||
|
||||
def validate(self, val_loader):
|
||||
self.net.eval()
|
||||
count = 10
|
||||
for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
|
||||
start_time = time.time()
|
||||
data = self.to_cuda(data)
|
||||
with torch.no_grad():
|
||||
output = self.net(data['img'])
|
||||
self.evaluator.evaluate(val_loader.dataset, output, data)
|
||||
# print("第{}张图片检测花了{}秒".format(i,time.time()-start_time))
|
||||
|
||||
metric = self.evaluator.summarize()
|
||||
if not metric:
|
||||
return
|
||||
if metric > self.metric:
|
||||
self.metric = metric
|
||||
self.save_ckpt(is_best=True)
|
||||
self.recorder.logger.info('Best metric: ' + str(self.metric))
|
||||
|
||||
def save_ckpt(self, is_best=False):
|
||||
save_model(self.net, self.optimizer, self.scheduler,
|
||||
self.recorder, is_best)
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
|
||||
_scheduler_factory = {
|
||||
'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
|
||||
}
|
||||
|
||||
|
||||
def build_scheduler(cfg, optimizer):
|
||||
|
||||
assert cfg.scheduler.type in _scheduler_factory
|
||||
|
||||
cfg_cp = cfg.scheduler.copy()
|
||||
cfg_cp.pop('type')
|
||||
|
||||
scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)
|
||||
|
||||
|
||||
return scheduler
|
|
@ -0,0 +1,105 @@
|
|||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import argparse
|
||||
|
||||
TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json']
|
||||
VAL_SET = ['label_data_0531.json']
|
||||
TRAIN_VAL_SET = TRAIN_SET + VAL_SET
|
||||
TEST_SET = ['test_label.json']
|
||||
|
||||
def gen_label_for_json(args, image_set):
|
||||
H, W = 720, 1280
|
||||
SEG_WIDTH = 30
|
||||
save_dir = args.savedir
|
||||
|
||||
os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True)
|
||||
list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w")
|
||||
|
||||
json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set))
|
||||
with open(json_path) as f:
|
||||
for line in f:
|
||||
label = json.loads(line)
|
||||
# ---------- clean and sort lanes -------------
|
||||
lanes = []
|
||||
_lanes = []
|
||||
slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope
|
||||
for i in range(len(label['lanes'])):
|
||||
l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0]
|
||||
if (len(l)>1):
|
||||
_lanes.append(l)
|
||||
slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180)
|
||||
_lanes = [_lanes[i] for i in np.argsort(slope)]
|
||||
slope = [slope[i] for i in np.argsort(slope)]
|
||||
|
||||
idx = [None for i in range(6)]
|
||||
for i in range(len(slope)):
|
||||
if slope[i] <= 90:
|
||||
idx[2] = i
|
||||
idx[1] = i-1 if i > 0 else None
|
||||
idx[0] = i-2 if i > 1 else None
|
||||
else:
|
||||
idx[3] = i
|
||||
idx[4] = i+1 if i+1 < len(slope) else None
|
||||
idx[5] = i+2 if i+2 < len(slope) else None
|
||||
break
|
||||
for i in range(6):
|
||||
lanes.append([] if idx[i] is None else _lanes[idx[i]])
|
||||
|
||||
# ---------------------------------------------
|
||||
|
||||
img_path = label['raw_file']
|
||||
seg_img = np.zeros((H, W, 3))
|
||||
list_str = [] # str to be written to list.txt
|
||||
for i in range(len(lanes)):
|
||||
coords = lanes[i]
|
||||
if len(coords) < 4:
|
||||
list_str.append('0')
|
||||
continue
|
||||
for j in range(len(coords)-1):
|
||||
cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2)
|
||||
list_str.append('1')
|
||||
|
||||
seg_path = img_path.split("/")
|
||||
seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3]
|
||||
os.makedirs(seg_path, exist_ok=True)
|
||||
seg_path = os.path.join(seg_path, img_name[:-3]+"png")
|
||||
cv2.imwrite(seg_path, seg_img)
|
||||
|
||||
seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"])
|
||||
if seg_path[0] != '/':
|
||||
seg_path = '/' + seg_path
|
||||
if img_path[0] != '/':
|
||||
img_path = '/' + img_path
|
||||
list_str.insert(0, seg_path)
|
||||
list_str.insert(0, img_path)
|
||||
list_str = " ".join(list_str) + "\n"
|
||||
list_f.write(list_str)
|
||||
|
||||
|
||||
def generate_json_file(save_dir, json_file, image_set):
|
||||
with open(os.path.join(save_dir, json_file), "w") as outfile:
|
||||
for json_name in (image_set):
|
||||
with open(os.path.join(args.root, json_name)) as infile:
|
||||
for line in infile:
|
||||
outfile.write(line)
|
||||
|
||||
def generate_label(args):
|
||||
save_dir = os.path.join(args.root, args.savedir)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET)
|
||||
generate_json_file(save_dir, "test.json", TEST_SET)
|
||||
|
||||
print("generating train_val set...")
|
||||
gen_label_for_json(args, 'train_val')
|
||||
print("generating test set...")
|
||||
gen_label_for_json(args, 'test')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', required=True, help='The root of the Tusimple dataset')
|
||||
parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset')
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_label(args)
|
|
@ -0,0 +1,2 @@
|
|||
from .config import Config
|
||||
from .registry import Registry, build_from_cfg
|
|
@ -0,0 +1,417 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import ast
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import Action, ArgumentParser
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
|
||||
from addict import Dict
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
|
||||
BASE_KEY = '_base_'
|
||||
DELETE_KEY = '_delete_'
|
||||
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
|
||||
|
||||
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
|
||||
class ConfigDict(Dict):
|
||||
|
||||
def __missing__(self, name):
|
||||
raise KeyError(name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
value = super(ConfigDict, self).__getattr__(name)
|
||||
except KeyError:
|
||||
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
||||
f"attribute '{name}'")
|
||||
except Exception as e:
|
||||
ex = e
|
||||
else:
|
||||
return value
|
||||
raise ex
|
||||
|
||||
|
||||
def add_args(parser, cfg, prefix=''):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str):
|
||||
parser.add_argument('--' + prefix + k)
|
||||
elif isinstance(v, int):
|
||||
parser.add_argument('--' + prefix + k, type=int)
|
||||
elif isinstance(v, float):
|
||||
parser.add_argument('--' + prefix + k, type=float)
|
||||
elif isinstance(v, bool):
|
||||
parser.add_argument('--' + prefix + k, action='store_true')
|
||||
elif isinstance(v, dict):
|
||||
add_args(parser, v, prefix + k + '.')
|
||||
elif isinstance(v, abc.Iterable):
|
||||
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
|
||||
else:
|
||||
print(f'cannot parse key {prefix + k} of type {type(v)}')
|
||||
return parser
|
||||
|
||||
|
||||
class Config:
|
||||
"""A facility for config and config files.
|
||||
It supports common file formats as configs: python/json/yaml. The interface
|
||||
is the same as a dict object and also allows access config values as
|
||||
attributes.
|
||||
Example:
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
>>> cfg.a
|
||||
1
|
||||
>>> cfg.b
|
||||
{'b1': [0, 1]}
|
||||
>>> cfg.b.b1
|
||||
[0, 1]
|
||||
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
||||
>>> cfg.filename
|
||||
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
||||
>>> cfg.item4
|
||||
'test'
|
||||
>>> cfg
|
||||
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
||||
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_py_syntax(filename):
|
||||
with open(filename) as f:
|
||||
content = f.read()
|
||||
try:
|
||||
ast.parse(content)
|
||||
except SyntaxError:
|
||||
raise SyntaxError('There are syntax errors in config '
|
||||
f'file {filename}')
|
||||
|
||||
@staticmethod
|
||||
def _file2dict(filename):
|
||||
filename = osp.abspath(osp.expanduser(filename))
|
||||
check_file_exist(filename)
|
||||
if filename.endswith('.py'):
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
dir=temp_config_dir, suffix='.py')
|
||||
temp_config_name = osp.basename(temp_config_file.name)
|
||||
shutil.copyfile(filename,
|
||||
osp.join(temp_config_dir, temp_config_name))
|
||||
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||
sys.path.insert(0, temp_config_dir)
|
||||
Config._validate_py_syntax(filename)
|
||||
mod = import_module(temp_module_name)
|
||||
sys.path.pop(0)
|
||||
cfg_dict = {
|
||||
name: value
|
||||
for name, value in mod.__dict__.items()
|
||||
if not name.startswith('__')
|
||||
}
|
||||
# delete imported module
|
||||
del sys.modules[temp_module_name]
|
||||
# close temp file
|
||||
temp_config_file.close()
|
||||
elif filename.endswith(('.yml', '.yaml', '.json')):
|
||||
import mmcv
|
||||
cfg_dict = mmcv.load(filename)
|
||||
else:
|
||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||
|
||||
cfg_text = filename + '\n'
|
||||
with open(filename, 'r') as f:
|
||||
cfg_text += f.read()
|
||||
|
||||
if BASE_KEY in cfg_dict:
|
||||
cfg_dir = osp.dirname(filename)
|
||||
base_filename = cfg_dict.pop(BASE_KEY)
|
||||
base_filename = base_filename if isinstance(
|
||||
base_filename, list) else [base_filename]
|
||||
|
||||
cfg_dict_list = list()
|
||||
cfg_text_list = list()
|
||||
for f in base_filename:
|
||||
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
||||
cfg_dict_list.append(_cfg_dict)
|
||||
cfg_text_list.append(_cfg_text)
|
||||
|
||||
base_cfg_dict = dict()
|
||||
for c in cfg_dict_list:
|
||||
if len(base_cfg_dict.keys() & c.keys()) > 0:
|
||||
raise KeyError('Duplicate key is not allowed among bases')
|
||||
base_cfg_dict.update(c)
|
||||
|
||||
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||
cfg_dict = base_cfg_dict
|
||||
|
||||
# merge cfg_text
|
||||
cfg_text_list.append(cfg_text)
|
||||
cfg_text = '\n'.join(cfg_text_list)
|
||||
|
||||
return cfg_dict, cfg_text
|
||||
|
||||
@staticmethod
|
||||
def _merge_a_into_b(a, b):
|
||||
# merge dict `a` into dict `b` (non-inplace). values in `a` will
|
||||
# overwrite `b`.
|
||||
# copy first to avoid inplace modification
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||
if not isinstance(b[k], dict):
|
||||
raise TypeError(
|
||||
f'{k}={v} in child config cannot inherit from base '
|
||||
f'because {k} is a dict in the child config but is of '
|
||||
f'type {type(b[k])} in base config. You may set '
|
||||
f'`{DELETE_KEY}=True` to ignore the base config')
|
||||
b[k] = Config._merge_a_into_b(v, b[k])
|
||||
else:
|
||||
b[k] = v
|
||||
return b
|
||||
|
||||
@staticmethod
|
||||
def fromfile(filename):
|
||||
cfg_dict, cfg_text = Config._file2dict(filename)
|
||||
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def auto_argparser(description=None):
|
||||
"""Generate argparser from config file automatically (experimental)
|
||||
"""
|
||||
partial_parser = ArgumentParser(description=description)
|
||||
partial_parser.add_argument('config', help='config file path')
|
||||
cfg_file = partial_parser.parse_known_args()[0].config
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
parser = ArgumentParser(description=description)
|
||||
parser.add_argument('config', help='config file path')
|
||||
add_args(parser, cfg)
|
||||
return parser, cfg
|
||||
|
||||
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
elif not isinstance(cfg_dict, dict):
|
||||
raise TypeError('cfg_dict must be a dict, but '
|
||||
f'got {type(cfg_dict)}')
|
||||
for key in cfg_dict:
|
||||
if key in RESERVED_KEYS:
|
||||
raise KeyError(f'{key} is reserved for config file')
|
||||
|
||||
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||
super(Config, self).__setattr__('_filename', filename)
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
with open(filename, 'r') as f:
|
||||
text = f.read()
|
||||
else:
|
||||
text = ''
|
||||
super(Config, self).__setattr__('_text', text)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def pretty_text(self):
|
||||
|
||||
indent = 4
|
||||
|
||||
def _indent(s_, num_spaces):
|
||||
s = s_.split('\n')
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(num_spaces * ' ') + line for line in s]
|
||||
s = '\n'.join(s)
|
||||
s = first + '\n' + s
|
||||
return s
|
||||
|
||||
def _format_basic_types(k, v, use_mapping=False):
|
||||
if isinstance(v, str):
|
||||
v_str = f"'{v}'"
|
||||
else:
|
||||
v_str = str(v)
|
||||
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent)
|
||||
|
||||
return attr_str
|
||||
|
||||
def _format_list(k, v, use_mapping=False):
|
||||
# check if all items in the list are dict
|
||||
if all(isinstance(_, dict) for _ in v):
|
||||
v_str = '[\n'
|
||||
v_str += '\n'.join(
|
||||
f'dict({_indent(_format_dict(v_), indent)}),'
|
||||
for v_ in v).rstrip(',')
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ']'
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping)
|
||||
return attr_str
|
||||
|
||||
def _contain_invalid_identifier(dict_str):
|
||||
contain_invalid_identifier = False
|
||||
for key_name in dict_str:
|
||||
contain_invalid_identifier |= \
|
||||
(not str(key_name).isidentifier())
|
||||
return contain_invalid_identifier
|
||||
|
||||
def _format_dict(input_dict, outest_level=False):
|
||||
r = ''
|
||||
s = []
|
||||
|
||||
use_mapping = _contain_invalid_identifier(input_dict)
|
||||
if use_mapping:
|
||||
r += '{'
|
||||
for idx, (k, v) in enumerate(input_dict.items()):
|
||||
is_last = idx >= len(input_dict) - 1
|
||||
end = '' if outest_level or is_last else ','
|
||||
if isinstance(v, dict):
|
||||
v_str = '\n' + _format_dict(v)
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: dict({v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}=dict({v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ')' + end
|
||||
elif isinstance(v, list):
|
||||
attr_str = _format_list(k, v, use_mapping) + end
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||
|
||||
s.append(attr_str)
|
||||
r += '\n'.join(s)
|
||||
if use_mapping:
|
||||
r += '}'
|
||||
return r
|
||||
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
text = _format_dict(cfg_dict, outest_level=True)
|
||||
# copied from setup.cfg
|
||||
yapf_style = dict(
|
||||
based_on_style='pep8',
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True)
|
||||
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||
|
||||
return text
|
||||
|
||||
def __repr__(self):
|
||||
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cfg_dict)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cfg_dict, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._cfg_dict.__getitem__(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setitem__(name, value)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._cfg_dict)
|
||||
|
||||
def dump(self, file=None):
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||
if self.filename.endswith('.py'):
|
||||
if file is None:
|
||||
return self.pretty_text
|
||||
else:
|
||||
with open(file, 'w') as f:
|
||||
f.write(self.pretty_text)
|
||||
else:
|
||||
import mmcv
|
||||
if file is None:
|
||||
file_format = self.filename.split('.')[-1]
|
||||
return mmcv.dump(cfg_dict, file_format=file_format)
|
||||
else:
|
||||
mmcv.dump(cfg_dict, file)
|
||||
|
||||
def merge_from_dict(self, options):
|
||||
"""Merge list into cfg_dict
|
||||
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||
Examples:
|
||||
>>> options = {'model.backbone.depth': 50,
|
||||
... 'model.backbone.with_cp':True}
|
||||
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
|
||||
>>> cfg.merge_from_dict(options)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(
|
||||
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
||||
Args:
|
||||
options (dict): dict of configs to merge from.
|
||||
"""
|
||||
option_cfg_dict = {}
|
||||
for full_key, v in options.items():
|
||||
d = option_cfg_dict
|
||||
key_list = full_key.split('.')
|
||||
for subkey in key_list[:-1]:
|
||||
d.setdefault(subkey, ConfigDict())
|
||||
d = d[subkey]
|
||||
subkey = key_list[-1]
|
||||
d[subkey] = v
|
||||
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
super(Config, self).__setattr__(
|
||||
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options should
|
||||
be passed as comma separated values, i.e KEY=V1,V2,V3
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_int_float_bool(val):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
pass
|
||||
if val.lower() in ['true', 'false']:
|
||||
return True if val.lower() == 'true' else False
|
||||
return val
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
options = {}
|
||||
for kv in values:
|
||||
key, val = kv.split('=', maxsplit=1)
|
||||
val = [self._parse_int_float_bool(v) for v in val.split(',')]
|
||||
if len(val) == 1:
|
||||
val = val[0]
|
||||
options[key] = val
|
||||
setattr(namespace, self.dest, options)
|
|
@ -0,0 +1,81 @@
|
|||
import inspect
|
||||
|
||||
import six
|
||||
|
||||
# borrow from mmdetection
|
||||
|
||||
def is_str(x):
|
||||
"""Whether the input is an string instance."""
|
||||
return isinstance(x, six.string_types)
|
||||
|
||||
class Registry(object):
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
self._module_dict = dict()
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
|
||||
self._name, list(self._module_dict.keys()))
|
||||
return format_str
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
def get(self, key):
|
||||
return self._module_dict.get(key, None)
|
||||
|
||||
def _register_module(self, module_class):
|
||||
"""Register a module.
|
||||
|
||||
Args:
|
||||
module (:obj:`nn.Module`): Module to be registered.
|
||||
"""
|
||||
if not inspect.isclass(module_class):
|
||||
raise TypeError('module must be a class, but got {}'.format(
|
||||
type(module_class)))
|
||||
module_name = module_class.__name__
|
||||
if module_name in self._module_dict:
|
||||
raise KeyError('{} is already registered in {}'.format(
|
||||
module_name, self.name))
|
||||
self._module_dict[module_name] = module_class
|
||||
|
||||
def register_module(self, cls):
|
||||
self._register_module(cls)
|
||||
return cls
|
||||
|
||||
|
||||
def build_from_cfg(cfg, registry, default_args=None):
|
||||
"""Build a module from config dict.
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
obj: The constructed object.
|
||||
"""
|
||||
assert isinstance(cfg, dict) and 'type' in cfg
|
||||
assert isinstance(default_args, dict) or default_args is None
|
||||
args = {}
|
||||
obj_type = cfg.type
|
||||
if is_str(obj_type):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError('{} is not in the {} registry'.format(
|
||||
obj_type, registry.name))
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError('type must be a str or valid type, but got {}'.format(
|
||||
type(obj_type)))
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_cls(**args)
|
|
@ -0,0 +1,357 @@
|
|||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numbers
|
||||
import collections
|
||||
|
||||
# copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py
|
||||
|
||||
__all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad',
|
||||
'GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize']
|
||||
|
||||
|
||||
class SampleResize(object):
|
||||
def __init__(self, size):
|
||||
assert (isinstance(size, collections.Iterable) and len(size) == 2)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, sample):
|
||||
out = list()
|
||||
out.append(cv2.resize(sample[0], self.size,
|
||||
interpolation=cv2.INTER_CUBIC))
|
||||
if len(sample) > 1:
|
||||
out.append(cv2.resize(sample[1], self.size,
|
||||
interpolation=cv2.INTER_NEAREST))
|
||||
return out
|
||||
|
||||
|
||||
class GroupRandomCrop(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img_group):
|
||||
h, w = img_group[0].shape[0:2]
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = random.randint(0, max(0, h - th))
|
||||
w1 = random.randint(0, max(0, w - tw))
|
||||
h2 = min(h1 + th, h)
|
||||
w2 = min(w1 + tw, w)
|
||||
|
||||
for img in img_group:
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(img[h1:h2, w1:w2, ...])
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomCropRatio(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img_group):
|
||||
h, w = img_group[0].shape[0:2]
|
||||
tw, th = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = random.randint(0, max(0, h - th))
|
||||
w1 = random.randint(0, max(0, w - tw))
|
||||
h2 = min(h1 + th, h)
|
||||
w2 = min(w1 + tw, w)
|
||||
|
||||
for img in img_group:
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(img[h1:h2, w1:w2, ...])
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupCenterCrop(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img_group):
|
||||
h, w = img_group[0].shape[0:2]
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = max(0, int((h - th) / 2))
|
||||
w1 = max(0, int((w - tw) / 2))
|
||||
h2 = min(h1 + th, h)
|
||||
w2 = min(w1 + tw, w)
|
||||
|
||||
for img in img_group:
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(img[h1:h2, w1:w2, ...])
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomPad(object):
|
||||
def __init__(self, size, padding):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.padding) == len(img_group))
|
||||
h, w = img_group[0].shape[0:2]
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = random.randint(0, max(0, th - h))
|
||||
w1 = random.randint(0, max(0, tw - w))
|
||||
h2 = max(th - h - h1, 0)
|
||||
w2 = max(tw - w - w1, 0)
|
||||
|
||||
for img, padding in zip(img_group, self.padding):
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(cv2.copyMakeBorder(
|
||||
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupCenterPad(object):
|
||||
def __init__(self, size, padding):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.padding) == len(img_group))
|
||||
h, w = img_group[0].shape[0:2]
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = max(0, int((th - h) / 2))
|
||||
w1 = max(0, int((tw - w) / 2))
|
||||
h2 = max(th - h - h1, 0)
|
||||
w2 = max(tw - w - w1, 0)
|
||||
|
||||
for img, padding in zip(img_group, self.padding):
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(cv2.copyMakeBorder(
|
||||
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupConcerPad(object):
|
||||
def __init__(self, size, padding):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.padding) == len(img_group))
|
||||
h, w = img_group[0].shape[0:2]
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
h1 = 0
|
||||
w1 = 0
|
||||
h2 = max(th - h - h1, 0)
|
||||
w2 = max(tw - w - w1, 0)
|
||||
|
||||
for img, padding in zip(img_group, self.padding):
|
||||
assert (img.shape[0] == h and img.shape[1] == w)
|
||||
out_images.append(cv2.copyMakeBorder(
|
||||
img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomScaleNew(object):
|
||||
def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.interpolation) == len(img_group))
|
||||
scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590
|
||||
out_images = list()
|
||||
for img, interpolation in zip(img_group, self.interpolation):
|
||||
out_images.append(cv2.resize(img, None, fx=scale_w,
|
||||
fy=scale_h, interpolation=interpolation))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomScale(object):
|
||||
def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.interpolation) == len(img_group))
|
||||
scale = random.uniform(self.size[0], self.size[1])
|
||||
out_images = list()
|
||||
for img, interpolation in zip(img_group, self.interpolation):
|
||||
out_images.append(cv2.resize(img, None, fx=scale,
|
||||
fy=scale, interpolation=interpolation))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomMultiScale(object):
|
||||
def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.interpolation) == len(img_group))
|
||||
scales = [0.5, 1.0, 1.5] # random.uniform(self.size[0], self.size[1])
|
||||
out_images = list()
|
||||
for scale in scales:
|
||||
for img, interpolation in zip(img_group, self.interpolation):
|
||||
out_images.append(cv2.resize(
|
||||
img, None, fx=scale, fy=scale, interpolation=interpolation))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomScaleRatio(object):
|
||||
def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
self.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784,
|
||||
1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295]
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.interpolation) == len(img_group))
|
||||
w_scale = random.randint(self.size[0], self.size[1])
|
||||
h_scale = random.randint(self.size[2], self.size[3])
|
||||
h, w, _ = img_group[0].shape
|
||||
out_images = list()
|
||||
out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h,
|
||||
interpolation=self.interpolation[0])) # fx=w_scale*1.0/w, fy=h_scale*1.0/h
|
||||
### process label map ###
|
||||
origin_label = cv2.resize(
|
||||
img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1])
|
||||
origin_label = origin_label.astype(int)
|
||||
label = origin_label[:, :, 0] * 5 + \
|
||||
origin_label[:, :, 1] * 3 + origin_label[:, :, 2]
|
||||
new_label = np.ones(label.shape) * 100
|
||||
new_label = new_label.astype(int)
|
||||
for cnt in range(37):
|
||||
new_label = (
|
||||
label == self.origin_id[cnt]) * (cnt - 100) + new_label
|
||||
new_label = (label == self.origin_id[37]) * (36 - 100) + new_label
|
||||
assert(100 not in np.unique(new_label))
|
||||
out_images.append(new_label)
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupRandomRotation(object):
|
||||
def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None):
|
||||
self.degree = degree
|
||||
self.interpolation = interpolation
|
||||
self.padding = padding
|
||||
if self.padding is None:
|
||||
self.padding = [0, 0]
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.interpolation) == len(img_group))
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
degree = random.uniform(self.degree[0], self.degree[1])
|
||||
h, w = img_group[0].shape[0:2]
|
||||
center = (w / 2, h / 2)
|
||||
map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0)
|
||||
out_images = list()
|
||||
for img, interpolation, padding in zip(img_group, self.interpolation, self.padding):
|
||||
out_images.append(cv2.warpAffine(
|
||||
img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding))
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class GroupRandomBlur(object):
|
||||
def __init__(self, applied):
|
||||
self.applied = applied
|
||||
|
||||
def __call__(self, img_group):
|
||||
assert (len(self.applied) == len(img_group))
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
out_images = []
|
||||
for img, a in zip(img_group, self.applied):
|
||||
if a:
|
||||
img = cv2.GaussianBlur(
|
||||
img, (5, 5), random.uniform(1e-6, 0.6))
|
||||
out_images.append(img)
|
||||
if len(img.shape) > len(out_images[-1].shape):
|
||||
out_images[-1] = out_images[-1][...,
|
||||
np.newaxis] # single channel image
|
||||
return out_images
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class GroupRandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given numpy Image with a probability of 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, is_flow=False):
|
||||
self.is_flow = is_flow
|
||||
|
||||
def __call__(self, img_group, is_flow=False):
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
out_images = [np.fliplr(img) for img in img_group]
|
||||
if self.is_flow:
|
||||
for i in range(0, len(out_images), 2):
|
||||
# invert flow pixel values when flipping
|
||||
out_images[i] = -out_images[i]
|
||||
return out_images
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class GroupNormalize(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, img_group):
|
||||
out_images = list()
|
||||
for img, m, s in zip(img_group, self.mean, self.std):
|
||||
if len(m) == 1:
|
||||
img = img - np.array(m) # single channel image
|
||||
img = img / np.array(s)
|
||||
else:
|
||||
img = img - np.array(m)[np.newaxis, np.newaxis, ...]
|
||||
img = img / np.array(s)[np.newaxis, np.newaxis, ...]
|
||||
out_images.append(img)
|
||||
|
||||
return out_images
|
Loading…
Reference in New Issue