first
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 liuyyy111
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,48 @@
|
|||
# Introduction
|
||||
This is Bidirectional Correct Attention Network, source code of Attend, Correct and Focus: [Bidirectional Correct Attention Network for Image-Text Matching (ICIP 2021)](https://ieeexplore.ieee.org/abstract/document/9506438) and BCAN++: Cross-Modal Retrieval with Improved Bidirectional Correct Attention Network.
|
||||
It is built on top of the [SCAN](github.com/kuanghuei/SCAN) in Pytorch.
|
||||
# Requirements and Installation
|
||||
We recommended the following dependencies.
|
||||
- Python 3.7
|
||||
- Pytorch 1.6+
|
||||
- Numpy
|
||||
- nltk
|
||||
# Download data
|
||||
Download the dataset files. We use the image feature created by SCAN, downloaded [here](https://github.com/kuanghuei/SCAN). All the data needed for reproducing the experiments in the paper, including image features and vocabularies, can be downloaded from:
|
||||
```bash
|
||||
wget https://scanproject.blob.core.windows.net/scan-data/data.zip
|
||||
wget https://scanproject.blob.core.windows.net/scan-data/vocab.zip
|
||||
```
|
||||
# Training
|
||||
- Train new BCAN models: Run `train.py`:
|
||||
```bash
|
||||
python train.py --data_path "$DATA_PATH" --data_name "$DATA_NAME" --logger_name "$LOGGER_NAME" --model_name "$MODEL_NAME"
|
||||
```
|
||||
|
||||
- Train new BCAN++ models: Run `bcan++_train.py`:
|
||||
```bash
|
||||
python bcan++_trian.py --data_path "$DATA_PATH" --data_name "$DATA_NAME" --logger_name "$LOGGER_NAME" --model_name "$MODEL_NAME"
|
||||
```
|
||||
|
||||
Argument used to train Flickr30K models and MSCOCO models are similar with those of SCAN:
|
||||
|
||||
For Flickr30K:
|
||||
| Method | Arguments |
|
||||
|:-:|:-:|
|
||||
|BCAN-equal| `--num_epochs=20 --lr_update=15 --correct_type=equal`|
|
||||
|BCAN-prob| `--num_epochs=20 --lr_update=15 --correct_type=prob`|
|
||||
|
||||
For MSCOCO:
|
||||
| Method | Arguments |
|
||||
|:-:|:-:|
|
||||
|BCAN-equal| `--num_epochs=15 --lr_update=8 --correct_type=equal`|
|
||||
|BCAN-prob| `--num_epochs=15 --lr_update=8 --correct_type=prob`|
|
||||
|
||||
|
||||
|
||||
# Evaluation
|
||||
```python
|
||||
from vocab import Vocabulary
|
||||
import evaluation
|
||||
evaluation.evalrank("$RUN_PATH/coco_scan/model_best.pth.tar", data_path="$DATA_PATH", split="test")
|
||||
```
|
|
@ -0,0 +1,160 @@
|
|||
from flask import Flask, request, jsonify
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import nltk
|
||||
# load model and options
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from data import get_test_loader
|
||||
from evaluation import AverageMeter, LogCollector, shard_xattn, i2t, t2i
|
||||
#from extract_features import feature
|
||||
from extract_features import feature, loadModel, imgFeature
|
||||
from model import SCAN
|
||||
from test_one import encode_img_caps, encode_image, encode_cap
|
||||
from vocab import deserialize_vocab
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
|
||||
# 限制上传的图片最大为10M
|
||||
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
|
||||
ALLOWED_FILE = set(['zip'])
|
||||
|
||||
model_path = "./runs/test/model_best.pth.tar"
|
||||
data_path = "./data/"
|
||||
image_path = "./image/ride.jpg"
|
||||
checkpoint = torch.load(model_path)
|
||||
opt = checkpoint['opt']
|
||||
print(opt)
|
||||
|
||||
caps_list = []
|
||||
with open("test_caps.txt", "r") as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip("\n")
|
||||
caps_list.append(line)
|
||||
print(len(caps_list))
|
||||
image_list = []
|
||||
with open("result.txt", "r") as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip("\n")
|
||||
image_list.append(line.split("#")[0])
|
||||
# print(len(image_list))
|
||||
# print(image_list[:10])
|
||||
id_list = []
|
||||
with open("test_ids.txt", "r") as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip("\n")
|
||||
id_list.append(line)
|
||||
# print(len(id_list))
|
||||
# print(id_list[:10])
|
||||
|
||||
|
||||
if data_path is not None:
|
||||
opt.data_path = data_path
|
||||
|
||||
# load vocabulary used by the model
|
||||
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
|
||||
word2idx = vocab.word2idx
|
||||
opt.vocab_size = len(vocab)
|
||||
|
||||
model = SCAN(word2idx, opt)
|
||||
model = torch.nn.DataParallel(model)
|
||||
model.cuda()
|
||||
|
||||
# load model state
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
||||
print('Loading dataset')
|
||||
loadTime = time.time()
|
||||
data_loader = get_test_loader("test", opt.data_name, vocab,
|
||||
opt.batch_size, 0, opt)
|
||||
print('loadTime:{:.4f}'.format(time.time() - loadTime))
|
||||
print('Computing results...')
|
||||
#img_embs, img_means, cap_embs, cap_lens, cap_means, img_id= encode_data(model, data_loader)
|
||||
img_embs, img_means, cap_embs, cap_lens, cap_means = encode_img_caps(model, data_loader)
|
||||
print(img_embs.shape, cap_embs.shape)
|
||||
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
|
||||
featureModel , cfg = loadModel()
|
||||
|
||||
# 检查上传文件格式
|
||||
def check_file_format(file_name, format):
|
||||
if '.' in file_name:
|
||||
file_format = file_name.rsplit('.')[1]
|
||||
if file_format in format:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 检查img大小,大于10M抛出异常
|
||||
def check_img_size(img_path):
|
||||
fsize = os.path.getsize(img_path)
|
||||
if fsize > ALLOWED_IMG_SIZE:
|
||||
raise RequestEntityTooLarge
|
||||
|
||||
def create_response(status, textList = None):
|
||||
# res为总的json结构体
|
||||
res = {}
|
||||
res['status'] = status
|
||||
#res['message'] = message
|
||||
if(textList != None):
|
||||
res['textList'] = textList
|
||||
return jsonify(res)
|
||||
|
||||
@app.route('/')
|
||||
def hello():
|
||||
return 'Hello, World!'
|
||||
|
||||
@app.route('/text2image', methods=['POST'])
|
||||
def text2image():
|
||||
str = request.form.get("string")
|
||||
cap_emb, cap_len, cap_mean = encode_cap(model, str, vocab)
|
||||
print(cap_emb.shape, len(cap_len), cap_mean.shape)
|
||||
sims = shard_xattn(model, img_embs, img_means, cap_emb, cap_len, cap_mean, opt, shard_size=1024)
|
||||
sims = sims.T
|
||||
print(sims.shape)
|
||||
|
||||
inds = np.argsort(sims[0])[::-1]
|
||||
print(inds[:10])
|
||||
imgList = []
|
||||
for i in inds[:10]:
|
||||
print(image_list[5 * int(id_list[5*i])])
|
||||
#print(image_list[5 * int(id_list[i])])
|
||||
imgList.append(image_list[5 * int(id_list[5*i])])
|
||||
return create_response(1, imgList)
|
||||
|
||||
@app.route('/image2text', methods=['POST'])
|
||||
def image2text():
|
||||
try:
|
||||
f = request.files['file_name']
|
||||
if f and check_file_format(f.filename, ALLOWED_IMG):
|
||||
img_path = './image/' + secure_filename(f.filename)
|
||||
f.save(img_path)
|
||||
check_img_size(img_path)
|
||||
tic = time.time()
|
||||
image_feat = imgFeature(featureModel, cfg, img_path)
|
||||
print(image_feat.shape)
|
||||
img_emb, img_mean = encode_image(model, image_feat)
|
||||
print(img_emb.shape)
|
||||
sims = shard_xattn(model, img_emb, img_mean, cap_embs, cap_lens, cap_means, opt, shard_size=10240)
|
||||
print(sims.shape)
|
||||
inds = np.argsort(sims[0])[::-1]
|
||||
# inds = inds.astype("int32")
|
||||
print(inds[:10])
|
||||
print(inds.dtype)
|
||||
textList = []
|
||||
for i in inds[:10]:
|
||||
textList.append(caps_list[i])
|
||||
print(caps_list[i])
|
||||
print('time: {:.4f}'.format(time.time() - tic))
|
||||
return create_response(1, textList)
|
||||
else:
|
||||
return create_response('png jpg jpeg bmp are allowed')
|
||||
except RequestEntityTooLarge:
|
||||
return create_response('image size should be less than 10M')
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=5001)
|
|
@ -0,0 +1,50 @@
|
|||
MODEL:
|
||||
WEIGHTS: "bua-caffe-frcn-r101_with_attributes_fix36.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: True
|
||||
EXTRACT_FEATS: True
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
EXTRACTOR:
|
||||
MIN_BOXES: 36
|
||||
MAX_BOXES: 36
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 2
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
MODEL:
|
||||
WEIGHTS: "bua-caffe-frcn-r101_with_attributes.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: True
|
||||
EXTRACT_FEATS: True
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
EXTRACTOR:
|
||||
MIN_BOXES: 10
|
||||
MAX_BOXES: 100
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 2
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
OUTPUT_DIR: "./output_caffe152"
|
||||
MODEL:
|
||||
WEIGHTS: "/home/zdz/bottom-up-attention/configs/bua-caffe/bua-caffe-frcn-r152_with_attributes.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [0, 0, 0]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: True
|
||||
EXTRACT_FEATS: True
|
||||
RESNET_VERSION: 2
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
EXTRACTOR:
|
||||
MIN_BOXES: 36
|
||||
MAX_BOXES: 36
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 152
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 1
|
||||
STRIDE_IN_1X1: False
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
DATASETS:
|
||||
TRAIN: ("visual_genome_train",)
|
||||
TEST: ("visual_genome_val",)
|
||||
TEST:
|
||||
DETECTIONS_PER_IMAGE: 400
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 0
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
MODEL:
|
||||
WEIGHTS: "bua-caffe-frcn-r101_with_attributes_fix36.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: False
|
||||
EXTRACT_FEATS: False
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 2
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
DATASETS:
|
||||
TRAIN: ("visual_genome_train",)
|
||||
TEST: ("visual_genome_val",)
|
||||
TEST:
|
||||
DETECTIONS_PER_IMAGE: 400
|
||||
DATALOADER:
|
||||
NUM_WORKERS:
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
|
@ -0,0 +1,54 @@
|
|||
MODEL:
|
||||
WEIGHTS: "bua-caffe-frcn-r101_with_attributes.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: False
|
||||
EXTRACT_FEATS: False
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 2
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
DATASETS:
|
||||
TRAIN: ("visual_genome_train",)
|
||||
TEST: ("visual_genome_val",)
|
||||
TEST:
|
||||
DETECTIONS_PER_IMAGE: 400
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 0
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
OUTPUT_DIR: "./output_caffe152"
|
||||
MODEL:
|
||||
WEIGHTS: "bua-caffe-frcn-r152_with_attributes.pth"
|
||||
META_ARCHITECTURE: "GeneralizedBUARCNN"
|
||||
PIXEL_MEAN: [0, 0, 0]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[4, 8, 16, 32]]
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "BUARPN"
|
||||
MIN_SIZE: 16
|
||||
BUA:
|
||||
ATTRIBUTE_ON: False
|
||||
EXTRACT_FEATS: False
|
||||
RESNET_VERSION: 2
|
||||
RPN:
|
||||
CONV_OUT_CHANNELS: 512
|
||||
EXTRACTOR:
|
||||
MIN_BOXES: 100
|
||||
MAX_BOXES: 100
|
||||
ATTRIBUTE:
|
||||
NUM_CLASSES: 401
|
||||
RESNETS:
|
||||
DEPTH: 152
|
||||
OUT_FEATURES: ["res4"]
|
||||
NORM: "BN"
|
||||
RES5_DILATION: 1
|
||||
STRIDE_IN_1X1: False
|
||||
BACKBONE:
|
||||
NAME: "build_bua_resnet_backbone"
|
||||
FREEZE_AT: 3
|
||||
RPN:
|
||||
HEAD_NAME: "StandardBUARPNHead"
|
||||
PRE_NMS_TOPK_TRAIN: 12000
|
||||
POST_NMS_TOPK_TRAIN: 2000
|
||||
POST_NMS_TOPK_TEST: 300
|
||||
PRE_NMS_TOPK_TEST: 6000
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
ROI_HEADS:
|
||||
NAME: "BUACaffeRes5ROIHeads"
|
||||
BATCH_SIZE_PER_IMAGE: 64
|
||||
SCORE_THRESH_TEST: -1.0
|
||||
NMS_THRESH_TEST: 0.3
|
||||
POSITIVE_FRACTION: 0.5
|
||||
NUM_CLASSES: 1601
|
||||
ROI_BOX_HEAD:
|
||||
POOLER_TYPE: "ROIPool"
|
||||
BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
|
||||
DATASETS:
|
||||
TRAIN: ("visual_genome_train",)
|
||||
TEST: ("visual_genome_val",)
|
||||
TEST:
|
||||
DETECTIONS_PER_IMAGE: 400
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 0
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (600, )
|
||||
MAX_SIZE_TRAIN: 1000
|
||||
MIN_SIZE_TEST: 600
|
||||
MAX_SIZE_TEST: 1000
|
||||
|
|
@ -0,0 +1,233 @@
|
|||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import h5py
|
||||
import nltk
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from torch.utils.data import DataLoader
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from transformers import BertTokenizer
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
class PrecompShuffleDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features
|
||||
Possible options: f30k_precomp, coco_precomp
|
||||
"""
|
||||
def __init__(self, data_path, data_split, vocab):
|
||||
print('word txt encoder')
|
||||
self.vocab = vocab
|
||||
loc = data_path + '/'
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Captions
|
||||
self.captions = []
|
||||
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip().decode('utf-8'))
|
||||
|
||||
self.data_split = data_split
|
||||
# if self.data_split == 'test':
|
||||
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
||||
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
||||
|
||||
# self.tags = []
|
||||
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
||||
# for line in f:
|
||||
# self.tags.append(line.strip().decode('utf-8'))
|
||||
|
||||
# Image features
|
||||
print('loading npy')
|
||||
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
||||
#print(len(self.images), len(self.captions))
|
||||
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
||||
print('done load npy')
|
||||
self.length = len(self.captions)
|
||||
# self.length = 10000
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
if self.images.shape[0] != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'shuffle_dev':
|
||||
self.length = 20000
|
||||
self.im_div = 4
|
||||
if data_split == 'shuffle_train':
|
||||
#self.length = 20000
|
||||
self.im_div = 20
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
# handle the image redundancy
|
||||
img_id = int(index/self.im_div)
|
||||
image = torch.Tensor(self.images[img_id])
|
||||
caption = self.captions[index]
|
||||
vocab = self.vocab
|
||||
|
||||
caption = caption.replace('.', '.[SEP]')[:-6]
|
||||
caption = self.tokenizer.encode(caption)
|
||||
target = torch.Tensor(caption)
|
||||
|
||||
# Convert caption (string) to word ids.
|
||||
# tokens = nltk.tokenize.word_tokenize(
|
||||
# caption.encode('utf-8').decode('utf-8'))
|
||||
# caption = []
|
||||
# caption.append(vocab('<start>'))
|
||||
# caption.extend([vocab(str(token).lower()) for token in tokens])
|
||||
# caption.append(vocab('<end>'))
|
||||
# # assert(len(caption) - 2== len(new_tags))
|
||||
# target = torch.Tensor(caption)
|
||||
# # new_tags = torch.Tensor(new_tags)
|
||||
|
||||
return image, target, index, img_id
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
class PrecompDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features
|
||||
Possible options: f30k_precomp, coco_precomp
|
||||
"""
|
||||
def __init__(self, data_path, data_split, vocab):
|
||||
print('word txt encoder')
|
||||
self.vocab = vocab
|
||||
loc = data_path + '/'
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Captions
|
||||
self.captions = []
|
||||
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip().decode('utf-8'))
|
||||
|
||||
self.data_split = data_split
|
||||
# if self.data_split == 'test':
|
||||
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
||||
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
||||
|
||||
# self.tags = []
|
||||
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
||||
# for line in f:
|
||||
# self.tags.append(line.strip().decode('utf-8'))
|
||||
|
||||
# Image features
|
||||
print('loading npy')
|
||||
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
||||
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
||||
print('done load npy')
|
||||
self.length = len(self.captions)
|
||||
# self.length = 10000
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
if self.images.shape[0] != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
# handle the image redundancy
|
||||
img_id = int(index/self.im_div)
|
||||
image = torch.Tensor(self.images[img_id])
|
||||
caption = self.captions[index]
|
||||
vocab = self.vocab
|
||||
|
||||
caption = self.tokenizer.encode(caption)
|
||||
target = torch.Tensor(caption)
|
||||
|
||||
# Convert caption (string) to word ids.
|
||||
# tokens = nltk.tokenize.word_tokenize(
|
||||
# caption.encode('utf-8').decode('utf-8'))
|
||||
# caption = []
|
||||
# caption.append(vocab('<start>'))
|
||||
# caption.extend([vocab(str(token).lower()) for token in tokens])
|
||||
# caption.append(vocab('<end>'))
|
||||
# # assert(len(caption) - 2== len(new_tags))
|
||||
# target = torch.Tensor(caption)
|
||||
# # new_tags = torch.Tensor(new_tags)
|
||||
|
||||
return image, target, index, img_id
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
||||
Args:
|
||||
data: list of (image, caption) tuple.
|
||||
- image: torch tensor of shape (3, 256, 256).
|
||||
- caption: torch tensor of shape (?); variable length.
|
||||
|
||||
Returns:
|
||||
images: torch tensor of shape (batch_size, 3, 256, 256).
|
||||
targets: torch tensor of shape (batch_size, padded_length).
|
||||
lengths: list; valid length for each padded caption.
|
||||
"""
|
||||
# Sort a data list by caption length
|
||||
data.sort(key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
images, captions, ids, img_ids = zip(*data)
|
||||
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
images = torch.stack(images, 0)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = torch.LongTensor([len(cap) for cap in captions])
|
||||
#print(lengths.dtype, lengths.shape)
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
img_lengths = [len(image) for image in images]
|
||||
img_lengths = torch.Tensor(img_lengths)
|
||||
return images, img_lengths, targets, lengths, ids
|
||||
|
||||
|
||||
def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
|
||||
shuffle=True, num_workers=0):
|
||||
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
||||
|
||||
dset = PrecompDataset(data_path, data_split, vocab)
|
||||
#dset = PrecompShuffleDataset(data_path, data_split, vocab)
|
||||
# train_sampler = torch.utils.data.distributed.DistributedSampler(dset)
|
||||
# if data_split == 'train':
|
||||
# data_loader = DataLoader(dataset=dset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler)
|
||||
# else:
|
||||
print(num_workers)
|
||||
data_loader = DataLoaderX(dataset=dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def get_loaders(data_name, vocab, batch_size, workers, opt):
|
||||
dpath = os.path.join(opt.data_path, data_name)
|
||||
train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
|
||||
batch_size, True, workers)
|
||||
val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
|
||||
batch_size, False, workers)
|
||||
# train_loader = get_precomp_loader(dpath, 'shuffle_train', vocab, opt,
|
||||
# batch_size, True, workers)
|
||||
# val_loader = get_precomp_loader(dpath, 'shuffle_dev', vocab, opt,
|
||||
# batch_size, False, workers)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def get_test_loader(split_name, data_name, vocab, batch_size,
|
||||
workers, opt):
|
||||
dpath = os.path.join(opt.data_path, data_name)
|
||||
test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
|
||||
batch_size, False, workers)
|
||||
return test_loader
|
|
@ -0,0 +1,151 @@
|
|||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import h5py
|
||||
import nltk
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from torch.utils.data import DataLoader
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
from transformers import BertTokenizer
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
|
||||
class PrecompDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features
|
||||
Possible options: f30k_precomp, coco_precomp
|
||||
"""
|
||||
def __init__(self, data_path, data_split, vocab):
|
||||
print('word txt encoder')
|
||||
self.vocab = vocab
|
||||
loc = data_path + '/'
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Captions
|
||||
self.captions = []
|
||||
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip().decode('utf-8'))
|
||||
|
||||
self.data_split = data_split
|
||||
# if self.data_split == 'test':
|
||||
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
||||
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
||||
|
||||
# self.tags = []
|
||||
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
||||
# for line in f:
|
||||
# self.tags.append(line.strip().decode('utf-8'))
|
||||
|
||||
# Image features
|
||||
print('loading npy')
|
||||
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
||||
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
||||
print('done load npy')
|
||||
self.length = len(self.captions)
|
||||
# self.length = 10000
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
if self.images.shape[0] != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
# handle the image redundancy
|
||||
img_id = int(index/self.im_div)
|
||||
image = torch.Tensor(self.images[img_id])
|
||||
caption = self.captions[index]
|
||||
vocab = self.vocab
|
||||
|
||||
# caption = self.tokenizer.encode(caption)
|
||||
# target = torch.Tensor(caption)
|
||||
|
||||
# Convert caption (string) to word ids.
|
||||
tokens = nltk.tokenize.word_tokenize(
|
||||
caption.encode('utf-8').decode('utf-8'))
|
||||
caption = []
|
||||
caption.append(vocab('<start>'))
|
||||
caption.extend([vocab(str(token).lower()) for token in tokens])
|
||||
caption.append(vocab('<end>'))
|
||||
# assert(len(caption) - 2== len(new_tags))
|
||||
target = torch.Tensor(caption)
|
||||
# new_tags = torch.Tensor(new_tags)
|
||||
|
||||
return image, target, index, img_id
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
||||
Args:
|
||||
data: list of (image, caption) tuple.
|
||||
- image: torch tensor of shape (3, 256, 256).
|
||||
- caption: torch tensor of shape (?); variable length.
|
||||
|
||||
Returns:
|
||||
images: torch tensor of shape (batch_size, 3, 256, 256).
|
||||
targets: torch tensor of shape (batch_size, padded_length).
|
||||
lengths: list; valid length for each padded caption.
|
||||
"""
|
||||
# Sort a data list by caption length
|
||||
data.sort(key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
images, captions, ids, img_ids = zip(*data)
|
||||
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
images = torch.stack(images, 0)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = torch.LongTensor([len(cap) for cap in captions])
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
|
||||
return images, targets, lengths, ids
|
||||
|
||||
|
||||
def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
|
||||
shuffle=True, num_workers=0):
|
||||
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
||||
|
||||
dset = PrecompDataset(data_path, data_split, vocab)
|
||||
# train_sampler = torch.utils.data.distributed.DistributedSampler(dset)
|
||||
# if data_split == 'train':
|
||||
# data_loader = DataLoader(dataset=dset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler)
|
||||
# else:
|
||||
print(num_workers)
|
||||
data_loader = DataLoaderX(dataset=dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def get_loaders(data_name, vocab, batch_size, workers, opt):
|
||||
dpath = os.path.join(opt.data_path, data_name)
|
||||
train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
|
||||
batch_size, True, workers)
|
||||
val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
|
||||
batch_size, False, workers)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def get_test_loader2(split_name, data_name, vocab, batch_size,
|
||||
workers, opt):
|
||||
dpath = os.path.join(opt.data_path, data_name)
|
||||
test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
|
||||
batch_size, False, workers)
|
||||
return test_loader
|
|
@ -0,0 +1,4 @@
|
|||
from . import dataset_vg
|
||||
from .dataset_mapper import DatasetMapper
|
||||
|
||||
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import copy
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
from detectron2.data import detection_utils as utils
|
||||
from detectron2.data import transforms as T
|
||||
|
||||
from .transform_gen import ResizeShortestEdge
|
||||
from .detection_utils import annotations_to_instances
|
||||
|
||||
"""
|
||||
This file contains the default mapping that's applied to "dataset dicts".
|
||||
"""
|
||||
|
||||
__all__ = ["DatasetMapper"]
|
||||
|
||||
def build_transform_gen(cfg, is_train):
|
||||
"""
|
||||
Create a list of :class:`TransformGen` from config.
|
||||
Now it includes resizing and flipping.
|
||||
|
||||
Returns:
|
||||
list[TransformGen]
|
||||
"""
|
||||
if is_train:
|
||||
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
||||
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
||||
else:
|
||||
min_size = cfg.INPUT.MIN_SIZE_TEST
|
||||
max_size = cfg.INPUT.MAX_SIZE_TEST
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tfm_gens = []
|
||||
tfm_gens.append(ResizeShortestEdge(min_size, max_size, cfg.MODEL.PIXEL_MEAN))
|
||||
if is_train:
|
||||
logger.info("TransformGens used in training: " + str(tfm_gens))
|
||||
return tfm_gens
|
||||
|
||||
class DatasetMapper:
|
||||
"""
|
||||
A callable which takes a dataset dict in Detectron2 Dataset format,
|
||||
and map it into a format used by the model.
|
||||
|
||||
This is the default callable to be used to map your dataset dict into training data.
|
||||
You may need to follow it to implement your own one for customized logic.
|
||||
|
||||
The callable currently does the following:
|
||||
1. Read the image from "file_name"
|
||||
2. Applies cropping/geometric transforms to the image and annotations
|
||||
3. Prepare data and annotations to Tensor and :class:`Instances`
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, is_train=True):
|
||||
if cfg.INPUT.CROP.ENABLED and is_train:
|
||||
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
|
||||
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
|
||||
else:
|
||||
self.crop_gen = None
|
||||
|
||||
self.tfm_gens = build_transform_gen(cfg, is_train)
|
||||
|
||||
# fmt: off
|
||||
self.img_format = cfg.INPUT.FORMAT
|
||||
self.mask_on = cfg.MODEL.MASK_ON
|
||||
self.mask_format = cfg.INPUT.MASK_FORMAT
|
||||
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
||||
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
|
||||
# fmt: on
|
||||
if self.keypoint_on and is_train:
|
||||
# Flip only makes sense in training
|
||||
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
||||
else:
|
||||
self.keypoint_hflip_indices = None
|
||||
|
||||
if self.load_proposals:
|
||||
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
|
||||
self.proposal_topk = (
|
||||
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
||||
if is_train
|
||||
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
||||
)
|
||||
self.is_train = is_train
|
||||
|
||||
def __call__(self, dataset_dict):
|
||||
"""
|
||||
Args:
|
||||
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
||||
|
||||
Returns:
|
||||
dict: a format that builtin models in detectron2 accept
|
||||
"""
|
||||
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
||||
# USER: Write your own image loading if it's not from a file
|
||||
# image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
||||
image = cv2.imread(dataset_dict["file_name"])
|
||||
h, w = image.shape[:2]
|
||||
# utils.check_image_size(dataset_dict, image)
|
||||
|
||||
if "annotations" not in dataset_dict:
|
||||
image, transforms = T.apply_transform_gens(
|
||||
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
|
||||
)
|
||||
else:
|
||||
# Crop around an instance if there are instances in the image.
|
||||
# USER: Remove if you don't use cropping
|
||||
if self.crop_gen:
|
||||
crop_tfm = utils.gen_crop_transform_with_instance(
|
||||
self.crop_gen.get_crop_size(image.shape[:2]),
|
||||
image.shape[:2],
|
||||
np.random.choice(dataset_dict["annotations"]),
|
||||
)
|
||||
image = crop_tfm.apply_image(image)
|
||||
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
||||
if self.crop_gen:
|
||||
transforms = crop_tfm + transforms
|
||||
|
||||
image_shape = image.shape[:2] # h, w
|
||||
|
||||
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
||||
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
||||
# Therefore it's important to use torch.Tensor.
|
||||
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
|
||||
dataset_dict["im_scale"] = float(image_shape[0])/ float(h)
|
||||
# Can use uint8 if it turns out to be slow some day
|
||||
|
||||
# USER: Remove if you don't use pre-computed proposals.
|
||||
if self.load_proposals:
|
||||
utils.transform_proposals(
|
||||
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
|
||||
)
|
||||
|
||||
if not self.is_train:
|
||||
dataset_dict.pop("annotations", None)
|
||||
dataset_dict.pop("sem_seg_file_name", None)
|
||||
return dataset_dict
|
||||
|
||||
if "annotations" in dataset_dict:
|
||||
# USER: Modify this if you want to keep them for some reason.
|
||||
for anno in dataset_dict["annotations"]:
|
||||
if not self.mask_on:
|
||||
anno.pop("segmentation", None)
|
||||
if not self.keypoint_on:
|
||||
anno.pop("keypoints", None)
|
||||
|
||||
# USER: Implement additional transformations if you have other types of data
|
||||
annos = [
|
||||
utils.transform_instance_annotations(
|
||||
obj, transforms, image_shape
|
||||
)
|
||||
for obj in dataset_dict.pop("annotations")
|
||||
if obj.get("iscrowd", 0) == 0
|
||||
]
|
||||
instances = annotations_to_instances(
|
||||
annos, image_shape, mask_format=self.mask_format
|
||||
)
|
||||
# Create a tight bounding box from masks, useful when image is cropped
|
||||
if self.crop_gen and instances.has("gt_masks"):
|
||||
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
||||
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
||||
|
||||
return dataset_dict
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import os
|
||||
|
||||
from detectron2.data import DatasetCatalog, MetadataCatalog
|
||||
from .load_vg_json import load_vg_json
|
||||
|
||||
SPLITS = {
|
||||
"visual_genome_train": ("vg/images", "vg/annotations/train.json"),
|
||||
"visual_genome_val": ("vg/images", "vg/annotations/val.json"),
|
||||
}
|
||||
|
||||
for key, (image_root, json_file) in SPLITS.items():
|
||||
# Assume pre-defined datasets live in `./datasets`.
|
||||
json_file = os.path.join("datasets", json_file)
|
||||
image_root = os.path.join("datasets", image_root)
|
||||
|
||||
DatasetCatalog.register(
|
||||
key,
|
||||
lambda key=key, json_file=json_file, image_root=image_root: load_vg_json(
|
||||
json_file, image_root, key
|
||||
),
|
||||
)
|
||||
|
||||
MetadataCatalog.get(key).set(
|
||||
json_file=json_file, image_root=image_root
|
||||
)
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Common data processing utilities that are used in a
|
||||
typical object detection data pipeline.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from detectron2.structures import (
|
||||
Boxes,
|
||||
BoxMode,
|
||||
Instances,
|
||||
)
|
||||
|
||||
def transform_instance_annotations(
|
||||
annotation, transforms, image_size, *, keypoint_hflip_indices=None
|
||||
):
|
||||
"""
|
||||
Apply transforms to box, segmentation and keypoints annotations of a single instance.
|
||||
|
||||
It will use `transforms.apply_box` for the box, and
|
||||
`transforms.apply_coords` for segmentation polygons & keypoints.
|
||||
If you need anything more specially designed for each data structure,
|
||||
you'll need to implement your own version of this function or the transforms.
|
||||
|
||||
Args:
|
||||
annotation (dict): dict of instance annotations for a single instance.
|
||||
It will be modified in-place.
|
||||
transforms (TransformList):
|
||||
image_size (tuple): the height, width of the transformed image
|
||||
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
the same input dict with fields "bbox", "segmentation", "keypoints"
|
||||
transformed according to `transforms`.
|
||||
The "bbox_mode" field will be set to XYXY_ABS.
|
||||
"""
|
||||
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
|
||||
# Note that bbox is 1d (per-instance bounding box)
|
||||
annotation["bbox"] = transforms.apply_box([bbox])[0]
|
||||
annotation["bbox_mode"] = BoxMode.XYXY_ABS
|
||||
|
||||
if "attributes" in annotation:
|
||||
annotation["attributes"] = annotation["attributes"]
|
||||
|
||||
return annotation
|
||||
|
||||
def annotations_to_instances(annos, image_size, mask_format="polygon"):
|
||||
"""
|
||||
Create an :class:`Instances` object used by the models,
|
||||
from instance annotations in the dataset dict.
|
||||
|
||||
Args:
|
||||
annos (list[dict]): a list of instance annotations in one image, each
|
||||
element for one instance.
|
||||
image_size (tuple): height, width
|
||||
|
||||
Returns:
|
||||
Instances:
|
||||
It will contain fields "gt_boxes", "gt_classes",
|
||||
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
|
||||
This is the format that builtin models expect.
|
||||
"""
|
||||
boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
|
||||
target = Instances(image_size)
|
||||
boxes = target.gt_boxes = Boxes(boxes)
|
||||
boxes.clip(image_size)
|
||||
|
||||
classes = [obj["category_id"] for obj in annos]
|
||||
classes = torch.tensor(classes, dtype=torch.int64)
|
||||
target.gt_classes = classes
|
||||
|
||||
# attributes = [obj["attributes"] for obj in annos]
|
||||
attributes = []
|
||||
for obj in annos:
|
||||
if "attributes" in obj.keys():
|
||||
attributes.append(obj["attributes"])
|
||||
else:
|
||||
attributes.append([-1]*16)
|
||||
attributes = torch.tensor(attributes, dtype=torch.int64)
|
||||
target.gt_attributes = attributes
|
||||
|
||||
return target
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import io
|
||||
import logging
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
from fvcore.common.timer import Timer
|
||||
from detectron2.structures import BoxMode
|
||||
from fvcore.common.file_io import PathManager
|
||||
|
||||
|
||||
from detectron2.data import MetadataCatalog
|
||||
|
||||
"""
|
||||
This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["load_vg_json"]
|
||||
|
||||
|
||||
def load_vg_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
||||
"""
|
||||
Load a json file with COCO's instances annotation format.
|
||||
Currently supports instance detection, instance segmentation,
|
||||
and person keypoints annotations.
|
||||
|
||||
Args:
|
||||
json_file (str): full path to the json file in COCO instances annotation format.
|
||||
image_root (str): the directory where the images in this json file exists.
|
||||
dataset_name (str): the name of the dataset (e.g., coco_2017_train).
|
||||
If provided, this function will also put "thing_classes" into
|
||||
the metadata associated with this dataset.
|
||||
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
|
||||
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
|
||||
"category_id", "segmentation"). The values for these keys will be returned as-is.
|
||||
For example, the densepose annotations are loaded in this way.
|
||||
|
||||
Returns:
|
||||
list[dict]: a list of dicts in Detectron2 standard format. (See
|
||||
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
||||
|
||||
Notes:
|
||||
1. This function does not read the image files.
|
||||
The results do not have the "image" field.
|
||||
"""
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
timer = Timer()
|
||||
json_file = PathManager.get_local_path(json_file)
|
||||
with contextlib.redirect_stdout(io.StringIO()):
|
||||
coco_api = COCO(json_file)
|
||||
if timer.seconds() > 1:
|
||||
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
||||
|
||||
id_map = None
|
||||
if dataset_name is not None:
|
||||
meta = MetadataCatalog.get(dataset_name)
|
||||
cat_ids = sorted(coco_api.getCatIds())
|
||||
cats = coco_api.loadCats(cat_ids)
|
||||
# The categories in a custom json file may not be sorted.
|
||||
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
||||
meta.thing_classes = thing_classes
|
||||
|
||||
# In COCO, certain category ids are artificially removed,
|
||||
# and by convention they are always ignored.
|
||||
# We deal with COCO's id issue and translate
|
||||
# the category ids to contiguous ids in [0, 80).
|
||||
|
||||
# It works by looking at the "categories" field in the json, therefore
|
||||
# if users' own json also have incontiguous ids, we'll
|
||||
# apply this mapping as well but print a warning.
|
||||
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
||||
if "coco" not in dataset_name:
|
||||
logger.warning(
|
||||
"""
|
||||
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
||||
"""
|
||||
)
|
||||
id_map = {v: i for i, v in enumerate(cat_ids)}
|
||||
meta.thing_dataset_id_to_contiguous_id = id_map
|
||||
|
||||
# sort indices for reproducible results
|
||||
img_ids = sorted(list(coco_api.imgs.keys()))
|
||||
# imgs is a list of dicts, each looks something like:
|
||||
# {'license': 4,
|
||||
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
||||
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
||||
# 'height': 427,
|
||||
# 'width': 640,
|
||||
# 'date_captured': '2013-11-17 05:57:24',
|
||||
# 'id': 1268}
|
||||
imgs = coco_api.loadImgs(img_ids)
|
||||
# anns is a list[list[dict]], where each dict is an annotation
|
||||
# record for an object. The inner list enumerates the objects in an image
|
||||
# and the outer list enumerates over images. Example of anns[0]:
|
||||
# [{'segmentation': [[192.81,
|
||||
# 247.09,
|
||||
# ...
|
||||
# 219.03,
|
||||
# 249.06]],
|
||||
# 'area': 1035.749,
|
||||
# 'iscrowd': 0,
|
||||
# 'image_id': 1268,
|
||||
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
||||
# 'category_id': 16,
|
||||
# 'id': 42986},
|
||||
# ...]
|
||||
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
||||
|
||||
if "minival" not in json_file:
|
||||
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
|
||||
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
|
||||
# Therefore we explicitly white-list them.
|
||||
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
||||
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
||||
json_file
|
||||
)
|
||||
|
||||
imgs_anns = list(zip(imgs, anns))
|
||||
|
||||
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
|
||||
|
||||
dataset_dicts = []
|
||||
|
||||
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
|
||||
|
||||
num_instances_without_valid_segmentation = 0
|
||||
max_attributes_per_ins = 16
|
||||
|
||||
for (img_dict, anno_dict_list) in imgs_anns:
|
||||
record = {}
|
||||
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
||||
record["height"] = img_dict["height"]
|
||||
record["width"] = img_dict["width"]
|
||||
image_id = record["image_id"] = img_dict["id"]
|
||||
|
||||
objs = []
|
||||
for anno in anno_dict_list:
|
||||
# Check that the image_id in this annotation is the same as
|
||||
# the image_id we're looking at.
|
||||
# This fails only when the data parsing logic or the annotation file is buggy.
|
||||
|
||||
# The original COCO valminusminival2014 & minival2014 annotation files
|
||||
# actually contains bugs that, together with certain ways of using COCO API,
|
||||
# can trigger this assertion.
|
||||
assert anno["image_id"] == image_id
|
||||
|
||||
assert anno.get("ignore", 0) == 0
|
||||
|
||||
obj = {key: anno[key] for key in ann_keys if key in anno}
|
||||
|
||||
attr = anno.get("attribute", None)
|
||||
if attr:
|
||||
attributes = [-1 for _ in range(max_attributes_per_ins)]
|
||||
for idx, a in enumerate(attr):
|
||||
attributes[idx] = a - 1
|
||||
obj["attributes"] = attributes
|
||||
|
||||
segm = anno.get("segmentation", None)
|
||||
if segm: # either list[list[float]] or dict(RLE)
|
||||
if not isinstance(segm, dict):
|
||||
# filter out invalid polygons (< 3 points)
|
||||
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
||||
if len(segm) == 0:
|
||||
num_instances_without_valid_segmentation += 1
|
||||
continue # ignore this instance
|
||||
obj["segmentation"] = segm
|
||||
|
||||
keypts = anno.get("keypoints", None)
|
||||
if keypts: # list[int]
|
||||
for idx, v in enumerate(keypts):
|
||||
if idx % 3 != 2:
|
||||
# COCO's segmentation coordinates are floating points in [0, H or W],
|
||||
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
||||
# Therefore we assume the coordinates are "pixel indices" and
|
||||
# add 0.5 to convert to floating point coordinates.
|
||||
keypts[idx] = v + 0.5
|
||||
obj["keypoints"] = keypts
|
||||
|
||||
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
||||
if id_map:
|
||||
obj["category_id"] = id_map[obj["category_id"]]
|
||||
objs.append(obj)
|
||||
record["annotations"] = objs
|
||||
dataset_dicts.append(record)
|
||||
|
||||
if num_instances_without_valid_segmentation > 0:
|
||||
logger.warn(
|
||||
"Filtered out {} instances without valid segmentation. "
|
||||
"There might be issues in your dataset generation process.".format(
|
||||
num_instances_without_valid_segmentation
|
||||
)
|
||||
)
|
||||
return dataset_dicts
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import cv2
|
||||
import PIL.Image as Image
|
||||
import numpy as np
|
||||
from fvcore.transforms.transform import Transform
|
||||
from detectron2.data.transforms import TransformGen
|
||||
|
||||
|
||||
class ResizeTransform(Transform):
|
||||
"""
|
||||
Resize the image to a target size.
|
||||
"""
|
||||
|
||||
def __init__(self, h, w, im_scale, pixel_mean):
|
||||
"""
|
||||
Args:
|
||||
h, w (int): original image size
|
||||
im_scale: im_scale of new_h/h or new_w/w
|
||||
"""
|
||||
# TODO decide on PIL vs opencv
|
||||
super().__init__()
|
||||
self._set_attributes(locals())
|
||||
|
||||
def apply_image(self, img):
|
||||
assert img.shape[:2] == (self.h, self.w)
|
||||
img_norm = img.astype(np.float32, copy=True) - np.asarray(self.pixel_mean)
|
||||
im = cv2.resize(
|
||||
img_norm,
|
||||
None,
|
||||
None,
|
||||
fx=self.im_scale,
|
||||
fy=self.im_scale,
|
||||
interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
ret = np.asarray(im)
|
||||
return ret
|
||||
|
||||
def apply_coords(self, coords):
|
||||
coords[:, 0] = coords[:, 0] * (self.im_scale)
|
||||
coords[:, 1] = coords[:, 1] * (self.im_scale)
|
||||
return coords
|
||||
|
||||
def apply_segmentation(self, segmentation):
|
||||
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
||||
return segmentation
|
||||
|
||||
|
||||
class ResizeShortestEdge(TransformGen):
|
||||
"""
|
||||
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
||||
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, min_size, max_size, pixel_mean):
|
||||
"""
|
||||
Args:
|
||||
min_size (int): minimum allowed smallest edge length.
|
||||
max_size (int): maximum allowed longest edge length.
|
||||
"""
|
||||
super().__init__()
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
self.pixel_mean = pixel_mean
|
||||
|
||||
self._init(locals())
|
||||
|
||||
def get_transform(self, img):
|
||||
h, w = img.shape[:2]
|
||||
|
||||
im_shape = img.shape
|
||||
im_size_min = np.min(im_shape[0:2])
|
||||
im_size_max = np.max(im_shape[0:2])
|
||||
|
||||
im_scale = float(self.min_size if not type(self.min_size) is tuple else self.min_size[0]) / float(im_size_min)
|
||||
|
||||
# Prevent the biggest axis from being more than max_size
|
||||
if np.round(im_scale * im_size_max) > self.max_size:
|
||||
im_scale = float(self.max_size) / float(im_size_max)
|
||||
|
||||
return ResizeTransform(h, w, im_scale, self.pixel_mean)
|
After Width: | Height: | Size: 103 KiB |
After Width: | Height: | Size: 113 KiB |
After Width: | Height: | Size: 87 KiB |
After Width: | Height: | Size: 72 KiB |
After Width: | Height: | Size: 120 KiB |
After Width: | Height: | Size: 181 KiB |
After Width: | Height: | Size: 322 KiB |
After Width: | Height: | Size: 322 KiB |
|
@ -0,0 +1,20 @@
|
|||
from evaluation import evalrank, evalrank2,evalrank3,evalrank_vse,evalrank_maxpool,evalrank_f_c,evalrank_fanhua
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
||||
|
||||
#evalrank("./runs/bert_adam_bcan_gpo_vseinfty_bcan/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank2("../bcan_gpo/runs/bert_adam_bcan_gpo/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank2("../bcan_gpo/runs/bert_adam_bcan_gpo2/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
|
||||
#evalrank2("../bcan_gpo/runs/bert_adam_bcan_gpo/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
|
||||
#evalrank("./runs/bert_adam_bcan_gpo_vseinfty_bcan_coco/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
evalrank_fanhua("./runs/bert_adam_bcan_gpo_vseinfty_bcan_coco/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank3("../pycharmProject/runs/bigru_bcan_adam_mean_base2/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank3("./runs/bigru_adam_bcan_mean2/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank("./runs/bert_adam_bcan_gpo_shuffle_test/model_best.pth.tar", "../SCAN-master/data/", "shuffle_dev")
|
||||
#evalrank("./runs/bert/model_best.pth.tar", "./data/", "test")
|
||||
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# pylint: disable=no-member
|
||||
"""
|
||||
TridentNet Training Script.
|
||||
adf
|
||||
This script is a simplified version of the training script in detectron2/tools.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
# import tqdm
|
||||
import cv2
|
||||
import numpy as np
|
||||
sys.path.append('detectron2')
|
||||
|
||||
import detectron2.utils.comm as comm
|
||||
from detectron2.checkpoint import DetectionCheckpointer
|
||||
from detectron2.data import build_detection_test_loader, build_detection_train_loader
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.engine import DefaultTrainer, default_setup, launch
|
||||
from detectron2.evaluation import COCOEvaluator, verify_results
|
||||
from detectron2.structures import Instances
|
||||
from detectron2.modeling import build_model
|
||||
|
||||
from utils.utils import mkdir, save_features
|
||||
from utils.extract_utils import get_image_blob, save_bbox, save_roi_features_by_bbox, save_roi_features, image_features
|
||||
from utils.progress_bar import ProgressBar
|
||||
from models import add_config
|
||||
from models.bua.box_regression import BUABoxes
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
def switch_extract_mode(mode):
|
||||
if mode == 'roi_feats':
|
||||
switch_cmd = ['MODEL.BUA.EXTRACTOR.MODE', 1]
|
||||
elif mode == 'bboxes':
|
||||
switch_cmd = ['MODEL.BUA.EXTRACTOR.MODE', 2]
|
||||
elif mode == 'bbox_feats':
|
||||
switch_cmd = ['MODEL.BUA.EXTRACTOR.MODE', 3, 'MODEL.PROPOSAL_GENERATOR.NAME', 'PrecomputedProposals']
|
||||
else:
|
||||
print('Wrong extract mode! ')
|
||||
exit()
|
||||
return switch_cmd
|
||||
|
||||
def set_min_max_boxes(min_max_boxes):
|
||||
if min_max_boxes == 'min_max_default':
|
||||
return []
|
||||
try:
|
||||
min_boxes = int(min_max_boxes.split(',')[0])
|
||||
max_boxes = int(min_max_boxes.split(',')[1])
|
||||
except:
|
||||
print('Illegal min-max boxes setting, using config default. ')
|
||||
return []
|
||||
cmd = ['MODEL.BUA.EXTRACTOR.MIN_BOXES', min_boxes,
|
||||
'MODEL.BUA.EXTRACTOR.MAX_BOXES', max_boxes]
|
||||
return cmd
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_config(args, cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.merge_from_list(switch_extract_mode(args.extract_mode))
|
||||
cfg.merge_from_list(set_min_max_boxes(args.min_max_boxes))
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
def generate_npz(extract_mode, *args):
|
||||
if extract_mode == 1:
|
||||
save_roi_features(*args)
|
||||
elif extract_mode == 2:
|
||||
save_bbox(*args)
|
||||
elif extract_mode == 3:
|
||||
save_roi_features_by_bbox(*args)
|
||||
else:
|
||||
print('Invalid Extract Mode! ')
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def extract_feat(split_idx, img_list, cfg, args, actor: ActorHandle):
|
||||
num_images = len(img_list)
|
||||
print('Number of images on split{}: {}.'.format(split_idx, num_images))
|
||||
#print(cfg)
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
print("111111111111111111111111111111111111111111111")
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
print("222222222222222222222222222222222222222222222222")
|
||||
model.eval()
|
||||
|
||||
for im_file in (img_list):
|
||||
print("33333333333333333333333333333333333333333333")
|
||||
# if os.path.exists(os.path.join(args.output_dir, im_file.split('.')[0]+'.npz')):
|
||||
# actor.update.remote(1)
|
||||
# continue
|
||||
im = cv2.imread(im_file)
|
||||
|
||||
if im is None:
|
||||
print(im_file, "is illegal!")
|
||||
actor.update.remote(1)
|
||||
continue
|
||||
dataset_dict = get_image_blob(im, cfg.MODEL.PIXEL_MEAN)
|
||||
# extract roi features
|
||||
if cfg.MODEL.BUA.EXTRACTOR.MODE == 1:
|
||||
attr_scores = None
|
||||
print("444444444444444444444444444444444")
|
||||
with torch.set_grad_enabled(False):
|
||||
if cfg.MODEL.BUA.ATTRIBUTE_ON:
|
||||
print("55555555555555555555555555555")
|
||||
boxes, scores, features_pooled, attr_scores = model([dataset_dict])
|
||||
print("6666666666666666666666666")
|
||||
else:
|
||||
boxes, scores, features_pooled = model([dataset_dict])
|
||||
boxes = [box.tensor.cpu() for box in boxes]
|
||||
scores = [score.cpu() for score in scores]
|
||||
#torch.set_printoptions(subprocess=True)
|
||||
#print(features_pooled)
|
||||
features_pooled = [feat.cpu() for feat in features_pooled]
|
||||
|
||||
if not attr_scores is None:
|
||||
attr_scores = [attr_score.cpu() for attr_score in attr_scores]
|
||||
print("77777777777777777777777777")
|
||||
im_feat = image_features(cfg, im_file, im, dataset_dict,
|
||||
boxes, scores, features_pooled, attr_scores)
|
||||
print("888888888888888888888888")
|
||||
actor.update.remote(1)
|
||||
return im_feat
|
||||
|
||||
|
||||
def feature(image_path):
|
||||
parser = argparse.ArgumentParser(description="PyTorch Object Detection2 Inference")
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
default="configs/bua-caffe/extract-bua-caffe-r152.yaml",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
)
|
||||
|
||||
parser.add_argument('--num-cpus', default=1, type=int,
|
||||
help='number of cpus to use for ray, 0 means no limit')
|
||||
|
||||
parser.add_argument('--gpus', dest='gpu_id', help='GPU id(s) to use',
|
||||
default='0', type=str)
|
||||
|
||||
parser.add_argument("--mode", default="caffe", type=str, help="bua_caffe, ...")
|
||||
|
||||
parser.add_argument('--extract-mode', default='roi_feats', type=str,
|
||||
help="'roi_feats', 'bboxes' and 'bbox_feats' indicates \
|
||||
'extract roi features directly', 'extract bboxes only' and \
|
||||
'extract roi features with pre-computed bboxes' respectively")
|
||||
|
||||
parser.add_argument('--min-max-boxes', default='min_max_default', type=str,
|
||||
help='the number of min-max boxes of extractor')
|
||||
|
||||
parser.add_argument('--out-dir', dest='output_dir',
|
||||
help='output directory for features',
|
||||
default="features")
|
||||
parser.add_argument('--image-dir', dest='image_dir',
|
||||
help='directory with images',
|
||||
default="image")
|
||||
parser.add_argument('--bbox-dir', dest='bbox_dir',
|
||||
help='directory with bbox',
|
||||
default="bbox")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="whether to attempt to resume from the checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = setup(args)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
num_gpus = len(args.gpu_id.split(','))
|
||||
|
||||
MIN_BOXES = cfg.MODEL.BUA.EXTRACTOR.MIN_BOXES
|
||||
MAX_BOXES = cfg.MODEL.BUA.EXTRACTOR.MAX_BOXES
|
||||
CONF_THRESH = cfg.MODEL.BUA.EXTRACTOR.CONF_THRESH
|
||||
|
||||
|
||||
|
||||
print("---------------------------------------------------"+str(args.num_cpus))
|
||||
if args.num_cpus != 0:
|
||||
#print(ray.init(num_cpus=args.num_cpus))
|
||||
ray.init(num_cpus=args.num_cpus)
|
||||
else:
|
||||
#print(ray.init())
|
||||
ray.init()
|
||||
|
||||
# Extract features.
|
||||
imglist = []
|
||||
imglist.append(image_path)
|
||||
num_images = len(imglist)
|
||||
print('Number of images: {}.'.format(num_images))
|
||||
img_lists = [imglist[i::num_gpus] for i in range(num_gpus)]
|
||||
|
||||
pb = ProgressBar(len(imglist))
|
||||
actor = pb.actor
|
||||
|
||||
print('Number of GPUs: {}.'.format(num_gpus))
|
||||
extract_feat_list = []
|
||||
#model = DefaultTrainer.build_model(cfg)
|
||||
for i in range(num_gpus):
|
||||
extract_feat_list.append(extract_feat.remote(i, img_lists[i], cfg, args, actor))
|
||||
|
||||
pb.print_until_done()
|
||||
img_fe = ray.get(extract_feat_list)
|
||||
#img_fe = np.array(img_fe)
|
||||
#print(img_fe.shape)
|
||||
ray.get(actor.get_counter.remote())
|
||||
return img_fe[0]
|
||||
|
||||
def loadModel():
|
||||
parser = argparse.ArgumentParser(description="PyTorch Object Detection2 Inference")
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
default="configs/bua-caffe/extract-bua-caffe-r152.yaml",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
)
|
||||
|
||||
parser.add_argument('--num-cpus', default=1, type=int,
|
||||
help='number of cpus to use for ray, 0 means no limit')
|
||||
|
||||
parser.add_argument('--gpus', dest='gpu_id', help='GPU id(s) to use',
|
||||
default='0', type=str)
|
||||
|
||||
parser.add_argument("--mode", default="caffe", type=str, help="bua_caffe, ...")
|
||||
|
||||
parser.add_argument('--extract-mode', default='roi_feats', type=str,
|
||||
help="'roi_feats', 'bboxes' and 'bbox_feats' indicates \
|
||||
'extract roi features directly', 'extract bboxes only' and \
|
||||
'extract roi features with pre-computed bboxes' respectively")
|
||||
|
||||
parser.add_argument('--min-max-boxes', default='min_max_default', type=str,
|
||||
help='the number of min-max boxes of extractor')
|
||||
|
||||
parser.add_argument('--out-dir', dest='output_dir',
|
||||
help='output directory for features',
|
||||
default="features")
|
||||
parser.add_argument('--image-dir', dest='image_dir',
|
||||
help='directory with images',
|
||||
default="image")
|
||||
parser.add_argument('--bbox-dir', dest='bbox_dir',
|
||||
help='directory with bbox',
|
||||
default="bbox")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="whether to attempt to resume from the checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = setup(args)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
num_gpus = len(args.gpu_id.split(','))
|
||||
|
||||
MIN_BOXES = cfg.MODEL.BUA.EXTRACTOR.MIN_BOXES
|
||||
MAX_BOXES = cfg.MODEL.BUA.EXTRACTOR.MAX_BOXES
|
||||
CONF_THRESH = cfg.MODEL.BUA.EXTRACTOR.CONF_THRESH
|
||||
|
||||
print("---------------------------------------------------" + str(args.num_cpus))
|
||||
if args.num_cpus != 0:
|
||||
# print(ray.init(num_cpus=args.num_cpus))
|
||||
ray.init(num_cpus=args.num_cpus)
|
||||
else:
|
||||
# print(ray.init())
|
||||
ray.init()
|
||||
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
print("111111111111111111111111111111111111111111111")
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
print("222222222222222222222222222222222222222222222222")
|
||||
model.eval()
|
||||
return model, cfg
|
||||
|
||||
def imgFeature(model, cfg, im_file):
|
||||
print("33333333333333333333333333333333333333333333")
|
||||
# if os.path.exists(os.path.join(args.output_dir, im_file.split('.')[0]+'.npz')):
|
||||
# actor.update.remote(1)
|
||||
# continue
|
||||
im = cv2.imread(im_file)
|
||||
if im is None:
|
||||
print(im_file, "is illegal!")
|
||||
return 0
|
||||
dataset_dict = get_image_blob(im, cfg.MODEL.PIXEL_MEAN)
|
||||
# extract roi features
|
||||
if cfg.MODEL.BUA.EXTRACTOR.MODE == 1:
|
||||
attr_scores = None
|
||||
print("444444444444444444444444444444444")
|
||||
with torch.set_grad_enabled(False):
|
||||
if cfg.MODEL.BUA.ATTRIBUTE_ON:
|
||||
print("55555555555555555555555555555")
|
||||
boxes, scores, features_pooled, attr_scores = model([dataset_dict])
|
||||
print("6666666666666666666666666")
|
||||
else:
|
||||
boxes, scores, features_pooled = model([dataset_dict])
|
||||
boxes = [box.tensor.cpu() for box in boxes]
|
||||
scores = [score.cpu() for score in scores]
|
||||
# torch.set_printoptions(subprocess=True)
|
||||
# print(features_pooled)
|
||||
features_pooled = [feat.cpu() for feat in features_pooled]
|
||||
|
||||
if not attr_scores is None:
|
||||
attr_scores = [attr_score.cpu() for attr_score in attr_scores]
|
||||
print("77777777777777777777777777")
|
||||
im_feat = image_features(cfg, im_file, im, dataset_dict,
|
||||
boxes, scores, features_pooled, attr_scores)
|
||||
print("888888888888888888888888")
|
||||
return im_feat
|
||||
|
||||
if __name__ == "__main__":
|
||||
feature("./image/134206.jpg")
|
|
@ -0,0 +1,88 @@
|
|||
# coding=utf-8
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
def positional_encoding_1d(d_model, length):
|
||||
"""
|
||||
:param d_model: dimension of the model
|
||||
:param length: length of positions
|
||||
:return: length*d_model position matrix
|
||||
"""
|
||||
if d_model % 2 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd dim (got dim={:d})".format(d_model))
|
||||
pe = torch.zeros(length, d_model)
|
||||
position = torch.arange(0, length).unsqueeze(1)
|
||||
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / d_model)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
class GPO(nn.Module):
|
||||
def __init__(self, d_pe, d_hidden):
|
||||
super(GPO, self).__init__()
|
||||
self.d_pe = d_pe
|
||||
self.d_hidden = d_hidden
|
||||
|
||||
self.pe_database = {}
|
||||
self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
|
||||
self.linear = nn.Linear(self.d_hidden, 1, bias=False)
|
||||
|
||||
# for p in self.parameters():
|
||||
# p.requires_grad = False
|
||||
|
||||
def compute_pool_weights(self, lengths, features):
|
||||
max_len = int(lengths.max())
|
||||
pe_max_len = self.get_pe(max_len)
|
||||
pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
|
||||
mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
|
||||
mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
|
||||
pes = pes.masked_fill(mask == 0, 0)
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
||||
out, _ = self.gru(packed)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
out_emb, out_len = padded
|
||||
out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
|
||||
scores = self.linear(out_emb)
|
||||
scores[torch.where(mask == 0)] = -10000
|
||||
|
||||
weights = torch.softmax(scores / 0.1, 1)
|
||||
return weights, mask
|
||||
|
||||
def forward(self, features, lengths):
|
||||
"""
|
||||
:param features: features with shape B x K x D
|
||||
:param lengths: B x 1, specify the length of each data sample.
|
||||
:return: pooled feature with shape B x D
|
||||
"""
|
||||
pool_weights, mask = self.compute_pool_weights(lengths, features)
|
||||
|
||||
features = features[:, :int(lengths.max()), :]
|
||||
sorted_features = features.masked_fill(mask == 0, -10000)
|
||||
sorted_features = sorted_features.sort(dim=1, descending=True)[0]
|
||||
sorted_features = sorted_features.masked_fill(mask == 0, 0)
|
||||
|
||||
pooled_features = (sorted_features * pool_weights).sum(1)
|
||||
return pooled_features, pool_weights
|
||||
|
||||
def get_pe(self, length):
|
||||
"""
|
||||
|
||||
:param length: the length of the sequence
|
||||
:return: the positional encoding of the given length
|
||||
"""
|
||||
length = int(length)
|
||||
if length in self.pe_database:
|
||||
return self.pe_database[length]
|
||||
else:
|
||||
pe = positional_encoding_1d(self.d_pe, length)
|
||||
self.pe_database[length] = pe
|
||||
return pe
|
After Width: | Height: | Size: 55 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 36 KiB |
After Width: | Height: | Size: 59 KiB |
|
@ -0,0 +1,371 @@
|
|||
"""COCO dataset loader"""
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from imageio import imread
|
||||
import random
|
||||
import json
|
||||
import cv2
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RawImageDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features
|
||||
Possible options: f30k_precomp, coco_precomp
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, data_name, data_split, tokenzier, opt, train):
|
||||
self.opt = opt
|
||||
self.train = train
|
||||
self.data_path = data_path
|
||||
self.data_name = data_name
|
||||
self.tokenizer = tokenzier
|
||||
|
||||
loc_cap = osp.join(data_path, 'precomp')
|
||||
loc_image = osp.join(data_path, 'precomp')
|
||||
loc_mapping = osp.join(data_path, 'id_mapping.json')
|
||||
if 'coco' in data_name:
|
||||
self.image_base = osp.join(data_path, 'images')
|
||||
else:
|
||||
self.image_base = osp.join(data_path, 'flickr30k-images')
|
||||
|
||||
with open(loc_mapping, 'r') as f_mapping:
|
||||
self.id_to_path = json.load(f_mapping)
|
||||
|
||||
# Read Captions
|
||||
self.captions = []
|
||||
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip())
|
||||
|
||||
# Get the image ids
|
||||
with open(osp.join(loc_image, '{}_ids.txt'.format(data_split)), 'r') as f:
|
||||
image_ids = f.readlines()
|
||||
self.images = [int(x.strip()) for x in image_ids]
|
||||
|
||||
# Set related parameters according to the pre-trained backbone **
|
||||
assert 'backbone' in opt.precomp_enc_type
|
||||
|
||||
self.backbone_source = opt.backbone_source
|
||||
self.base_target_size = 256
|
||||
self.crop_ratio = 0.875
|
||||
self.train_scale_rate = 1
|
||||
if hasattr(opt, 'input_scale_factor') and opt.input_scale_factor != 1:
|
||||
self.base_target_size = int(self.base_target_size * opt.input_scale_factor)
|
||||
logger.info('Input images are scaled by factor {}'.format(opt.input_scale_factor))
|
||||
if 'detector' in self.backbone_source:
|
||||
self.pixel_means = np.array([[[102.9801, 115.9465, 122.7717]]])
|
||||
else:
|
||||
self.imagenet_mean = [0.485, 0.456, 0.406]
|
||||
self.imagenet_std = [0.229, 0.224, 0.225]
|
||||
|
||||
self.length = len(self.captions)
|
||||
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
num_images = len(self.images)
|
||||
|
||||
if num_images != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_index = index // self.im_div
|
||||
caption = self.captions[index]
|
||||
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
|
||||
|
||||
# Convert caption (string) to word ids (with Size Augmentation at training time).
|
||||
target = process_caption(self.tokenizer, caption_tokens, self.train)
|
||||
|
||||
image_id = self.images[img_index]
|
||||
image_path = os.path.join(self.image_base, self.id_to_path[str(image_id)])
|
||||
im_in = np.array(imread(image_path))
|
||||
processed_image = self._process_image(im_in)
|
||||
image = torch.Tensor(processed_image)
|
||||
image = image.permute(2, 0, 1)
|
||||
return image, target, index, img_index
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def _process_image(self, im_in):
|
||||
"""
|
||||
Converts an image into a network input, with pre-processing including re-scaling, padding, etc, and data
|
||||
augmentation.
|
||||
"""
|
||||
if len(im_in.shape) == 2:
|
||||
im_in = im_in[:, :, np.newaxis]
|
||||
im_in = np.concatenate((im_in, im_in, im_in), axis=2)
|
||||
|
||||
if 'detector' in self.backbone_source:
|
||||
im_in = im_in[:, :, ::-1]
|
||||
im = im_in.astype(np.float32, copy=True)
|
||||
|
||||
if self.train:
|
||||
target_size = self.base_target_size * self.train_scale_rate
|
||||
else:
|
||||
target_size = self.base_target_size
|
||||
|
||||
# 2. Random crop when in training mode, elsewise just skip
|
||||
if self.train:
|
||||
crop_ratio = np.random.random() * 0.4 + 0.6
|
||||
crop_size_h = int(im.shape[0] * crop_ratio)
|
||||
crop_size_w = int(im.shape[1] * crop_ratio)
|
||||
processed_im = self._crop(im, crop_size_h, crop_size_w, random=True)
|
||||
else:
|
||||
processed_im = im
|
||||
|
||||
# 3. Resize to the target resolution
|
||||
im_shape = processed_im.shape
|
||||
im_scale_x = float(target_size) / im_shape[1]
|
||||
im_scale_y = float(target_size) / im_shape[0]
|
||||
processed_im = cv2.resize(processed_im, None, None, fx=im_scale_x, fy=im_scale_y,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.train:
|
||||
if np.random.random() > 0.5:
|
||||
processed_im = self._hori_flip(processed_im)
|
||||
|
||||
# Normalization
|
||||
if 'detector' in self.backbone_source:
|
||||
processed_im = self._detector_norm(processed_im)
|
||||
else:
|
||||
processed_im = self._imagenet_norm(processed_im)
|
||||
|
||||
return processed_im
|
||||
|
||||
def _imagenet_norm(self, im_in):
|
||||
im_in = im_in.astype(np.float32)
|
||||
im_in = im_in / 255
|
||||
for i in range(im_in.shape[-1]):
|
||||
im_in[:, :, i] = (im_in[:, :, i] - self.imagenet_mean[i]) / self.imagenet_std[i]
|
||||
return im_in
|
||||
|
||||
def _detector_norm(self, im_in):
|
||||
im_in = im_in.astype(np.float32)
|
||||
im_in -= self.pixel_means
|
||||
return im_in
|
||||
|
||||
@staticmethod
|
||||
def _crop(im, crop_size_h, crop_size_w, random):
|
||||
h, w = im.shape[0], im.shape[1]
|
||||
if random:
|
||||
if w - crop_size_w == 0:
|
||||
x_start = 0
|
||||
else:
|
||||
x_start = np.random.randint(w - crop_size_w, size=1)[0]
|
||||
if h - crop_size_h == 0:
|
||||
y_start = 0
|
||||
else:
|
||||
y_start = np.random.randint(h - crop_size_h, size=1)[0]
|
||||
else:
|
||||
x_start = (w - crop_size_w) // 2
|
||||
y_start = (h - crop_size_h) // 2
|
||||
|
||||
cropped_im = im[y_start:y_start + crop_size_h, x_start:x_start + crop_size_w, :]
|
||||
|
||||
return cropped_im
|
||||
|
||||
@staticmethod
|
||||
def _hori_flip(im):
|
||||
im = np.fliplr(im).copy()
|
||||
return im
|
||||
|
||||
|
||||
class PrecompRegionDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features for COCO or Flickr
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, data_name, data_split, tokenizer, opt, train):
|
||||
self.tokenizer = tokenizer
|
||||
self.opt = opt
|
||||
self.train = train
|
||||
self.data_path = data_path
|
||||
self.data_name = data_name
|
||||
|
||||
# loc_cap = osp.join(data_path, 'precomp')
|
||||
# loc_image = osp.join(data_path, 'precomp')
|
||||
loc_cap = data_path
|
||||
loc_image = data_path
|
||||
|
||||
# Captions
|
||||
self.captions = []
|
||||
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip())
|
||||
# Image features
|
||||
self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split), mmap_mode = 'r')
|
||||
|
||||
self.length = len(self.captions)
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
num_images = len(self.images)
|
||||
|
||||
if num_images != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
# if data_split == 'test':
|
||||
# self.length = 5000
|
||||
|
||||
def __getitem__(self, index):
|
||||
# handle the image redundancy
|
||||
img_index = index // self.im_div
|
||||
caption = self.captions[index]
|
||||
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
|
||||
|
||||
# Convert caption (string) to word ids (with Size Augmentation at training time)
|
||||
target = process_caption(self.tokenizer, caption_tokens, self.train)
|
||||
image = self.images[img_index]
|
||||
if self.train: # Size augmentation for region feature
|
||||
num_features = image.shape[0]
|
||||
rand_list = np.random.rand(num_features)
|
||||
image = image[np.where(rand_list > 0.20)]
|
||||
image = torch.Tensor(image)
|
||||
return image, target, index, img_index
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def process_caption(tokenizer, tokens, train=True):
|
||||
output_tokens = []
|
||||
deleted_idx = []
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
sub_tokens = tokenizer.wordpiece_tokenizer.tokenize(token)
|
||||
prob = random.random()
|
||||
|
||||
if prob < 0.20 and train: # mask/remove the tokens only during training
|
||||
prob /= 0.20
|
||||
|
||||
# 50% randomly change token to mask token
|
||||
if prob < 0.5:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append("[MASK]")
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.6:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append(random.choice(list(tokenizer.vocab.keys())))
|
||||
# -> rest 10% randomly keep current token
|
||||
else:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append(sub_token)
|
||||
deleted_idx.append(len(output_tokens) - 1)
|
||||
else:
|
||||
for sub_token in sub_tokens:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
output_tokens.append(sub_token)
|
||||
|
||||
if len(deleted_idx) != 0:
|
||||
output_tokens = [output_tokens[i] for i in range(len(output_tokens)) if i not in deleted_idx]
|
||||
|
||||
output_tokens = ['[CLS]'] + output_tokens + ['[SEP]']
|
||||
target = tokenizer.convert_tokens_to_ids(output_tokens)
|
||||
target = torch.Tensor(target)
|
||||
return target
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
||||
Args:
|
||||
data: list of (image, caption) tuple.
|
||||
- image: torch tensor of shape (3, 256, 256).
|
||||
- caption: torch tensor of shape (?); variable length.
|
||||
|
||||
Returns:
|
||||
images: torch tensor of shape (batch_size, 3, 256, 256).
|
||||
targets: torch tensor of shape (batch_size, padded_length).
|
||||
lengths: list; valid length for each padded caption.
|
||||
"""
|
||||
images, captions, ids, img_ids = zip(*data)
|
||||
if len(images[0].shape) == 2: # region feature
|
||||
# Sort a data list by caption length
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
# images = torch.stack(images, 0)
|
||||
img_lengths = [len(image) for image in images]
|
||||
all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1))
|
||||
for i, image in enumerate(images):
|
||||
end = img_lengths[i]
|
||||
all_images[i, :end] = image[:end]
|
||||
img_lengths = torch.Tensor(img_lengths)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = [len(cap) for cap in captions]
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
|
||||
return all_images, img_lengths, targets, lengths, ids
|
||||
else: # raw input image
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
images = torch.stack(images, 0)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = [len(cap) for cap in captions]
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
return images, targets, lengths, ids
|
||||
|
||||
|
||||
def get_loader(data_path, data_name, data_split, tokenizer, opt, batch_size=100,
|
||||
shuffle=True, num_workers=0, train=True):
|
||||
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
||||
if train:
|
||||
drop_last = True
|
||||
else:
|
||||
drop_last = False
|
||||
if opt.precomp_enc_type == 'basic':
|
||||
dset = PrecompRegionDataset(data_path, data_name, data_split, tokenizer, opt, train)
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=num_workers,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
dset = RawImageDataset(data_path, data_name, data_split, tokenizer, opt, train)
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn)
|
||||
return data_loader
|
||||
|
||||
|
||||
def get_loaders(data_path, data_name, tokenizer, batch_size, workers, opt):
|
||||
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
|
||||
batch_size, True, workers)
|
||||
val_loader = get_loader(data_path, data_name, 'dev', tokenizer, opt,
|
||||
batch_size, False, workers, train=False)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def get_train_loader(data_path, data_name, tokenizer, batch_size, workers, opt, shuffle):
|
||||
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
|
||||
batch_size, shuffle, workers)
|
||||
return train_loader
|
||||
|
||||
|
||||
def get_test_loader(split_name, data_name, tokenizer, batch_size, workers, opt):
|
||||
test_loader = get_loader(opt.data_path, data_name, split_name, tokenizer, opt,
|
||||
batch_size, False, workers, train=False)
|
||||
return test_loader
|
|
@ -0,0 +1,16 @@
|
|||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
root_dir = osp.abspath(osp.dirname(osp.join(__file__, '..')))
|
||||
|
||||
# Add lib to PYTHONPATH
|
||||
lib_path = osp.join(root_dir, 'lib')
|
||||
datasets_path = osp.join(root_dir, 'datasets')
|
||||
add_path(lib_path)
|
||||
add_path(datasets_path)
|
|
@ -0,0 +1,369 @@
|
|||
"""COCO dataset loader"""
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from imageio import imread
|
||||
import random
|
||||
import json
|
||||
import cv2
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RawImageDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features
|
||||
Possible options: f30k_precomp, coco_precomp
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, data_name, data_split, tokenzier, opt, train):
|
||||
self.opt = opt
|
||||
self.train = train
|
||||
self.data_path = data_path
|
||||
self.data_name = data_name
|
||||
self.tokenizer = tokenzier
|
||||
|
||||
loc_cap = osp.join(data_path, 'precomp')
|
||||
loc_image = osp.join(data_path, 'precomp')
|
||||
loc_mapping = osp.join(data_path, 'id_mapping.json')
|
||||
if 'coco' in data_name:
|
||||
self.image_base = osp.join(data_path, 'images')
|
||||
else:
|
||||
self.image_base = osp.join(data_path, 'flickr30k-images')
|
||||
|
||||
with open(loc_mapping, 'r') as f_mapping:
|
||||
self.id_to_path = json.load(f_mapping)
|
||||
|
||||
# Read Captions
|
||||
self.captions = []
|
||||
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip())
|
||||
|
||||
# Get the image ids
|
||||
with open(osp.join(loc_image, '{}_ids.txt'.format(data_split)), 'r') as f:
|
||||
image_ids = f.readlines()
|
||||
self.images = [int(x.strip()) for x in image_ids]
|
||||
|
||||
# Set related parameters according to the pre-trained backbone **
|
||||
assert 'backbone' in opt.precomp_enc_type
|
||||
|
||||
self.backbone_source = opt.backbone_source
|
||||
self.base_target_size = 256
|
||||
self.crop_ratio = 0.875
|
||||
self.train_scale_rate = 1
|
||||
if hasattr(opt, 'input_scale_factor') and opt.input_scale_factor != 1:
|
||||
self.base_target_size = int(self.base_target_size * opt.input_scale_factor)
|
||||
logger.info('Input images are scaled by factor {}'.format(opt.input_scale_factor))
|
||||
if 'detector' in self.backbone_source:
|
||||
self.pixel_means = np.array([[[102.9801, 115.9465, 122.7717]]])
|
||||
else:
|
||||
self.imagenet_mean = [0.485, 0.456, 0.406]
|
||||
self.imagenet_std = [0.229, 0.224, 0.225]
|
||||
|
||||
self.length = len(self.captions)
|
||||
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
num_images = len(self.images)
|
||||
|
||||
if num_images != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_index = index // self.im_div
|
||||
caption = self.captions[index]
|
||||
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
|
||||
|
||||
# Convert caption (string) to word ids (with Size Augmentation at training time).
|
||||
target = process_caption(self.tokenizer, caption_tokens, self.train)
|
||||
|
||||
image_id = self.images[img_index]
|
||||
image_path = os.path.join(self.image_base, self.id_to_path[str(image_id)])
|
||||
im_in = np.array(imread(image_path))
|
||||
processed_image = self._process_image(im_in)
|
||||
image = torch.Tensor(processed_image)
|
||||
image = image.permute(2, 0, 1)
|
||||
return image, target, index, img_index
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def _process_image(self, im_in):
|
||||
"""
|
||||
Converts an image into a network input, with pre-processing including re-scaling, padding, etc, and data
|
||||
augmentation.
|
||||
"""
|
||||
if len(im_in.shape) == 2:
|
||||
im_in = im_in[:, :, np.newaxis]
|
||||
im_in = np.concatenate((im_in, im_in, im_in), axis=2)
|
||||
|
||||
if 'detector' in self.backbone_source:
|
||||
im_in = im_in[:, :, ::-1]
|
||||
im = im_in.astype(np.float32, copy=True)
|
||||
|
||||
if self.train:
|
||||
target_size = self.base_target_size * self.train_scale_rate
|
||||
else:
|
||||
target_size = self.base_target_size
|
||||
|
||||
# 2. Random crop when in training mode, elsewise just skip
|
||||
if self.train:
|
||||
crop_ratio = np.random.random() * 0.4 + 0.6
|
||||
crop_size_h = int(im.shape[0] * crop_ratio)
|
||||
crop_size_w = int(im.shape[1] * crop_ratio)
|
||||
processed_im = self._crop(im, crop_size_h, crop_size_w, random=True)
|
||||
else:
|
||||
processed_im = im
|
||||
|
||||
# 3. Resize to the target resolution
|
||||
im_shape = processed_im.shape
|
||||
im_scale_x = float(target_size) / im_shape[1]
|
||||
im_scale_y = float(target_size) / im_shape[0]
|
||||
processed_im = cv2.resize(processed_im, None, None, fx=im_scale_x, fy=im_scale_y,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.train:
|
||||
if np.random.random() > 0.5:
|
||||
processed_im = self._hori_flip(processed_im)
|
||||
|
||||
# Normalization
|
||||
if 'detector' in self.backbone_source:
|
||||
processed_im = self._detector_norm(processed_im)
|
||||
else:
|
||||
processed_im = self._imagenet_norm(processed_im)
|
||||
|
||||
return processed_im
|
||||
|
||||
def _imagenet_norm(self, im_in):
|
||||
im_in = im_in.astype(np.float32)
|
||||
im_in = im_in / 255
|
||||
for i in range(im_in.shape[-1]):
|
||||
im_in[:, :, i] = (im_in[:, :, i] - self.imagenet_mean[i]) / self.imagenet_std[i]
|
||||
return im_in
|
||||
|
||||
def _detector_norm(self, im_in):
|
||||
im_in = im_in.astype(np.float32)
|
||||
im_in -= self.pixel_means
|
||||
return im_in
|
||||
|
||||
@staticmethod
|
||||
def _crop(im, crop_size_h, crop_size_w, random):
|
||||
h, w = im.shape[0], im.shape[1]
|
||||
if random:
|
||||
if w - crop_size_w == 0:
|
||||
x_start = 0
|
||||
else:
|
||||
x_start = np.random.randint(w - crop_size_w, size=1)[0]
|
||||
if h - crop_size_h == 0:
|
||||
y_start = 0
|
||||
else:
|
||||
y_start = np.random.randint(h - crop_size_h, size=1)[0]
|
||||
else:
|
||||
x_start = (w - crop_size_w) // 2
|
||||
y_start = (h - crop_size_h) // 2
|
||||
|
||||
cropped_im = im[y_start:y_start + crop_size_h, x_start:x_start + crop_size_w, :]
|
||||
|
||||
return cropped_im
|
||||
|
||||
@staticmethod
|
||||
def _hori_flip(im):
|
||||
im = np.fliplr(im).copy()
|
||||
return im
|
||||
|
||||
|
||||
class PrecompRegionDataset(data.Dataset):
|
||||
"""
|
||||
Load precomputed captions and image features for COCO or Flickr
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, data_name, data_split, tokenizer, opt, train):
|
||||
self.tokenizer = tokenizer
|
||||
self.opt = opt
|
||||
self.train = train
|
||||
self.data_path = data_path
|
||||
self.data_name = data_name
|
||||
|
||||
loc_cap = osp.join(data_path, 'precomp')
|
||||
loc_image = osp.join(data_path, 'precomp')
|
||||
loc_cap = data_path
|
||||
loc_image = data_path
|
||||
|
||||
# Captions
|
||||
self.captions = []
|
||||
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
|
||||
for line in f:
|
||||
self.captions.append(line.strip())
|
||||
# Image features
|
||||
self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split), mmap_mode = 'r')
|
||||
|
||||
self.length = len(self.captions)
|
||||
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||||
num_images = len(self.images)
|
||||
|
||||
if num_images != self.length:
|
||||
self.im_div = 5
|
||||
else:
|
||||
self.im_div = 1
|
||||
# the development set for coco is large and so validation would be slow
|
||||
if data_split == 'dev':
|
||||
self.length = 5000
|
||||
|
||||
def __getitem__(self, index):
|
||||
# handle the image redundancy
|
||||
img_index = index // self.im_div
|
||||
caption = self.captions[index]
|
||||
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
|
||||
|
||||
# Convert caption (string) to word ids (with Size Augmentation at training time)
|
||||
target = process_caption(self.tokenizer, caption_tokens, self.train)
|
||||
image = self.images[img_index]
|
||||
if self.train: # Size augmentation for region feature
|
||||
num_features = image.shape[0]
|
||||
rand_list = np.random.rand(num_features)
|
||||
image = image[np.where(rand_list > 0.20)]
|
||||
image = torch.Tensor(image)
|
||||
return image, target, index, img_index
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def process_caption(tokenizer, tokens, train=True):
|
||||
output_tokens = []
|
||||
deleted_idx = []
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
sub_tokens = tokenizer.wordpiece_tokenizer.tokenize(token)
|
||||
prob = random.random()
|
||||
|
||||
if prob < 0.20 and train: # mask/remove the tokens only during training
|
||||
prob /= 0.20
|
||||
|
||||
# 50% randomly change token to mask token
|
||||
if prob < 0.5:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append("[MASK]")
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.6:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append(random.choice(list(tokenizer.vocab.keys())))
|
||||
# -> rest 10% randomly keep current token
|
||||
else:
|
||||
for sub_token in sub_tokens:
|
||||
output_tokens.append(sub_token)
|
||||
deleted_idx.append(len(output_tokens) - 1)
|
||||
else:
|
||||
for sub_token in sub_tokens:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
output_tokens.append(sub_token)
|
||||
|
||||
if len(deleted_idx) != 0:
|
||||
output_tokens = [output_tokens[i] for i in range(len(output_tokens)) if i not in deleted_idx]
|
||||
|
||||
output_tokens = ['[CLS]'] + output_tokens + ['[SEP]']
|
||||
target = tokenizer.convert_tokens_to_ids(output_tokens)
|
||||
target = torch.Tensor(target)
|
||||
return target
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
||||
Args:
|
||||
data: list of (image, caption) tuple.
|
||||
- image: torch tensor of shape (3, 256, 256).
|
||||
- caption: torch tensor of shape (?); variable length.
|
||||
|
||||
Returns:
|
||||
images: torch tensor of shape (batch_size, 3, 256, 256).
|
||||
targets: torch tensor of shape (batch_size, padded_length).
|
||||
lengths: list; valid length for each padded caption.
|
||||
"""
|
||||
images, captions, ids, img_ids = zip(*data)
|
||||
if len(images[0].shape) == 2: # region feature
|
||||
# Sort a data list by caption length
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
# images = torch.stack(images, 0)
|
||||
img_lengths = [len(image) for image in images]
|
||||
all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1))
|
||||
for i, image in enumerate(images):
|
||||
end = img_lengths[i]
|
||||
all_images[i, :end] = image[:end]
|
||||
img_lengths = torch.Tensor(img_lengths)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = [len(cap) for cap in captions]
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
|
||||
return all_images, img_lengths, targets, lengths, ids
|
||||
else: # raw input image
|
||||
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||||
images = torch.stack(images, 0)
|
||||
|
||||
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||||
lengths = [len(cap) for cap in captions]
|
||||
targets = torch.zeros(len(captions), max(lengths)).long()
|
||||
for i, cap in enumerate(captions):
|
||||
end = lengths[i]
|
||||
targets[i, :end] = cap[:end]
|
||||
return images, targets, lengths, ids
|
||||
|
||||
|
||||
def get_loader(data_path, data_name, data_split, tokenizer, opt, batch_size=100,
|
||||
shuffle=True, num_workers=0, train=True):
|
||||
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
||||
if train:
|
||||
drop_last = True
|
||||
else:
|
||||
drop_last = False
|
||||
if opt.precomp_enc_type == 'basic':
|
||||
dset = PrecompRegionDataset(data_path, data_name, data_split, tokenizer, opt, train)
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=num_workers,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
dset = RawImageDataset(data_path, data_name, data_split, tokenizer, opt, train)
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn)
|
||||
return data_loader
|
||||
|
||||
|
||||
def get_loaders(data_path, data_name, tokenizer, batch_size, workers, opt):
|
||||
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
|
||||
batch_size, True, workers)
|
||||
val_loader = get_loader(data_path, data_name, 'dev', tokenizer, opt,
|
||||
batch_size, False, workers, train=False)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def get_train_loader(data_path, data_name, tokenizer, batch_size, workers, opt, shuffle):
|
||||
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
|
||||
batch_size, shuffle, workers)
|
||||
return train_loader
|
||||
|
||||
|
||||
def get_test_loader(split_name, data_name, tokenizer, batch_size, workers, opt):
|
||||
test_loader = get_loader(opt.data_path, data_name, split_name, tokenizer, opt,
|
||||
batch_size, False, workers, train=False)
|
||||
return test_loader
|
|
@ -0,0 +1,228 @@
|
|||
"""VSE modules"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from transformers import BertModel, BertConfig, BertLayer
|
||||
|
||||
from lib.modules.resnet import ResnetFeatureExtractor
|
||||
from lib.modules.aggr.gpo import GPO
|
||||
from lib.modules.mlp import MLP
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def l1norm(X, dim, eps=1e-8):
|
||||
"""L1-normalize columns of X
|
||||
"""
|
||||
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def l2norm(X, dim, eps=1e-8):
|
||||
"""L2-normalize columns of X
|
||||
"""
|
||||
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def maxk_pool1d_var(x, dim, k, lengths):
|
||||
results = list()
|
||||
lengths = list(lengths.cpu().numpy())
|
||||
lengths = [int(x) for x in lengths]
|
||||
for idx, length in enumerate(lengths):
|
||||
k = min(k, length)
|
||||
max_k_i = maxk(x[idx, :length, :], dim - 1, k).mean(dim - 1)
|
||||
results.append(max_k_i)
|
||||
results = torch.stack(results, dim=0)
|
||||
return results
|
||||
|
||||
|
||||
def maxk_pool1d(x, dim, k):
|
||||
max_k = maxk(x, dim, k)
|
||||
return max_k.mean(dim)
|
||||
|
||||
|
||||
def maxk(x, dim, k):
|
||||
index = x.topk(k, dim=dim)[1]
|
||||
return x.gather(dim, index)
|
||||
|
||||
|
||||
def get_text_encoder(embed_size, no_txtnorm=False):
|
||||
return EncoderText(embed_size, no_txtnorm=no_txtnorm)
|
||||
|
||||
class TransformerMapping(nn.Module):
|
||||
""" Self-attention layer for image branch
|
||||
"""
|
||||
def __init__(self, opt):
|
||||
super(TransformerMapping, self).__init__()
|
||||
self.opt = opt
|
||||
self.no_imgnorm = opt.no_imgnorm
|
||||
bert_config = BertConfig.from_json_file(opt.trans_cfg)
|
||||
self.layer = BertLayer(bert_config)
|
||||
self.mapping = nn.Linear(opt.img_dim, opt.embed_size)
|
||||
#self.mapping2 = nn.Linear(opt.final_dims, opt.final_dims)
|
||||
self.mlp = MLP(opt.img_dim, opt.embed_size // 2, opt.embed_size, 2)
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
def forward(self, image, image_lengths):
|
||||
# x: (batch_size, patch_num, img_dim)
|
||||
x = self.mapping(image) # x: (batch_size, patch_num, final_dims)
|
||||
attention_mask = torch.ones(x.size(0), x.size(1))
|
||||
if torch.cuda.is_available():
|
||||
attention_mask = attention_mask.cuda()
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.float()
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
hidden_states = self.layer(x, extended_attention_mask)
|
||||
# hidden_states = self.mapping2(hidden_states)
|
||||
#embed = torch.mean(hidden_states, 1) # (batch_size, final_dims)
|
||||
#codes = torch.nn.functional.normalize(embed, p=2, dim=1) # (N, C)
|
||||
#return codes
|
||||
#print(hidden_states)
|
||||
features = self.mlp(image) + hidden_states[0]
|
||||
features, pool_weights = self.gpool(features, image_lengths)
|
||||
#print(features.shape)
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
return features
|
||||
|
||||
|
||||
def get_image_encoder(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
||||
backbone_source=None, backbone_path=None, no_imgnorm=False):
|
||||
"""A wrapper to image encoders. Chooses between an different encoders
|
||||
that uses precomputed image features.
|
||||
"""
|
||||
if precomp_enc_type == 'basic':
|
||||
img_enc = EncoderImageAggr(
|
||||
img_dim, embed_size, precomp_enc_type, no_imgnorm)
|
||||
elif precomp_enc_type == 'backbone':
|
||||
backbone_cnn = ResnetFeatureExtractor(backbone_source, backbone_path, fixed_blocks=2)
|
||||
img_enc = EncoderImageFull(backbone_cnn, img_dim, embed_size, precomp_enc_type, no_imgnorm)
|
||||
else:
|
||||
raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type))
|
||||
|
||||
return img_enc
|
||||
|
||||
|
||||
class EncoderImageAggr(nn.Module):
|
||||
def __init__(self, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
|
||||
super(EncoderImageAggr, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_imgnorm = no_imgnorm
|
||||
self.fc = nn.Linear(img_dim, embed_size)
|
||||
self.precomp_enc_type = precomp_enc_type
|
||||
if precomp_enc_type == 'basic':
|
||||
self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2)
|
||||
self.gpool = GPO(32, 32)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Xavier initialization for the fully connected layer
|
||||
"""
|
||||
r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
|
||||
self.fc.out_features)
|
||||
self.fc.weight.data.uniform_(-r, r)
|
||||
self.fc.bias.data.fill_(0)
|
||||
|
||||
def forward(self, images, image_lengths):
|
||||
"""Extract image feature vectors."""
|
||||
#print(images.shape)
|
||||
features = self.fc(images)
|
||||
#features = torch.mean(features, 1)
|
||||
if self.precomp_enc_type == 'basic':
|
||||
# When using pre-extracted region features, add an extra MLP for the embedding transformation
|
||||
features = self.mlp(images) + features
|
||||
|
||||
features, pool_weights = self.gpool(features, image_lengths)
|
||||
#print(features.shape)
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
class EncoderImageFull(nn.Module):
|
||||
def __init__(self, backbone_cnn, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
|
||||
super(EncoderImageFull, self).__init__()
|
||||
self.backbone = backbone_cnn
|
||||
self.image_encoder = EncoderImageAggr(img_dim, embed_size, precomp_enc_type, no_imgnorm)
|
||||
self.backbone_freezed = False
|
||||
|
||||
def forward(self, images):
|
||||
"""Extract image feature vectors."""
|
||||
base_features = self.backbone(images)
|
||||
|
||||
if self.training:
|
||||
# Size Augmentation during training, randomly drop grids
|
||||
base_length = base_features.size(1)
|
||||
features = []
|
||||
feat_lengths = []
|
||||
rand_list_1 = np.random.rand(base_features.size(0), base_features.size(1))
|
||||
rand_list_2 = np.random.rand(base_features.size(0))
|
||||
for i in range(base_features.size(0)):
|
||||
if rand_list_2[i] > 0.2:
|
||||
feat_i = base_features[i][np.where(rand_list_1[i] > 0.20 * rand_list_2[i])]
|
||||
len_i = len(feat_i)
|
||||
pads_i = torch.zeros(base_length - len_i, base_features.size(-1)).to(base_features.device)
|
||||
feat_i = torch.cat([feat_i, pads_i], dim=0)
|
||||
else:
|
||||
feat_i = base_features[i]
|
||||
len_i = base_length
|
||||
feat_lengths.append(len_i)
|
||||
features.append(feat_i)
|
||||
base_features = torch.stack(features, dim=0)
|
||||
base_features = base_features[:, :max(feat_lengths), :]
|
||||
feat_lengths = torch.tensor(feat_lengths).to(base_features.device)
|
||||
else:
|
||||
feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device)
|
||||
feat_lengths[:] = base_features.size(1)
|
||||
|
||||
features = self.image_encoder(base_features, feat_lengths)
|
||||
|
||||
return features
|
||||
|
||||
def freeze_backbone(self):
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
logger.info('Backbone freezed.')
|
||||
|
||||
def unfreeze_backbone(self, fixed_blocks):
|
||||
for param in self.backbone.parameters(): # open up all params first, then adjust the base parameters
|
||||
param.requires_grad = True
|
||||
self.backbone.set_fixed_blocks(fixed_blocks)
|
||||
self.backbone.unfreeze_base()
|
||||
logger.info('Backbone unfreezed, fixed blocks {}'.format(self.backbone.get_fixed_blocks()))
|
||||
|
||||
|
||||
# Language Model with BERT
|
||||
class EncoderText(nn.Module):
|
||||
def __init__(self, embed_size, no_txtnorm=False):
|
||||
super(EncoderText, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
||||
self.linear = nn.Linear(768, embed_size)
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
bert_attention_mask = (x != 0).float()
|
||||
bert_emb = self.bert(x, bert_attention_mask)[0] # B x N x D
|
||||
cap_len = lengths
|
||||
|
||||
cap_emb = self.linear(bert_emb)
|
||||
|
||||
pooled_features, pool_weights = self.gpool(cap_emb, cap_len.to(cap_emb.device))
|
||||
#pooled_features = torch.mean(cap_emb, 1)
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
pooled_features = l2norm(pooled_features, dim=-1)
|
||||
|
||||
return pooled_features
|
|
@ -0,0 +1,478 @@
|
|||
"""Evaluation"""
|
||||
from __future__ import print_function
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from collections import OrderedDict
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from lib.datasets import image_caption
|
||||
from lib.vse import VSEModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=0):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / (.0001 + self.count)
|
||||
|
||||
def __str__(self):
|
||||
"""String representation for logging
|
||||
"""
|
||||
# for values that should be recorded exactly e.g. iteration number
|
||||
if self.count == 0:
|
||||
return str(self.val)
|
||||
# for stats
|
||||
return '%.4f (%.4f)' % (self.val, self.avg)
|
||||
|
||||
|
||||
class LogCollector(object):
|
||||
"""A collection of logging objects that can change from train to val"""
|
||||
|
||||
def __init__(self):
|
||||
# to keep the order of logged variables deterministic
|
||||
self.meters = OrderedDict()
|
||||
|
||||
def update(self, k, v, n=0):
|
||||
# create a new meter if previously not recorded
|
||||
if k not in self.meters:
|
||||
self.meters[k] = AverageMeter()
|
||||
self.meters[k].update(v, n)
|
||||
|
||||
def __str__(self):
|
||||
"""Concatenate the meters in one log line
|
||||
"""
|
||||
s = ''
|
||||
for i, (k, v) in enumerate(self.meters.items()):
|
||||
if i > 0:
|
||||
s += ' '
|
||||
s += k + ' ' + str(v)
|
||||
return s
|
||||
|
||||
def tb_log(self, tb_logger, prefix='', step=None):
|
||||
"""Log using tensorboard
|
||||
"""
|
||||
for k, v in self.meters.items():
|
||||
tb_logger.log_value(prefix + k, v.val, step=step)
|
||||
|
||||
|
||||
def encode_data(model, data_loader, log_step=10, logging=logger.info, backbone=False):
|
||||
"""Encode all images and captions loadable by `data_loader`
|
||||
"""
|
||||
batch_time = AverageMeter()
|
||||
val_logger = LogCollector()
|
||||
|
||||
# switch to evaluate mode
|
||||
model.val_start()
|
||||
|
||||
end = time.time()
|
||||
|
||||
# np array to keep all the embeddings
|
||||
img_embs = None
|
||||
cap_embs = None
|
||||
|
||||
for i, data_i in enumerate(data_loader):
|
||||
# make sure val logger is used
|
||||
if not backbone:
|
||||
images, image_lengths, captions, lengths, ids = data_i
|
||||
else:
|
||||
images, captions, lengths, ids = data_i
|
||||
model.logger = val_logger
|
||||
|
||||
# compute the embeddings
|
||||
if not backbone:
|
||||
img_emb, cap_emb = model.forward_emb(images, captions, lengths, image_lengths=image_lengths)
|
||||
else:
|
||||
img_emb, cap_emb = model.forward_emb(images, captions, lengths)
|
||||
|
||||
if img_embs is None:
|
||||
if img_emb.dim() == 3:
|
||||
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
|
||||
else:
|
||||
img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
|
||||
cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1)))
|
||||
cap_lens = [0] * len(data_loader.dataset)
|
||||
# cache embeddings
|
||||
img_embs[ids] = img_emb.data.cpu().numpy().copy()
|
||||
cap_embs[ids, :] = cap_emb.data.cpu().numpy().copy()
|
||||
|
||||
# measure accuracy and record loss
|
||||
model.forward_loss(img_emb, cap_emb)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % log_step == 0:
|
||||
logging('Test: [{0}/{1}]\t'
|
||||
'{e_log}\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
.format(
|
||||
i, len(data_loader.dataset) // data_loader.batch_size + 1, batch_time=batch_time,
|
||||
e_log=str(model.logger)))
|
||||
del images, captions
|
||||
return img_embs, cap_embs
|
||||
|
||||
|
||||
def eval_ensemble(results_paths, fold5=False):
|
||||
all_sims = []
|
||||
all_npts = []
|
||||
for sim_path in results_paths:
|
||||
results = np.load(sim_path, allow_pickle=True).tolist()
|
||||
npts = results['npts']
|
||||
sims = results['sims']
|
||||
all_npts.append(npts)
|
||||
all_sims.append(sims)
|
||||
all_npts = np.array(all_npts)
|
||||
all_sims = np.array(all_sims)
|
||||
assert np.all(all_npts == all_npts[0])
|
||||
npts = int(all_npts[0])
|
||||
sims = all_sims.mean(axis=0)
|
||||
|
||||
if not fold5:
|
||||
r, rt = i2t(npts, sims, return_ranks=True)
|
||||
ri, rti = t2i(npts, sims, return_ranks=True)
|
||||
ar = (r[0] + r[1] + r[2]) / 3
|
||||
ari = (ri[0] + ri[1] + ri[2]) / 3
|
||||
rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
|
||||
logger.info("rsum: %.1f" % rsum)
|
||||
logger.info("Average i2t Recall: %.1f" % ar)
|
||||
logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
|
||||
logger.info("Average t2i Recall: %.1f" % ari)
|
||||
logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
|
||||
else:
|
||||
npts = npts // 5
|
||||
results = []
|
||||
all_sims = sims.copy()
|
||||
for i in range(5):
|
||||
sims = all_sims[i * npts: (i + 1) * npts, i * npts * 5: (i + 1) * npts * 5]
|
||||
r, rt0 = i2t(npts, sims, return_ranks=True)
|
||||
logger.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
|
||||
ri, rti0 = t2i(npts, sims, return_ranks=True)
|
||||
logger.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
|
||||
|
||||
if i == 0:
|
||||
rt, rti = rt0, rti0
|
||||
ar = (r[0] + r[1] + r[2]) / 3
|
||||
ari = (ri[0] + ri[1] + ri[2]) / 3
|
||||
rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
|
||||
logger.info("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
|
||||
results += [list(r) + list(ri) + [ar, ari, rsum]]
|
||||
logger.info("-----------------------------------")
|
||||
logger.info("Mean metrics: ")
|
||||
mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
|
||||
logger.info("rsum: %.1f" % (mean_metrics[12]))
|
||||
logger.info("Average i2t Recall: %.1f" % mean_metrics[10])
|
||||
logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" %
|
||||
mean_metrics[:5])
|
||||
logger.info("Average t2i Recall: %.1f" % mean_metrics[11])
|
||||
logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" %
|
||||
mean_metrics[5:10])
|
||||
|
||||
|
||||
def evalrank(model_path, data_path=None, split='dev', fold5=False, save_path=None, cxc=False):
|
||||
"""
|
||||
Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
|
||||
cross-validation is done (only for MSCOCO). Otherwise, the full data is
|
||||
used for evaluation.
|
||||
"""
|
||||
# load model and options
|
||||
checkpoint = torch.load(model_path)
|
||||
opt = checkpoint['opt']
|
||||
opt.workers = 5
|
||||
|
||||
logger.info(opt)
|
||||
if not hasattr(opt, 'caption_loss'):
|
||||
opt.caption_loss = False
|
||||
|
||||
# load vocabulary used by the model
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
vocab = tokenizer.vocab
|
||||
opt.vocab_size = len(vocab)
|
||||
|
||||
opt.backbone_path = '/tmp/data/weights/original_updown_backbone.pth'
|
||||
if data_path is not None:
|
||||
opt.data_path = data_path
|
||||
|
||||
# construct model
|
||||
model = VSEModel(opt)
|
||||
|
||||
model.make_data_parallel()
|
||||
# load model state
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.val_start()
|
||||
|
||||
logger.info('Loading dataset')
|
||||
data_loader = image_caption.get_test_loader(split, opt.data_name, tokenizer,
|
||||
opt.batch_size, opt.workers, opt)
|
||||
|
||||
logger.info('Computing results...')
|
||||
with torch.no_grad():
|
||||
if opt.precomp_enc_type == 'basic':
|
||||
img_embs, cap_embs = encode_data(model, data_loader)
|
||||
else:
|
||||
img_embs, cap_embs = encode_data(model, data_loader, backbone=True)
|
||||
logger.info('Images: %d, Captions: %d' %
|
||||
(img_embs.shape[0] / 5, cap_embs.shape[0]))
|
||||
|
||||
if cxc:
|
||||
eval_cxc(img_embs, cap_embs, data_path)
|
||||
else:
|
||||
if not fold5:
|
||||
# no cross-validation, full evaluation
|
||||
img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
|
||||
start = time.time()
|
||||
|
||||
sims = compute_sim(img_embs, cap_embs)
|
||||
npts = img_embs.shape[0]
|
||||
|
||||
if save_path is not None:
|
||||
np.save(save_path, {'npts': npts, 'sims': sims})
|
||||
logger.info('Save the similarity into {}'.format(save_path))
|
||||
|
||||
end = time.time()
|
||||
logger.info("calculate similarity time: {}".format(end - start))
|
||||
|
||||
r, rt = i2t(npts, sims, return_ranks=True)
|
||||
ri, rti = t2i(npts, sims, return_ranks=True)
|
||||
ar = (r[0] + r[1] + r[2]) / 3
|
||||
ari = (ri[0] + ri[1] + ri[2]) / 3
|
||||
rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
|
||||
logger.info("rsum: %.1f" % rsum)
|
||||
logger.info("Average i2t Recall: %.1f" % ar)
|
||||
logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
|
||||
logger.info("Average t2i Recall: %.1f" % ari)
|
||||
logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
|
||||
else:
|
||||
# 5fold cross-validation, only for MSCOCO
|
||||
results = []
|
||||
for i in range(5):
|
||||
img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
|
||||
cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
|
||||
start = time.time()
|
||||
sims = compute_sim(img_embs_shard, cap_embs_shard)
|
||||
end = time.time()
|
||||
logger.info("calculate similarity time: {}".format(end - start))
|
||||
|
||||
npts = img_embs_shard.shape[0]
|
||||
r, rt0 = i2t(npts, sims, return_ranks=True)
|
||||
logger.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
|
||||
ri, rti0 = t2i(npts, sims, return_ranks=True)
|
||||
logger.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
|
||||
|
||||
if i == 0:
|
||||
rt, rti = rt0, rti0
|
||||
ar = (r[0] + r[1] + r[2]) / 3
|
||||
ari = (ri[0] + ri[1] + ri[2]) / 3
|
||||
rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
|
||||
logger.info("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
|
||||
results += [list(r) + list(ri) + [ar, ari, rsum]]
|
||||
|
||||
logger.info("-----------------------------------")
|
||||
logger.info("Mean metrics: ")
|
||||
mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
|
||||
logger.info("rsum: %.1f" % (mean_metrics[12]))
|
||||
logger.info("Average i2t Recall: %.1f" % mean_metrics[10])
|
||||
logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" %
|
||||
mean_metrics[:5])
|
||||
logger.info("Average t2i Recall: %.1f" % mean_metrics[11])
|
||||
logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" %
|
||||
mean_metrics[5:10])
|
||||
|
||||
|
||||
def compute_sim(images, captions):
|
||||
similarities = np.matmul(images, np.matrix.transpose(captions))
|
||||
return similarities
|
||||
|
||||
|
||||
def i2t(npts, sims, return_ranks=False, mode='coco'):
|
||||
"""
|
||||
Images->Text (Image Annotation)
|
||||
Images: (N, n_region, d) matrix of images
|
||||
Captions: (5N, max_n_word, d) matrix of captions
|
||||
CapLens: (5N) array of caption lengths
|
||||
sims: (N, 5N) matrix of similarity im-cap
|
||||
"""
|
||||
ranks = np.zeros(npts)
|
||||
top1 = np.zeros(npts)
|
||||
for index in range(npts):
|
||||
inds = np.argsort(sims[index])[::-1]
|
||||
if mode == 'coco':
|
||||
rank = 1e20
|
||||
for i in range(5 * index, 5 * index + 5, 1):
|
||||
tmp = np.where(inds == i)[0][0]
|
||||
if tmp < rank:
|
||||
rank = tmp
|
||||
ranks[index] = rank
|
||||
top1[index] = inds[0]
|
||||
else:
|
||||
rank = np.where(inds == index)[0][0]
|
||||
ranks[index] = rank
|
||||
top1[index] = inds[0]
|
||||
|
||||
# Compute metrics
|
||||
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
medr = np.floor(np.median(ranks)) + 1
|
||||
meanr = ranks.mean() + 1
|
||||
|
||||
if return_ranks:
|
||||
return (r1, r5, r10, medr, meanr), (ranks, top1)
|
||||
else:
|
||||
return (r1, r5, r10, medr, meanr)
|
||||
|
||||
|
||||
def t2i(npts, sims, return_ranks=False, mode='coco'):
|
||||
"""
|
||||
Text->Images (Image Search)
|
||||
Images: (N, n_region, d) matrix of images
|
||||
Captions: (5N, max_n_word, d) matrix of captions
|
||||
CapLens: (5N) array of caption lengths
|
||||
sims: (N, 5N) matrix of similarity im-cap
|
||||
"""
|
||||
# npts = images.shape[0]
|
||||
|
||||
if mode == 'coco':
|
||||
ranks = np.zeros(5 * npts)
|
||||
top1 = np.zeros(5 * npts)
|
||||
else:
|
||||
ranks = np.zeros(npts)
|
||||
top1 = np.zeros(npts)
|
||||
|
||||
# --> (5N(caption), N(image))
|
||||
sims = sims.T
|
||||
|
||||
for index in range(npts):
|
||||
if mode == 'coco':
|
||||
for i in range(5):
|
||||
inds = np.argsort(sims[5 * index + i])[::-1]
|
||||
ranks[5 * index + i] = np.where(inds == index)[0][0]
|
||||
top1[5 * index + i] = inds[0]
|
||||
else:
|
||||
inds = np.argsort(sims[index])[::-1]
|
||||
ranks[index] = np.where(inds == index)[0][0]
|
||||
top1[index] = inds[0]
|
||||
|
||||
# Compute metrics
|
||||
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
medr = np.floor(np.median(ranks)) + 1
|
||||
meanr = ranks.mean() + 1
|
||||
if return_ranks:
|
||||
return (r1, r5, r10, medr, meanr), (ranks, top1)
|
||||
else:
|
||||
return (r1, r5, r10, medr, meanr)
|
||||
|
||||
|
||||
"""
|
||||
CxC related evaluation.
|
||||
"""
|
||||
|
||||
def eval_cxc(images, captions, data_path):
|
||||
import os
|
||||
import json
|
||||
cxc_annot_base = os.path.join(data_path, 'cxc_annots')
|
||||
img_id_path = os.path.join(cxc_annot_base, 'testall_ids.txt')
|
||||
cap_id_path = os.path.join(cxc_annot_base, 'testall_capids.txt')
|
||||
|
||||
images = images[::5, :]
|
||||
|
||||
with open(img_id_path) as f:
|
||||
img_ids = f.readlines()
|
||||
with open(cap_id_path) as f:
|
||||
cap_ids = f.readlines()
|
||||
|
||||
img_ids = [img_id.strip() for i, img_id in enumerate(img_ids) if i % 5 == 0]
|
||||
cap_ids = [cap_id.strip() for cap_id in cap_ids]
|
||||
|
||||
with open(os.path.join(cxc_annot_base, 'cxc_it.json')) as f_it:
|
||||
cxc_it = json.load(f_it)
|
||||
with open(os.path.join(cxc_annot_base, 'cxc_i2i.json')) as f_i2i:
|
||||
cxc_i2i = json.load(f_i2i)
|
||||
with open(os.path.join(cxc_annot_base, 'cxc_t2t.json')) as f_t2t:
|
||||
cxc_t2t = json.load(f_t2t)
|
||||
|
||||
sims = compute_sim(images, captions)
|
||||
t2i_recalls = cxc_inter(sims.T, img_ids, cap_ids, cxc_it['t2i'])
|
||||
i2t_recalls = cxc_inter(sims, cap_ids, img_ids, cxc_it['i2t'])
|
||||
logger.info('T2I R@1: {}, R@5: {}, R@10: {}'.format(*t2i_recalls))
|
||||
logger.info('I2T R@1: {}, R@5: {}, R@10: {}'.format(*i2t_recalls))
|
||||
|
||||
i2i_recalls = cxc_intra(images, img_ids, cxc_i2i)
|
||||
t2t_recalls = cxc_intra(captions, cap_ids, cxc_t2t, text=True)
|
||||
logger.info('I2I R@1: {}, R@5: {}, R@10: {}'.format(*i2i_recalls))
|
||||
logger.info('T2T R@1: {}, R@5: {}, R@10: {}'.format(*t2t_recalls))
|
||||
|
||||
|
||||
def cxc_inter(sims, data_ids, query_ids, annot):
|
||||
ranks = list()
|
||||
for idx, query_id in enumerate(query_ids):
|
||||
if query_id not in annot:
|
||||
raise ValueError('unexpected query id {}'.format(query_id))
|
||||
pos_data_ids = annot[query_id]
|
||||
pos_data_ids = [pos_data_id for pos_data_id in pos_data_ids if str(pos_data_id[0]) in data_ids]
|
||||
pos_data_indices = [data_ids.index(str(pos_data_id[0])) for pos_data_id in pos_data_ids]
|
||||
rank = 1e20
|
||||
inds = np.argsort(sims[idx])[::-1]
|
||||
for pos_data_idx in pos_data_indices:
|
||||
tmp = np.where(inds == pos_data_idx)[0][0]
|
||||
if tmp < rank:
|
||||
rank = tmp
|
||||
ranks.append(rank)
|
||||
ranks = np.array(ranks)
|
||||
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
return (r1, r5, r10)
|
||||
|
||||
|
||||
def cxc_intra(embs, data_ids, annot, text=False):
|
||||
pos_thresh = 3.0 if text else 2.5 # threshold for positive pairs according to the CxC paper
|
||||
|
||||
sims = compute_sim(embs, embs)
|
||||
np.fill_diagonal(sims, 0)
|
||||
|
||||
ranks = list()
|
||||
for idx, data_id in enumerate(data_ids):
|
||||
sim_items = annot[data_id]
|
||||
pos_items = [item for item in sim_items if item[1] >= pos_thresh]
|
||||
rank = 1e20
|
||||
inds = np.argsort(sims[idx])[::-1]
|
||||
if text:
|
||||
coco_pos = list(range(idx // 5 * 5, (idx // 5 + 1) * 5))
|
||||
coco_pos.remove(idx)
|
||||
pos_indices = coco_pos
|
||||
pos_indices.extend([data_ids.index(str(pos_item[0])) for pos_item in pos_items])
|
||||
else:
|
||||
pos_indices = [data_ids.index(str(pos_item[0])) for pos_item in pos_items]
|
||||
if len(pos_indices) == 0: # skip it since there is positive example in the annotation
|
||||
continue
|
||||
for pos_idx in pos_indices:
|
||||
tmp = np.where(inds == pos_idx)[0][0]
|
||||
if tmp < rank:
|
||||
rank = tmp
|
||||
ranks.append(rank)
|
||||
|
||||
ranks = np.array(ranks)
|
||||
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
return (r1, r5, r10)
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Compute contrastive loss (max-margin based)
|
||||
"""
|
||||
|
||||
def __init__(self, opt, margin=0, max_violation=False):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.opt = opt
|
||||
self.margin = margin
|
||||
self.max_violation = max_violation
|
||||
|
||||
def max_violation_on(self):
|
||||
self.max_violation = True
|
||||
print('Use VSE++ objective.')
|
||||
|
||||
def max_violation_off(self):
|
||||
self.max_violation = False
|
||||
print('Use VSE0 objective.')
|
||||
|
||||
def forward(self, im, s):
|
||||
# compute image-sentence score matrix
|
||||
scores = get_sim(im, s)
|
||||
diagonal = scores.diag().view(im.size(0), 1)
|
||||
d1 = diagonal.expand_as(scores)
|
||||
d2 = diagonal.t().expand_as(scores)
|
||||
|
||||
# compare every diagonal score to scores in its column
|
||||
# caption retrieval
|
||||
cost_s = (self.margin + scores - d1).clamp(min=0)
|
||||
# compare every diagonal score to scores in its row
|
||||
# image retrieval
|
||||
cost_im = (self.margin + scores - d2).clamp(min=0)
|
||||
|
||||
# clear diagonals
|
||||
mask = torch.eye(scores.size(0)) > .5
|
||||
I = Variable(mask)
|
||||
if torch.cuda.is_available():
|
||||
I = I.cuda()
|
||||
cost_s = cost_s.masked_fill_(I, 0)
|
||||
cost_im = cost_im.masked_fill_(I, 0)
|
||||
|
||||
# keep the maximum violating negative for each query
|
||||
if self.max_violation:
|
||||
cost_s = cost_s.max(1)[0]
|
||||
cost_im = cost_im.max(0)[0]
|
||||
|
||||
return cost_s.sum() + cost_im.sum()
|
||||
|
||||
|
||||
def get_sim(images, captions):
|
||||
similarities = images.mm(captions.t())
|
||||
return similarities
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# coding=utf-8
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
def positional_encoding_1d(d_model, length):
|
||||
"""
|
||||
:param d_model: dimension of the model
|
||||
:param length: length of positions
|
||||
:return: length*d_model position matrix
|
||||
"""
|
||||
if d_model % 2 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd dim (got dim={:d})".format(d_model))
|
||||
pe = torch.zeros(length, d_model)
|
||||
position = torch.arange(0, length).unsqueeze(1)
|
||||
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / d_model)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
class GPO(nn.Module):
|
||||
def __init__(self, d_pe, d_hidden):
|
||||
super(GPO, self).__init__()
|
||||
self.d_pe = d_pe
|
||||
self.d_hidden = d_hidden
|
||||
|
||||
self.pe_database = {}
|
||||
self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
|
||||
self.linear = nn.Linear(self.d_hidden, 1, bias=False)
|
||||
|
||||
def compute_pool_weights(self, lengths, features):
|
||||
max_len = int(lengths.max())
|
||||
pe_max_len = self.get_pe(max_len)
|
||||
pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
|
||||
mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
|
||||
mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
|
||||
pes = pes.masked_fill(mask == 0, 0)
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
||||
out, _ = self.gru(packed)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
out_emb, out_len = padded
|
||||
out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
|
||||
scores = self.linear(out_emb)
|
||||
scores[torch.where(mask == 0)] = -10000
|
||||
|
||||
weights = torch.softmax(scores / 0.1, 1)
|
||||
return weights, mask
|
||||
|
||||
def forward(self, features, lengths):
|
||||
"""
|
||||
:param features: features with shape B x K x D
|
||||
:param lengths: B x 1, specify the length of each data sample.
|
||||
:return: pooled feature with shape B x D
|
||||
"""
|
||||
pool_weights, mask = self.compute_pool_weights(lengths, features)
|
||||
|
||||
features = features[:, :int(lengths.max()), :]
|
||||
sorted_features = features.masked_fill(mask == 0, -10000)
|
||||
sorted_features = sorted_features.sort(dim=1, descending=True)[0]
|
||||
sorted_features = sorted_features.masked_fill(mask == 0, 0)
|
||||
|
||||
pooled_features = (sorted_features * pool_weights).sum(1)
|
||||
return pooled_features, pool_weights
|
||||
|
||||
def get_pe(self, length):
|
||||
"""
|
||||
|
||||
:param length: the length of the sequence
|
||||
:return: the positional encoding of the given length
|
||||
"""
|
||||
length = int(length)
|
||||
if length in self.pe_database:
|
||||
return self.pe_database[length]
|
||||
else:
|
||||
pe = positional_encoding_1d(self.d_pe, length)
|
||||
self.pe_database[length] = pe
|
||||
return pe
|
|
@ -0,0 +1,47 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TwoLayerMLP(nn.Module):
|
||||
def __init__(self, num_features, hid_dim, out_dim, return_hidden=False):
|
||||
super().__init__()
|
||||
self.return_hidden = return_hidden
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(num_features, hid_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hid_dim, out_dim),
|
||||
)
|
||||
|
||||
for m in self.model:
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.return_hidden:
|
||||
return self.model(x)
|
||||
else:
|
||||
hid_feat = self.model[:2](x)
|
||||
results = self.model[2:](hid_feat)
|
||||
return hid_feat, results
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
""" Very simple multi-layer perceptron (also called FFN)"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
|
||||
|
||||
def forward(self, x):
|
||||
B, N, D = x.size()
|
||||
x = x.reshape(B*N, D)
|
||||
for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
|
||||
x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
|
||||
x = x.view(B, N, self.output_dim)
|
||||
return x
|
||||
|
|
@ -0,0 +1,310 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
|
||||
|
||||
model_urls = {
|
||||
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, width_mult, num_classes=1000):
|
||||
self.inplanes = 64 * width_mult
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
|
||||
self.layer1 = self._make_layer(block, 64 * width_mult, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128 * width_mult, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256 * width_mult, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512 * width_mult, layers[3], stride=2)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion * width_mult, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet50(pretrained=False, width_mult=1):
|
||||
"""Constructs a ResNet-50 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], width_mult)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, width_mult=1):
|
||||
"""Constructs a ResNet-101 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], width_mult)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, width_mult=1):
|
||||
"""Constructs a ResNet-152 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], width_mult)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
return model
|
||||
|
||||
|
||||
class ResnetFeatureExtractor(nn.Module):
|
||||
def __init__(self, backbone_source, weights_path, pooling_size=7, fixed_blocks=2):
|
||||
super(ResnetFeatureExtractor, self).__init__()
|
||||
self.backbone_source = backbone_source
|
||||
self.weights_path = weights_path
|
||||
self.pooling_size = pooling_size
|
||||
self.fixed_blocks = fixed_blocks
|
||||
|
||||
if 'detector' in self.backbone_source:
|
||||
self.resnet = resnet101()
|
||||
elif self.backbone_source == 'imagenet':
|
||||
self.resnet = resnet101(pretrained=True)
|
||||
elif self.backbone_source == 'imagenet_res50':
|
||||
self.resnet = resnet50(pretrained=True)
|
||||
elif self.backbone_source == 'imagenet_res152':
|
||||
self.resnet = resnet152(pretrained=True)
|
||||
elif self.backbone_source == 'imagenet_resnext':
|
||||
self.resnet = torch.hub.load('pytorch/vision:v0.4.2', 'resnext101_32x8d', pretrained=True)
|
||||
elif 'wsl' in self.backbone_source:
|
||||
self.resnet = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
|
||||
else:
|
||||
raise ValueError('Unknown backbone source {}'.format(self.backbone_source))
|
||||
|
||||
self._init_modules()
|
||||
|
||||
def _init_modules(self):
|
||||
# Build resnet.
|
||||
self.base = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu,
|
||||
self.resnet.maxpool, self.resnet.layer1, self.resnet.layer2, self.resnet.layer3)
|
||||
self.top = nn.Sequential(self.resnet.layer4)
|
||||
|
||||
if self.weights_path != '':
|
||||
if 'detector' in self.backbone_source:
|
||||
if os.path.exists(self.weights_path):
|
||||
logger.info(
|
||||
'Loading pretrained backbone weights from {} for backbone source {}'.format(self.weights_path,
|
||||
self.backbone_source))
|
||||
backbone_ckpt = torch.load(self.weights_path)
|
||||
self.base.load_state_dict(backbone_ckpt['base'])
|
||||
self.top.load_state_dict(backbone_ckpt['top'])
|
||||
else:
|
||||
raise ValueError('Could not find weights for backbone CNN at {}'.format(self.weights_path))
|
||||
else:
|
||||
logger.info('Did not load external checkpoints')
|
||||
self.unfreeze_base()
|
||||
|
||||
def set_fixed_blocks(self, fixed_blocks):
|
||||
self.fixed_blocks = fixed_blocks
|
||||
|
||||
def get_fixed_blocks(self):
|
||||
return self.fixed_blocks
|
||||
|
||||
def unfreeze_base(self):
|
||||
assert (0 <= self.fixed_blocks < 4)
|
||||
if self.fixed_blocks == 3:
|
||||
for p in self.base[6].parameters(): p.requires_grad = False
|
||||
for p in self.base[5].parameters(): p.requires_grad = False
|
||||
for p in self.base[4].parameters(): p.requires_grad = False
|
||||
for p in self.base[0].parameters(): p.requires_grad = False
|
||||
for p in self.base[1].parameters(): p.requires_grad = False
|
||||
if self.fixed_blocks == 2:
|
||||
for p in self.base[6].parameters(): p.requires_grad = True
|
||||
for p in self.base[5].parameters(): p.requires_grad = False
|
||||
for p in self.base[4].parameters(): p.requires_grad = False
|
||||
for p in self.base[0].parameters(): p.requires_grad = False
|
||||
for p in self.base[1].parameters(): p.requires_grad = False
|
||||
if self.fixed_blocks == 1:
|
||||
for p in self.base[6].parameters(): p.requires_grad = True
|
||||
for p in self.base[5].parameters(): p.requires_grad = True
|
||||
for p in self.base[4].parameters(): p.requires_grad = False
|
||||
for p in self.base[0].parameters(): p.requires_grad = False
|
||||
for p in self.base[1].parameters(): p.requires_grad = False
|
||||
if self.fixed_blocks == 0:
|
||||
for p in self.base[6].parameters(): p.requires_grad = True
|
||||
for p in self.base[5].parameters(): p.requires_grad = True
|
||||
for p in self.base[4].parameters(): p.requires_grad = True
|
||||
for p in self.base[0].parameters(): p.requires_grad = True
|
||||
for p in self.base[1].parameters(): p.requires_grad = True
|
||||
|
||||
logger.info('Resnet backbone now has fixed blocks {}'.format(self.fixed_blocks))
|
||||
|
||||
def freeze_base(self):
|
||||
for p in self.base.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
# Override train so that the training mode is set as we want (BN does not update the running stats)
|
||||
nn.Module.train(self, mode)
|
||||
if mode:
|
||||
# fix all bn layers
|
||||
def set_bn_eval(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
m.eval()
|
||||
|
||||
self.base.apply(set_bn_eval)
|
||||
self.top.apply(set_bn_eval)
|
||||
|
||||
def _head_to_tail(self, pool5):
|
||||
fc7 = self.top(pool5).mean(3).mean(2)
|
||||
return fc7
|
||||
|
||||
def forward(self, im_data):
|
||||
b_s = im_data.size(0)
|
||||
base_feat = self.base(im_data)
|
||||
top_feat = self.top(base_feat)
|
||||
features = top_feat.view(b_s, top_feat.size(1), -1).permute(0, 2, 1)
|
||||
return features
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
|
||||
def count_params(model):
|
||||
model_parameters = model.parameters()
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
model = resnet50(pretrained=False, width_mult=1)
|
||||
num_params = count_params(model)
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
"""VSE model"""
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from lib.encoders import get_image_encoder, get_text_encoder, TransformerMapping
|
||||
from lib.loss import ContrastiveLoss
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VSEModel(object):
|
||||
"""
|
||||
The standard VSE model
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
# Build Models
|
||||
self.grad_clip = opt.grad_clip
|
||||
#self.img_enc = TransformerMapping(opt)
|
||||
self.img_enc = get_image_encoder(opt.data_name, opt.img_dim, opt.embed_size,
|
||||
precomp_enc_type=opt.precomp_enc_type,
|
||||
backbone_source=opt.backbone_source,
|
||||
backbone_path=opt.backbone_path,
|
||||
no_imgnorm=opt.no_imgnorm)
|
||||
self.txt_enc = get_text_encoder(opt.embed_size, no_txtnorm=opt.no_txtnorm)
|
||||
if torch.cuda.is_available():
|
||||
self.img_enc.cuda()
|
||||
self.txt_enc.cuda()
|
||||
cudnn.benchmark = True
|
||||
|
||||
# Loss and Optimizer
|
||||
self.criterion = ContrastiveLoss(opt=opt,
|
||||
margin=opt.margin,
|
||||
max_violation=opt.max_violation)
|
||||
|
||||
params = list(self.txt_enc.parameters())
|
||||
params += list(self.img_enc.parameters())
|
||||
|
||||
self.params = params
|
||||
self.opt = opt
|
||||
|
||||
# Set up the lr for different parts of the VSE model
|
||||
decay_factor = 1e-4
|
||||
if opt.precomp_enc_type == 'basic':
|
||||
if self.opt.optim == 'adam':
|
||||
all_text_params = list(self.txt_enc.parameters())
|
||||
bert_params = list(self.txt_enc.bert.parameters())
|
||||
bert_params_ptr = [p.data_ptr() for p in bert_params]
|
||||
text_params_no_bert = list()
|
||||
for p in all_text_params:
|
||||
if p.data_ptr() not in bert_params_ptr:
|
||||
text_params_no_bert.append(p)
|
||||
self.optimizer = torch.optim.AdamW([
|
||||
{'params': text_params_no_bert, 'lr': opt.learning_rate},
|
||||
{'params': bert_params, 'lr': opt.learning_rate * 0.1},
|
||||
{'params': self.img_enc.parameters(), 'lr': opt.learning_rate},
|
||||
],
|
||||
lr=opt.learning_rate, weight_decay=decay_factor)
|
||||
elif self.opt.optim == 'sgd':
|
||||
self.optimizer = torch.optim.SGD(self.params, lr=opt.learning_rate, momentum=0.9)
|
||||
else:
|
||||
raise ValueError('Invalid optim option {}'.format(self.opt.optim))
|
||||
else:
|
||||
if self.opt.optim == 'adam':
|
||||
all_text_params = list(self.txt_enc.parameters())
|
||||
bert_params = list(self.txt_enc.bert.parameters())
|
||||
bert_params_ptr = [p.data_ptr() for p in bert_params]
|
||||
text_params_no_bert = list()
|
||||
for p in all_text_params:
|
||||
if p.data_ptr() not in bert_params_ptr:
|
||||
text_params_no_bert.append(p)
|
||||
self.optimizer = torch.optim.AdamW([
|
||||
{'params': text_params_no_bert, 'lr': opt.learning_rate},
|
||||
{'params': bert_params, 'lr': opt.learning_rate * 0.1},
|
||||
{'params': self.img_enc.backbone.top.parameters(),
|
||||
'lr': opt.learning_rate * opt.backbone_lr_factor, },
|
||||
{'params': self.img_enc.backbone.base.parameters(),
|
||||
'lr': opt.learning_rate * opt.backbone_lr_factor, },
|
||||
{'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
|
||||
], lr=opt.learning_rate, weight_decay=decay_factor)
|
||||
elif self.opt.optim == 'sgd':
|
||||
self.optimizer = torch.optim.SGD([
|
||||
{'params': self.txt_enc.parameters(), 'lr': opt.learning_rate},
|
||||
{'params': self.img_enc.backbone.parameters(), 'lr': opt.learning_rate * opt.backbone_lr_factor,
|
||||
'weight_decay': decay_factor},
|
||||
{'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
|
||||
], lr=opt.learning_rate, momentum=0.9, nesterov=True)
|
||||
else:
|
||||
raise ValueError('Invalid optim option {}'.format(self.opt.optim))
|
||||
|
||||
logger.info('Use {} as the optimizer, with init lr {}'.format(self.opt.optim, opt.learning_rate))
|
||||
|
||||
self.Eiters = 0
|
||||
self.data_parallel = False
|
||||
|
||||
def set_max_violation(self, max_violation):
|
||||
if max_violation:
|
||||
self.criterion.max_violation_on()
|
||||
else:
|
||||
self.criterion.max_violation_off()
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict()]
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.img_enc.load_state_dict(state_dict[0], strict=False)
|
||||
self.txt_enc.load_state_dict(state_dict[1], strict=False)
|
||||
|
||||
def train_start(self):
|
||||
"""switch to train mode
|
||||
"""
|
||||
self.img_enc.train()
|
||||
self.txt_enc.train()
|
||||
|
||||
def val_start(self):
|
||||
"""switch to evaluate mode
|
||||
"""
|
||||
self.img_enc.eval()
|
||||
self.txt_enc.eval()
|
||||
|
||||
def freeze_backbone(self):
|
||||
if 'backbone' in self.opt.precomp_enc_type:
|
||||
if isinstance(self.img_enc, nn.DataParallel):
|
||||
self.img_enc.module.freeze_backbone()
|
||||
else:
|
||||
self.img_enc.freeze_backbone()
|
||||
|
||||
def unfreeze_backbone(self, fixed_blocks):
|
||||
if 'backbone' in self.opt.precomp_enc_type:
|
||||
if isinstance(self.img_enc, nn.DataParallel):
|
||||
self.img_enc.module.unfreeze_backbone(fixed_blocks)
|
||||
else:
|
||||
self.img_enc.unfreeze_backbone(fixed_blocks)
|
||||
|
||||
def make_data_parallel(self):
|
||||
self.img_enc = nn.DataParallel(self.img_enc)
|
||||
self.txt_enc = nn.DataParallel(self.txt_enc)
|
||||
self.data_parallel = True
|
||||
logger.info('Image encoder is data paralleled now.')
|
||||
|
||||
@property
|
||||
def is_data_parallel(self):
|
||||
return self.data_parallel
|
||||
|
||||
def forward_emb(self, images, captions, lengths, image_lengths=None):
|
||||
"""Compute the image and caption embeddings
|
||||
"""
|
||||
# Set mini-batch dataset
|
||||
if self.opt.precomp_enc_type == 'basic':
|
||||
if torch.cuda.is_available():
|
||||
images = images.cuda()
|
||||
captions = captions.cuda()
|
||||
image_lengths = image_lengths.cuda()
|
||||
img_emb = self.img_enc(images, image_lengths)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
images = images.cuda()
|
||||
captions = captions.cuda()
|
||||
img_emb = self.img_enc(images)
|
||||
|
||||
#lengths = torch.Tensor(lengths).cuda()
|
||||
lengths = torch.tensor(lengths).cuda()
|
||||
cap_emb = self.txt_enc(captions, lengths)
|
||||
return img_emb, cap_emb
|
||||
|
||||
def forward_loss(self, img_emb, cap_emb):
|
||||
"""Compute the loss given pairs of image and caption embeddings
|
||||
"""
|
||||
loss = self.criterion(img_emb, cap_emb)
|
||||
#self.logger.update('Le', loss.data.item(), img_emb.size(0))
|
||||
return loss
|
||||
|
||||
def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None):
|
||||
"""One training step given images and captions.
|
||||
"""
|
||||
# self.Eiters += 1
|
||||
# self.logger.update('Eit', self.Eiters)
|
||||
# self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
|
||||
|
||||
# compute the embeddings
|
||||
img_emb, cap_emb = self.forward_emb(images, captions, lengths, image_lengths=image_lengths)
|
||||
|
||||
# measure accuracy and record loss
|
||||
#self.optimizer.zero_grad()
|
||||
loss = self.forward_loss(img_emb, cap_emb)
|
||||
|
||||
if warmup_alpha is not None:
|
||||
loss = loss * warmup_alpha
|
||||
|
||||
# compute gradient and update
|
||||
# loss.backward()
|
||||
# if self.grad_clip > 0:
|
||||
# clip_grad_norm_(self.params, self.grad_clip)
|
||||
# self.optimizer.step()
|
||||
# return loss.data.item()
|
||||
return loss
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TwoLayerMLP(nn.Module):
|
||||
def __init__(self, num_features, hid_dim, out_dim, return_hidden=False):
|
||||
super().__init__()
|
||||
self.return_hidden = return_hidden
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(num_features, hid_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hid_dim, out_dim),
|
||||
)
|
||||
|
||||
for m in self.model:
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if not self.return_hidden:
|
||||
return self.model(x)
|
||||
else:
|
||||
hid_feat = self.model[:2](x)
|
||||
results = self.model[2:](hid_feat)
|
||||
return hid_feat, results
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
""" Very simple multi-layer perceptron (also called FFN)"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
|
||||
|
||||
# for p in self.parameters():
|
||||
# p.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B, N, D = x.size()
|
||||
x = x.reshape(B*N, D)
|
||||
for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
|
||||
x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
|
||||
x = x.view(B, N, self.output_dim)
|
||||
return x
|
||||
|
|
@ -0,0 +1,904 @@
|
|||
# -----------------------------------------------------------
|
||||
# "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network"
|
||||
# Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu*
|
||||
#
|
||||
# ---------------------------------------------------------------
|
||||
"""BCAN model"""
|
||||
import copy
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init
|
||||
import torchtext
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer, BertModel, BertConfig
|
||||
|
||||
from gpo import GPO
|
||||
from mlp import MLP
|
||||
|
||||
|
||||
def l1norm(X, dim, eps=1e-8):
|
||||
"""L1-normalize columns of X
|
||||
"""
|
||||
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def l2norm(X, dim, eps=1e-8):
|
||||
"""L2-normalize columns of X
|
||||
"""
|
||||
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True)
|
||||
norm = torch.sqrt(norm + eps) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
||||
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
||||
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
||||
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
||||
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
||||
if keep_dim:
|
||||
return w12 / (w1 * w2).clamp(min=eps)
|
||||
else:
|
||||
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
||||
|
||||
|
||||
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
||||
"""
|
||||
query: (batch, queryL, d)
|
||||
context: (batch, sourceL, d)
|
||||
opt: parameters
|
||||
"""
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# Step 1: preassign attention
|
||||
# --> (batch, d, queryL)
|
||||
queryT = torch.transpose(query, 1, 2)
|
||||
|
||||
# (batch, sourceL, d)(batch, d, queryL)
|
||||
attn = torch.bmm(context, queryT)
|
||||
attn = nn.LeakyReLU(0.1)(attn)
|
||||
attn = l2norm(attn, 2)
|
||||
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = torch.transpose(attn, 1, 2).contiguous()
|
||||
# --> (batch*queryL, sourceL)
|
||||
attn = attn.view(batch_size * queryL, sourceL)
|
||||
attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = attn.view(batch_size, queryL, sourceL)
|
||||
|
||||
# Step 2: identify irrelevant fragments
|
||||
# Learning an indicator function H, one for relevant, zero for irrelevant
|
||||
if opt.correct_type == 'equal':
|
||||
re_attn = correct_equal(attn, query, context, sourceL, g_sim)
|
||||
elif opt.correct_type == 'prob':
|
||||
re_attn = correct_prob(attn, query, context, sourceL, g_sim)
|
||||
# --> (batch, d, sourceL)
|
||||
contextT = torch.transpose(context, 1, 2)
|
||||
# --> (batch, sourceL, queryL)
|
||||
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
||||
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||||
# --> (batch, d, queryL)
|
||||
weightedContext = torch.bmm(contextT, re_attnT)
|
||||
|
||||
# --> (batch, queryL, d)
|
||||
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||||
|
||||
if torch.isnan(weightedContext).any():
|
||||
print('ddd')
|
||||
return weightedContext, re_attn
|
||||
|
||||
|
||||
def correct_equal(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as equal
|
||||
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_equal(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_equal(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_equal(attn, query, context, sourceL):
|
||||
funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
return re_attn
|
||||
|
||||
|
||||
def correct_prob(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as the sqrt
|
||||
of their similarity probability to the query fragment
|
||||
sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_prob(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_prob(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_prob(attn, query, context, sourceL):
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# -> (batch, queryL, sourceL, 1)
|
||||
xi = attn.unsqueeze(-1).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj = attn.unsqueeze(2).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj_confi = torch.sqrt(xj)
|
||||
|
||||
xi = xi.view(batch_size * queryL, sourceL, 1)
|
||||
xj = xj.view(batch_size * queryL, 1, sourceL)
|
||||
xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL)
|
||||
|
||||
# -> (batch*queryL, sourceL, sourceL)
|
||||
term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8)
|
||||
term2 = xj * xj_confi
|
||||
funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL)
|
||||
funcF = funcF.view(batch_size, queryL, sourceL)
|
||||
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
if torch.isnan(re_attn).any():
|
||||
print("ddd")
|
||||
return re_attn
|
||||
|
||||
|
||||
def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
||||
no_imgnorm=False):
|
||||
"""A wrapper to image encoders. Chooses between an different encoders
|
||||
that uses precomputed image features.
|
||||
"""
|
||||
img_enc = EncoderImagePrecomp(img_dim, embed_size, no_imgnorm)
|
||||
|
||||
return img_enc
|
||||
|
||||
|
||||
class EncoderImagePrecomp(nn.Module):
|
||||
|
||||
def __init__(self, img_dim, embed_size, no_imgnorm=False):
|
||||
super(EncoderImagePrecomp, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_imgnorm = no_imgnorm
|
||||
self.fc = nn.Linear(img_dim, embed_size)
|
||||
|
||||
self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2)
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Xavier initialization for the fully connected layer
|
||||
"""
|
||||
r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
|
||||
self.fc.weight.data.uniform_(-r, r)
|
||||
self.fc.bias.data.fill_(0)
|
||||
|
||||
def forward(self, images, img_lengths):
|
||||
"""Extract image feature vectors."""
|
||||
# assuming that the precomputed features are already l2-normalized
|
||||
#print(images, images.shape)
|
||||
features = self.fc(images)
|
||||
|
||||
features = self.mlp(images) + features
|
||||
features_mean, pool_weights = self.gpool(features, img_lengths)
|
||||
#features_mean = torch.mean(features, 1)
|
||||
# normalize in the joint embedding space
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
features_mean = l2norm(features_mean, dim=-1)
|
||||
|
||||
return features, features_mean
|
||||
|
||||
class EncoderTextBERT(nn.Module):
|
||||
def __init__(self, opt, order_embeddings=False, mean=True, post_transformer_layers=0):
|
||||
super().__init__()
|
||||
self.no_txtnorm = opt.no_txtnorm
|
||||
self.preextracted = opt.text_model_pre_extracted
|
||||
bert_config = BertConfig.from_pretrained(opt.text_model_pretrain,
|
||||
output_hidden_states=True,
|
||||
num_hidden_layers=opt.text_model_extraction_hidden_layer)
|
||||
bert_model = BertModel.from_pretrained(opt.text_model_pretrain, config=bert_config)
|
||||
self.order_embeddings = order_embeddings
|
||||
self.vocab_size = bert_model.config.vocab_size
|
||||
self.hidden_layer = opt.text_model_extraction_hidden_layer
|
||||
if not self.preextracted:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(opt.text_model_pretrain)
|
||||
self.bert_model = bert_model
|
||||
self.word_embeddings = self.bert_model.get_input_embeddings()
|
||||
if post_transformer_layers > 0:
|
||||
transformer_layer = nn.TransformerEncoderLayer(d_model=opt.text_model_word_dim, nhead=4,
|
||||
dim_feedforward=2048,
|
||||
dropout=opt.text_model_dropout, activation='relu')
|
||||
self.transformer_encoder = nn.TransformerEncoder(transformer_layer,
|
||||
num_layers=post_transformer_layers)
|
||||
self.post_transformer_layers = post_transformer_layers
|
||||
self.map = nn.Linear(opt.text_model_word_dim, opt.embed_size)
|
||||
self.mean = mean
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
def forward(self, x, lengths):
|
||||
'''
|
||||
x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ?
|
||||
lengths: tensor of lengths (LongTensor) of size B
|
||||
'''
|
||||
#print(x[0], x.shape)
|
||||
#print(lengths)
|
||||
if not self.preextracted or self.post_transformer_layers > 0:
|
||||
max_len = max(lengths)
|
||||
attention_mask = torch.ones(x.shape[0], max_len)
|
||||
for e, l in zip(attention_mask, lengths):
|
||||
e[l:] = 0
|
||||
attention_mask = attention_mask.to(x.device)
|
||||
|
||||
if self.preextracted:
|
||||
outputs = x
|
||||
else:
|
||||
#print(x.shape)
|
||||
outputs = self.bert_model(x, attention_mask=attention_mask)
|
||||
#print(outputs)
|
||||
#print("----")
|
||||
|
||||
outputs = outputs[2][-1]
|
||||
#print(outputs.shape)
|
||||
|
||||
if self.post_transformer_layers > 0:
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool())
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
if self.mean:
|
||||
#x = outputs.mean(dim=1)
|
||||
out = self.map(outputs)
|
||||
#out = torch.mean(out, 1)
|
||||
cap_len = torch.tensor(lengths)
|
||||
out, pool_weights = self.gpool(out, cap_len.to(out.device))
|
||||
else:
|
||||
x = outputs[:, 0, :] # from the last layer take only the first word
|
||||
|
||||
#out = self.map(x)
|
||||
outputs = self.map(outputs)
|
||||
#print(outputs.shape, outputs.dtype)
|
||||
|
||||
# normalization in the joint embedding space
|
||||
# out = l2norm(out)
|
||||
if not self.no_txtnorm:
|
||||
outputs = l2norm(outputs, dim=-1)
|
||||
out = l2norm(out, dim=-1)
|
||||
|
||||
# take absolute value, used by order embeddings
|
||||
if self.order_embeddings:
|
||||
out = torch.abs(out)
|
||||
#print(outputs.shape, out.shape)
|
||||
return outputs, lengths, out
|
||||
|
||||
def get_finetuning_params(self):
|
||||
return list(self.bert_model.parameters())
|
||||
|
||||
class EncoderTextBertVSE(nn.Module):
|
||||
def __init__(self, embed_size, no_txtnorm=False):
|
||||
super(EncoderTextBertVSE, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
||||
self.linear = nn.Linear(768, embed_size)
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
bert_attention_mask = (x != 0).float()
|
||||
bert_emb = self.bert(x, bert_attention_mask)[0] # B x N x D
|
||||
cap_len = torch.Tensor(lengths)
|
||||
|
||||
cap_emb = self.linear(bert_emb)
|
||||
#print(cap_len.shape,cap_len.dtype)
|
||||
|
||||
cap_emb_mean, pool_weights = self.gpool(cap_emb, cap_len.to(cap_emb.device))
|
||||
# cap_emb_mean = torch.mean(cap_emb, 1)
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
cap_emb = l2norm(cap_emb, dim=-1)
|
||||
cap_emb_mean = l2norm(cap_emb_mean, dim=-1)
|
||||
# print(cap_emb.shape, cap_emb_mean.shape)
|
||||
return cap_emb, lengths, cap_emb_mean
|
||||
|
||||
def encoder_text(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False):
|
||||
txt_enc = EncoderText(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm)
|
||||
|
||||
return txt_enc
|
||||
|
||||
|
||||
class EncoderText(nn.Module):
|
||||
|
||||
def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers,
|
||||
use_bi_gru=False, no_txtnorm=False):
|
||||
super(EncoderText, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
# word embedding
|
||||
self.embed = nn.Embedding(vocab_size, word_dim)
|
||||
|
||||
# caption embedding
|
||||
self.use_bi_gru = use_bi_gru
|
||||
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||||
|
||||
self.init_weights(word2idx)
|
||||
self.gpool = GPO(32, 32)
|
||||
|
||||
def init_weights(self, word2idx):
|
||||
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||||
|
||||
wemb = torchtext.vocab.GloVe(cache=".vector_cache")
|
||||
|
||||
# quick-and-dirty trick to improve word-hit rate
|
||||
missing_words = []
|
||||
for word, idx in word2idx.items():
|
||||
if word not in wemb.stoi:
|
||||
word = word.replace('-', '').replace('.', '').replace("'", '')
|
||||
if '/' in word:
|
||||
word = word.split('/')[0]
|
||||
if word in wemb.stoi:
|
||||
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
|
||||
else:
|
||||
missing_words.append(word)
|
||||
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||||
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
#print(x, x.shape, lengths, len(lengths))
|
||||
x = self.embed(x)
|
||||
packed = pack_padded_sequence(x, lengths, batch_first=True)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.rnn.flatten_parameters()
|
||||
# Forward propagate RNN
|
||||
out, _ = self.rnn(packed)
|
||||
#print(out.dtype, out.shape)
|
||||
#print("---")
|
||||
|
||||
# Reshape *final* output to (batch_size, hidden_size)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
cap_emb, cap_len = padded
|
||||
|
||||
if self.use_bi_gru:
|
||||
cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2
|
||||
|
||||
cap_emb_mean, pool_weights = self.gpool(cap_emb, cap_len.to(cap_emb.device))
|
||||
#cap_emb_mean = torch.mean(cap_emb, 1)
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
cap_emb = l2norm(cap_emb, dim=-1)
|
||||
cap_emb_mean = l2norm(cap_emb_mean, dim=1)
|
||||
#print(cap_emb.shape, cap_emb_mean.shape)
|
||||
return cap_emb, cap_len, cap_emb_mean
|
||||
|
||||
|
||||
''' Visual self-attention module '''
|
||||
|
||||
|
||||
class V_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Visual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, image_dim, embed_dim, dropout_rate=0.4, img_region_num=36):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(V_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(image_dim, embed_dim) # embed visual feature to common space
|
||||
|
||||
self.fc2 = nn.Linear(image_dim, embed_dim) # embed memory to common space
|
||||
self.fc2_2 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
self.fc4 = nn.Linear(image_dim, embed_dim) # embed attentive feature to common space
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.BatchNorm1d(img_region_num), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2_2 = nn.Sequential(self.fc2_2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, v_t, m_v):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_v = self.embedding_1(v_t)
|
||||
|
||||
if m_v.size()[-1] == v_t.size()[-1]:
|
||||
W_v_m = self.embedding_2(m_v)
|
||||
else:
|
||||
W_v_m = self.embedding_2_2(m_v)
|
||||
|
||||
W_v_m = W_v_m.unsqueeze(1).repeat(1, W_v.size()[1], 1)
|
||||
|
||||
h_v = W_v.mul(W_v_m)
|
||||
|
||||
a_v = self.embedding_3(h_v)
|
||||
a_v = a_v.squeeze(2)
|
||||
weights = self.softmax(a_v)
|
||||
|
||||
v_att = ((weights.unsqueeze(2) * v_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
v_att = l2norm(v_att, -1)
|
||||
|
||||
return v_att, weights
|
||||
|
||||
|
||||
class T_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Textual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, dropout_rate=0.4):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(T_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(embed_dim, embed_dim) # embed visual feature to common space
|
||||
self.fc2 = nn.Linear(embed_dim, embed_dim) # embed memory to common space
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, u_t, m_u):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_u = self.embedding_1(u_t)
|
||||
|
||||
W_u_m = self.embedding_2(m_u)
|
||||
W_u_m = W_u_m.unsqueeze(1).repeat(1, W_u.size()[1], 1)
|
||||
|
||||
h_u = W_u.mul(W_u_m)
|
||||
|
||||
a_u = self.embedding_3(h_u)
|
||||
a_u = a_u.squeeze(2)
|
||||
weights = self.softmax(a_u)
|
||||
|
||||
u_att = ((weights.unsqueeze(2) * u_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
u_att = l2norm(u_att, -1)
|
||||
|
||||
return u_att, weights
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Compute contrastive loss
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, scores):
|
||||
# compute image-sentence score matrix
|
||||
|
||||
diagonal = scores.diag().view(-1, 1)
|
||||
d1 = diagonal.expand_as(scores)
|
||||
d2 = diagonal.t().expand_as(scores)
|
||||
|
||||
# compare every diagonal score to scores in its column
|
||||
# caption retrieval
|
||||
cost_s = (self.margin + scores - d1).clamp(min=0)
|
||||
# compare every diagonal score to scores in its row
|
||||
# image retrieval
|
||||
cost_im = (self.margin + scores - d2).clamp(min=0)
|
||||
|
||||
# clear diagonals
|
||||
mask = torch.eye(scores.size(0)) > .5
|
||||
I = Variable(mask)
|
||||
if torch.cuda.is_available():
|
||||
I = I.cuda()
|
||||
cost_s = cost_s.masked_fill_(I, 0)
|
||||
cost_im = cost_im.masked_fill_(I, 0)
|
||||
|
||||
# keep the maximum violating negative for each query
|
||||
cost_s = cost_s.max(1)[0]
|
||||
cost_im = cost_im.max(0)[0]
|
||||
|
||||
return cost_s.sum() + cost_im.sum()
|
||||
|
||||
|
||||
class SCAN(nn.Module):
|
||||
"""
|
||||
Stacked Cross Attention Network (SCAN) model
|
||||
"""
|
||||
|
||||
def __init__(self, word2idx, opt):
|
||||
super(SCAN, self).__init__()
|
||||
# Build Models
|
||||
self.grad_clip = opt.grad_clip
|
||||
self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
|
||||
precomp_enc_type=opt.precomp_enc_type,
|
||||
no_imgnorm=opt.no_imgnorm)
|
||||
# self.txt_enc = encoder_text(word2idx, opt.vocab_size, opt.word_dim,
|
||||
# opt.embed_size, opt.num_layers,
|
||||
# use_bi_gru=True,
|
||||
# no_txtnorm=opt.no_txtnorm)
|
||||
self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers)
|
||||
#self.txt_enc = EncoderTextBertVSE(opt.embed_size, opt.no_imgnorm)
|
||||
|
||||
self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size)
|
||||
self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size)
|
||||
|
||||
self.opt = opt
|
||||
self.Eiters = 0
|
||||
|
||||
def forward_emb(self, images, img_lengths, captions, lengths):
|
||||
"""Compute the image and caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images, img_lengths)
|
||||
#print(img_emb.shape,img_mean.shape)
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return img_emb, img_mean, cap_emb, cap_lens, cap_mean
|
||||
|
||||
def txt_emb(self, captions, lengths):
|
||||
"""Compute the caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return cap_emb, cap_lens, cap_mean
|
||||
|
||||
def image_emb(self, images, img_lengths):
|
||||
"""Compute the image embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images, img_lengths)
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
return img_emb, img_mean
|
||||
|
||||
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
||||
"""Compute the loss given pairs of image and caption embeddings
|
||||
"""
|
||||
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
||||
|
||||
return scores
|
||||
|
||||
def forward(self, images, img_lengths, captions, lengths, ids=None, *args):
|
||||
# compute the embeddings
|
||||
lengths = lengths.cpu().numpy().tolist()
|
||||
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, img_lengths, captions, lengths)
|
||||
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
||||
return scores
|
||||
|
||||
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
||||
similarities = []
|
||||
n_image = images.size(0)
|
||||
n_caption = captions.size(0)
|
||||
g_sims = cap_mean.mm(img_mean.t())
|
||||
now = time.time()
|
||||
for i in range(n_caption):
|
||||
# Get the i-th text description
|
||||
n_word = cap_lens[i]
|
||||
g_sim = g_sims[i]
|
||||
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
||||
# --> (n_image, n_word, d)
|
||||
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
||||
|
||||
# t2i process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
||||
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
||||
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# i2t process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
||||
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
||||
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# Overall similarity for image and text
|
||||
sim = t2i_sim + i2t_sim
|
||||
|
||||
similarities.append(sim)
|
||||
|
||||
# (n_image, n_caption)
|
||||
similarities = torch.cat(similarities, 1)
|
||||
|
||||
if self.training:
|
||||
similarities = similarities.transpose(0, 1)
|
||||
#print('Time:{:.4f}'.format(time.time() - now))
|
||||
return similarities
|
||||
|
||||
def EncoderImage2(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
||||
no_imgnorm=False):
|
||||
"""A wrapper to image encoders. Chooses between an different encoders
|
||||
that uses precomputed image features.
|
||||
"""
|
||||
img_enc = EncoderImagePrecomp2(img_dim, embed_size, no_imgnorm)
|
||||
|
||||
return img_enc
|
||||
|
||||
|
||||
class EncoderImagePrecomp2(nn.Module):
|
||||
|
||||
def __init__(self, img_dim, embed_size, no_imgnorm=False):
|
||||
super(EncoderImagePrecomp2, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_imgnorm = no_imgnorm
|
||||
self.fc = nn.Linear(img_dim, embed_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Xavier initialization for the fully connected layer
|
||||
"""
|
||||
r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
|
||||
self.fc.weight.data.uniform_(-r, r)
|
||||
self.fc.bias.data.fill_(0)
|
||||
|
||||
def forward(self, images):
|
||||
"""Extract image feature vectors."""
|
||||
# assuming that the precomputed features are already l2-normalized
|
||||
#print(images, images.shape)
|
||||
features = self.fc(images)
|
||||
features_mean = torch.mean(features, 1)
|
||||
# normalize in the joint embedding space
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
|
||||
return features, features_mean
|
||||
|
||||
|
||||
def encoder_text2(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False):
|
||||
txt_enc = EncoderText2(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm)
|
||||
|
||||
return txt_enc
|
||||
|
||||
|
||||
class EncoderText2(nn.Module):
|
||||
|
||||
def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers,
|
||||
use_bi_gru=False, no_txtnorm=False):
|
||||
super(EncoderText2, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
# word embedding
|
||||
self.embed = nn.Embedding(vocab_size, word_dim)
|
||||
|
||||
# caption embedding
|
||||
self.use_bi_gru = use_bi_gru
|
||||
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||||
|
||||
self.init_weights(word2idx)
|
||||
|
||||
def init_weights(self, word2idx):
|
||||
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||||
|
||||
wemb = torchtext.vocab.GloVe(cache=".vector_cache")
|
||||
|
||||
# quick-and-dirty trick to improve word-hit rate
|
||||
missing_words = []
|
||||
for word, idx in word2idx.items():
|
||||
if word not in wemb.stoi:
|
||||
word = word.replace('-', '').replace('.', '').replace("'", '')
|
||||
if '/' in word:
|
||||
word = word.split('/')[0]
|
||||
if word in wemb.stoi:
|
||||
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
|
||||
else:
|
||||
missing_words.append(word)
|
||||
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||||
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
#print(x, x.shape, lengths, len(lengths))
|
||||
x = self.embed(x)
|
||||
packed = pack_padded_sequence(x, lengths, batch_first=True)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.rnn.flatten_parameters()
|
||||
# Forward propagate RNN
|
||||
out, _ = self.rnn(packed)
|
||||
#print(out.dtype, out.shape)
|
||||
#print("---")
|
||||
|
||||
# Reshape *final* output to (batch_size, hidden_size)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
cap_emb, cap_len = padded
|
||||
|
||||
if self.use_bi_gru:
|
||||
cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2
|
||||
cap_emb_mean = torch.mean(cap_emb, 1)
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
cap_emb = l2norm(cap_emb, dim=-1)
|
||||
cap_emb_mean = l2norm(cap_emb_mean, dim=1)
|
||||
#print(cap_emb.shape, cap_emb_mean.shape)
|
||||
return cap_emb, cap_len, cap_emb_mean
|
||||
|
||||
class SCAN2(nn.Module):
|
||||
"""
|
||||
Stacked Cross Attention Network (SCAN) model
|
||||
"""
|
||||
|
||||
def __init__(self, word2idx, opt):
|
||||
super(SCAN2, self).__init__()
|
||||
# Build Models
|
||||
self.grad_clip = opt.grad_clip
|
||||
self.img_enc = EncoderImage2(opt.data_name, opt.img_dim, opt.embed_size,
|
||||
precomp_enc_type=opt.precomp_enc_type,
|
||||
no_imgnorm=opt.no_imgnorm)
|
||||
self.txt_enc = encoder_text2(word2idx, opt.vocab_size, opt.word_dim,
|
||||
opt.embed_size, opt.num_layers,
|
||||
use_bi_gru=True,
|
||||
no_txtnorm=opt.no_txtnorm)
|
||||
#self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers)
|
||||
|
||||
self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size)
|
||||
self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size)
|
||||
|
||||
self.opt = opt
|
||||
self.Eiters = 0
|
||||
|
||||
def forward_emb(self, images, captions, lengths):
|
||||
"""Compute the image and caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
#print(img_emb.shape,img_mean.shape)
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return img_emb, img_mean, cap_emb, cap_lens, cap_mean
|
||||
|
||||
def txt_emb(self, captions, lengths):
|
||||
"""Compute the caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return cap_emb, cap_lens, cap_mean
|
||||
|
||||
def image_emb(self, images):
|
||||
"""Compute the image embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
return img_emb, img_mean
|
||||
|
||||
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
||||
"""Compute the loss given pairs of image and caption embeddings
|
||||
"""
|
||||
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
||||
|
||||
return scores
|
||||
|
||||
def forward(self, images, captions, lengths, ids=None, *args):
|
||||
# compute the embeddings
|
||||
lengths = lengths.cpu().numpy().tolist()
|
||||
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
|
||||
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
||||
return scores
|
||||
|
||||
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
||||
similarities = []
|
||||
n_image = images.size(0)
|
||||
n_caption = captions.size(0)
|
||||
g_sims = cap_mean.mm(img_mean.t())
|
||||
now = time.time()
|
||||
for i in range(n_caption):
|
||||
# Get the i-th text description
|
||||
n_word = cap_lens[i]
|
||||
g_sim = g_sims[i]
|
||||
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
||||
# --> (n_image, n_word, d)
|
||||
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
||||
|
||||
# t2i process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
||||
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
||||
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# i2t process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
||||
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
||||
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# Overall similarity for image and text
|
||||
sim = t2i_sim + i2t_sim
|
||||
|
||||
similarities.append(sim)
|
||||
|
||||
# (n_image, n_caption)
|
||||
similarities = torch.cat(similarities, 1)
|
||||
|
||||
if self.training:
|
||||
similarities = similarities.transpose(0, 1)
|
||||
#print('Time:{:.4f}'.format(time.time() - now))
|
||||
return similarities
|
|
@ -0,0 +1,630 @@
|
|||
# -----------------------------------------------------------
|
||||
# "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network"
|
||||
# Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu*
|
||||
#
|
||||
# ---------------------------------------------------------------
|
||||
"""BCAN model"""
|
||||
import copy
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init
|
||||
import torchtext
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer, BertModel, BertConfig
|
||||
|
||||
|
||||
def l1norm(X, dim, eps=1e-8):
|
||||
"""L1-normalize columns of X
|
||||
"""
|
||||
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def l2norm(X, dim, eps=1e-8):
|
||||
"""L2-normalize columns of X
|
||||
"""
|
||||
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True)
|
||||
norm = torch.sqrt(norm + eps) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
||||
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
||||
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
||||
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
||||
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
||||
if keep_dim:
|
||||
return w12 / (w1 * w2).clamp(min=eps)
|
||||
else:
|
||||
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
||||
|
||||
|
||||
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
||||
"""
|
||||
query: (batch, queryL, d)
|
||||
context: (batch, sourceL, d)
|
||||
opt: parameters
|
||||
"""
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# Step 1: preassign attention
|
||||
# --> (batch, d, queryL)
|
||||
queryT = torch.transpose(query, 1, 2)
|
||||
|
||||
# (batch, sourceL, d)(batch, d, queryL)
|
||||
attn = torch.bmm(context, queryT)
|
||||
attn = nn.LeakyReLU(0.1)(attn)
|
||||
attn = l2norm(attn, 2)
|
||||
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = torch.transpose(attn, 1, 2).contiguous()
|
||||
# --> (batch*queryL, sourceL)
|
||||
attn = attn.view(batch_size * queryL, sourceL)
|
||||
attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = attn.view(batch_size, queryL, sourceL)
|
||||
|
||||
# Step 2: identify irrelevant fragments
|
||||
# Learning an indicator function H, one for relevant, zero for irrelevant
|
||||
if opt.correct_type == 'equal':
|
||||
re_attn = correct_equal(attn, query, context, sourceL, g_sim)
|
||||
elif opt.correct_type == 'prob':
|
||||
re_attn = correct_prob(attn, query, context, sourceL, g_sim)
|
||||
# --> (batch, d, sourceL)
|
||||
contextT = torch.transpose(context, 1, 2)
|
||||
# --> (batch, sourceL, queryL)
|
||||
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
||||
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||||
# --> (batch, d, queryL)
|
||||
weightedContext = torch.bmm(contextT, re_attnT)
|
||||
|
||||
# --> (batch, queryL, d)
|
||||
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||||
|
||||
if torch.isnan(weightedContext).any():
|
||||
print('ddd')
|
||||
return weightedContext, re_attn
|
||||
|
||||
|
||||
def correct_equal(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as equal
|
||||
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_equal(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_equal(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_equal(attn, query, context, sourceL):
|
||||
funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
return re_attn
|
||||
|
||||
|
||||
def correct_prob(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as the sqrt
|
||||
of their similarity probability to the query fragment
|
||||
sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_prob(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_prob(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_prob(attn, query, context, sourceL):
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# -> (batch, queryL, sourceL, 1)
|
||||
xi = attn.unsqueeze(-1).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj = attn.unsqueeze(2).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj_confi = torch.sqrt(xj)
|
||||
|
||||
xi = xi.view(batch_size * queryL, sourceL, 1)
|
||||
xj = xj.view(batch_size * queryL, 1, sourceL)
|
||||
xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL)
|
||||
|
||||
# -> (batch*queryL, sourceL, sourceL)
|
||||
term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8)
|
||||
term2 = xj * xj_confi
|
||||
funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL)
|
||||
funcF = funcF.view(batch_size, queryL, sourceL)
|
||||
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
if torch.isnan(re_attn).any():
|
||||
print("ddd")
|
||||
return re_attn
|
||||
|
||||
|
||||
def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
||||
no_imgnorm=False):
|
||||
"""A wrapper to image encoders. Chooses between an different encoders
|
||||
that uses precomputed image features.
|
||||
"""
|
||||
img_enc = EncoderImagePrecomp(img_dim, embed_size, no_imgnorm)
|
||||
|
||||
return img_enc
|
||||
|
||||
|
||||
class EncoderImagePrecomp(nn.Module):
|
||||
|
||||
def __init__(self, img_dim, embed_size, no_imgnorm=False):
|
||||
super(EncoderImagePrecomp, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_imgnorm = no_imgnorm
|
||||
self.fc = nn.Linear(img_dim, embed_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Xavier initialization for the fully connected layer
|
||||
"""
|
||||
r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
|
||||
self.fc.weight.data.uniform_(-r, r)
|
||||
self.fc.bias.data.fill_(0)
|
||||
|
||||
def forward(self, images):
|
||||
"""Extract image feature vectors."""
|
||||
# assuming that the precomputed features are already l2-normalized
|
||||
#print(images, images.shape)
|
||||
features = self.fc(images)
|
||||
features_mean = torch.mean(features, 1)
|
||||
# normalize in the joint embedding space
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
|
||||
return features, features_mean
|
||||
|
||||
class EncoderTextBERT(nn.Module):
|
||||
def __init__(self, opt, order_embeddings=False, mean=True, post_transformer_layers=0):
|
||||
super().__init__()
|
||||
self.preextracted = opt.text_model_pre_extracted
|
||||
bert_config = BertConfig.from_pretrained(opt.text_model_pretrain,
|
||||
output_hidden_states=True,
|
||||
num_hidden_layers=opt.text_model_extraction_hidden_layer)
|
||||
bert_model = BertModel.from_pretrained(opt.text_model_pretrain, config=bert_config)
|
||||
self.order_embeddings = order_embeddings
|
||||
self.vocab_size = bert_model.config.vocab_size
|
||||
self.hidden_layer = opt.text_model_extraction_hidden_layer
|
||||
if not self.preextracted:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(opt.text_model_pretrain)
|
||||
self.bert_model = bert_model
|
||||
self.word_embeddings = self.bert_model.get_input_embeddings()
|
||||
if post_transformer_layers > 0:
|
||||
transformer_layer = nn.TransformerEncoderLayer(d_model=opt.text_model_word_dim, nhead=4,
|
||||
dim_feedforward=2048,
|
||||
dropout=opt.text_model_dropout, activation='relu')
|
||||
self.transformer_encoder = nn.TransformerEncoder(transformer_layer,
|
||||
num_layers=post_transformer_layers)
|
||||
self.post_transformer_layers = post_transformer_layers
|
||||
self.map = nn.Linear(opt.text_model_word_dim, opt.embed_size)
|
||||
self.mean = mean
|
||||
|
||||
def forward(self, x, lengths):
|
||||
'''
|
||||
x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ?
|
||||
lengths: tensor of lengths (LongTensor) of size B
|
||||
'''
|
||||
# print(x, x.shape)
|
||||
# print(lengths)
|
||||
if not self.preextracted or self.post_transformer_layers > 0:
|
||||
max_len = max(lengths)
|
||||
attention_mask = torch.ones(x.shape[0], max_len)
|
||||
for e, l in zip(attention_mask, lengths):
|
||||
e[l:] = 0
|
||||
attention_mask = attention_mask.to(x.device)
|
||||
|
||||
if self.preextracted:
|
||||
outputs = x
|
||||
else:
|
||||
outputs = self.bert_model(x, attention_mask=attention_mask)
|
||||
outputs = outputs[2][-1]
|
||||
|
||||
if self.post_transformer_layers > 0:
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool())
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
if self.mean:
|
||||
#x = outputs.mean(dim=1)
|
||||
x = torch.mean(outputs, 1)
|
||||
else:
|
||||
x = outputs[:, 0, :] # from the last layer take only the first word
|
||||
|
||||
out = self.map(x)
|
||||
outputs = self.map(outputs)
|
||||
|
||||
# normalization in the joint embedding space
|
||||
# out = l2norm(out)
|
||||
|
||||
# take absolute value, used by order embeddings
|
||||
if self.order_embeddings:
|
||||
out = torch.abs(out)
|
||||
#print(outputs.shape, out.shape)
|
||||
return outputs, lengths, out
|
||||
|
||||
def get_finetuning_params(self):
|
||||
return list(self.bert_model.parameters())
|
||||
|
||||
def encoder_text(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False):
|
||||
txt_enc = EncoderText(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm)
|
||||
|
||||
return txt_enc
|
||||
|
||||
|
||||
class EncoderText(nn.Module):
|
||||
|
||||
def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers,
|
||||
use_bi_gru=False, no_txtnorm=False):
|
||||
super(EncoderText, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
# word embedding
|
||||
self.embed = nn.Embedding(vocab_size, word_dim)
|
||||
|
||||
# caption embedding
|
||||
self.use_bi_gru = use_bi_gru
|
||||
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||||
|
||||
self.init_weights(word2idx)
|
||||
|
||||
def init_weights(self, word2idx):
|
||||
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||||
|
||||
wemb = torchtext.vocab.GloVe(cache=".vector_cache")
|
||||
|
||||
# quick-and-dirty trick to improve word-hit rate
|
||||
missing_words = []
|
||||
for word, idx in word2idx.items():
|
||||
if word not in wemb.stoi:
|
||||
word = word.replace('-', '').replace('.', '').replace("'", '')
|
||||
if '/' in word:
|
||||
word = word.split('/')[0]
|
||||
if word in wemb.stoi:
|
||||
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
|
||||
else:
|
||||
missing_words.append(word)
|
||||
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||||
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
#print(x, x.shape, lengths, len(lengths))
|
||||
x = self.embed(x)
|
||||
packed = pack_padded_sequence(x, lengths, batch_first=True)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.rnn.flatten_parameters()
|
||||
# Forward propagate RNN
|
||||
out, _ = self.rnn(packed)
|
||||
#print(out.dtype, out.shape)
|
||||
#print("---")
|
||||
|
||||
# Reshape *final* output to (batch_size, hidden_size)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
cap_emb, cap_len = padded
|
||||
|
||||
if self.use_bi_gru:
|
||||
cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2
|
||||
cap_emb_mean = torch.mean(cap_emb, 1)
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
cap_emb = l2norm(cap_emb, dim=-1)
|
||||
cap_emb_mean = l2norm(cap_emb_mean, dim=1)
|
||||
#print(cap_emb.shape, cap_emb_mean.shape)
|
||||
return cap_emb, cap_len, cap_emb_mean
|
||||
|
||||
|
||||
''' Visual self-attention module '''
|
||||
|
||||
|
||||
class V_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Visual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, image_dim, embed_dim, dropout_rate=0.4, img_region_num=36):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(V_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(image_dim, embed_dim) # embed visual feature to common space
|
||||
|
||||
self.fc2 = nn.Linear(image_dim, embed_dim) # embed memory to common space
|
||||
self.fc2_2 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
self.fc4 = nn.Linear(image_dim, embed_dim) # embed attentive feature to common space
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.BatchNorm1d(img_region_num), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2_2 = nn.Sequential(self.fc2_2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, v_t, m_v):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_v = self.embedding_1(v_t)
|
||||
|
||||
if m_v.size()[-1] == v_t.size()[-1]:
|
||||
W_v_m = self.embedding_2(m_v)
|
||||
else:
|
||||
W_v_m = self.embedding_2_2(m_v)
|
||||
|
||||
W_v_m = W_v_m.unsqueeze(1).repeat(1, W_v.size()[1], 1)
|
||||
|
||||
h_v = W_v.mul(W_v_m)
|
||||
|
||||
a_v = self.embedding_3(h_v)
|
||||
a_v = a_v.squeeze(2)
|
||||
weights = self.softmax(a_v)
|
||||
|
||||
v_att = ((weights.unsqueeze(2) * v_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
v_att = l2norm(v_att, -1)
|
||||
|
||||
return v_att, weights
|
||||
|
||||
|
||||
class T_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Textual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, dropout_rate=0.4):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(T_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(embed_dim, embed_dim) # embed visual feature to common space
|
||||
self.fc2 = nn.Linear(embed_dim, embed_dim) # embed memory to common space
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, u_t, m_u):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_u = self.embedding_1(u_t)
|
||||
|
||||
W_u_m = self.embedding_2(m_u)
|
||||
W_u_m = W_u_m.unsqueeze(1).repeat(1, W_u.size()[1], 1)
|
||||
|
||||
h_u = W_u.mul(W_u_m)
|
||||
|
||||
a_u = self.embedding_3(h_u)
|
||||
a_u = a_u.squeeze(2)
|
||||
weights = self.softmax(a_u)
|
||||
|
||||
u_att = ((weights.unsqueeze(2) * u_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
u_att = l2norm(u_att, -1)
|
||||
|
||||
return u_att, weights
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Compute contrastive loss
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, scores):
|
||||
# compute image-sentence score matrix
|
||||
|
||||
diagonal = scores.diag().view(-1, 1)
|
||||
d1 = diagonal.expand_as(scores)
|
||||
d2 = diagonal.t().expand_as(scores)
|
||||
|
||||
# compare every diagonal score to scores in its column
|
||||
# caption retrieval
|
||||
cost_s = (self.margin + scores - d1).clamp(min=0)
|
||||
# compare every diagonal score to scores in its row
|
||||
# image retrieval
|
||||
cost_im = (self.margin + scores - d2).clamp(min=0)
|
||||
|
||||
# clear diagonals
|
||||
mask = torch.eye(scores.size(0)) > .5
|
||||
I = Variable(mask)
|
||||
if torch.cuda.is_available():
|
||||
I = I.cuda()
|
||||
cost_s = cost_s.masked_fill_(I, 0)
|
||||
cost_im = cost_im.masked_fill_(I, 0)
|
||||
|
||||
# keep the maximum violating negative for each query
|
||||
cost_s = cost_s.max(1)[0]
|
||||
cost_im = cost_im.max(0)[0]
|
||||
|
||||
return cost_s.sum() + cost_im.sum()
|
||||
|
||||
|
||||
class SCAN2(nn.Module):
|
||||
"""
|
||||
Stacked Cross Attention Network (SCAN) model
|
||||
"""
|
||||
|
||||
def __init__(self, word2idx, opt):
|
||||
super(SCAN2, self).__init__()
|
||||
# Build Models
|
||||
self.grad_clip = opt.grad_clip
|
||||
self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
|
||||
precomp_enc_type=opt.precomp_enc_type,
|
||||
no_imgnorm=opt.no_imgnorm)
|
||||
self.txt_enc = encoder_text(word2idx, opt.vocab_size, opt.word_dim,
|
||||
opt.embed_size, opt.num_layers,
|
||||
use_bi_gru=True,
|
||||
no_txtnorm=opt.no_txtnorm)
|
||||
#self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers)
|
||||
|
||||
self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size)
|
||||
self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size)
|
||||
|
||||
self.opt = opt
|
||||
self.Eiters = 0
|
||||
|
||||
def forward_emb(self, images, captions, lengths):
|
||||
"""Compute the image and caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
#print(img_emb.shape,img_mean.shape)
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return img_emb, img_mean, cap_emb, cap_lens, cap_mean
|
||||
|
||||
def txt_emb(self, captions, lengths):
|
||||
"""Compute the caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return cap_emb, cap_lens, cap_mean
|
||||
|
||||
def image_emb(self, images):
|
||||
"""Compute the image embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
return img_emb, img_mean
|
||||
|
||||
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
||||
"""Compute the loss given pairs of image and caption embeddings
|
||||
"""
|
||||
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
||||
|
||||
return scores
|
||||
|
||||
def forward(self, images, captions, lengths, ids=None, *args):
|
||||
# compute the embeddings
|
||||
lengths = lengths.cpu().numpy().tolist()
|
||||
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
|
||||
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
||||
return scores
|
||||
|
||||
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
||||
similarities = []
|
||||
n_image = images.size(0)
|
||||
n_caption = captions.size(0)
|
||||
g_sims = cap_mean.mm(img_mean.t())
|
||||
now = time.time()
|
||||
for i in range(n_caption):
|
||||
# Get the i-th text description
|
||||
n_word = cap_lens[i]
|
||||
g_sim = g_sims[i]
|
||||
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
||||
# --> (n_image, n_word, d)
|
||||
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
||||
|
||||
# t2i process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
||||
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
||||
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# i2t process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
||||
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
||||
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# Overall similarity for image and text
|
||||
sim = t2i_sim + i2t_sim
|
||||
|
||||
similarities.append(sim)
|
||||
|
||||
# (n_image, n_caption)
|
||||
similarities = torch.cat(similarities, 1)
|
||||
|
||||
if self.training:
|
||||
similarities = similarities.transpose(0, 1)
|
||||
#print('Time:{:.4f}'.format(time.time() - now))
|
||||
return similarities
|
|
@ -0,0 +1,632 @@
|
|||
# -----------------------------------------------------------
|
||||
# "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network"
|
||||
# Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu*
|
||||
#
|
||||
# ---------------------------------------------------------------
|
||||
"""BCAN model"""
|
||||
import copy
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init
|
||||
import torchtext
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer, BertModel, BertConfig
|
||||
|
||||
|
||||
def l1norm(X, dim, eps=1e-8):
|
||||
"""L1-normalize columns of X
|
||||
"""
|
||||
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def l2norm(X, dim, eps=1e-8):
|
||||
"""L2-normalize columns of X
|
||||
"""
|
||||
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True)
|
||||
norm = torch.sqrt(norm + eps) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
||||
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
||||
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
||||
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
||||
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
||||
if keep_dim:
|
||||
return w12 / (w1 * w2).clamp(min=eps)
|
||||
else:
|
||||
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
||||
|
||||
|
||||
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
||||
"""
|
||||
query: (batch, queryL, d)
|
||||
context: (batch, sourceL, d)
|
||||
opt: parameters
|
||||
"""
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# Step 1: preassign attention
|
||||
# --> (batch, d, queryL)
|
||||
queryT = torch.transpose(query, 1, 2)
|
||||
|
||||
# (batch, sourceL, d)(batch, d, queryL)
|
||||
attn = torch.bmm(context, queryT)
|
||||
attn = nn.LeakyReLU(0.1)(attn)
|
||||
attn = l2norm(attn, 2)
|
||||
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = torch.transpose(attn, 1, 2).contiguous()
|
||||
# --> (batch*queryL, sourceL)
|
||||
attn = attn.view(batch_size * queryL, sourceL)
|
||||
attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = attn.view(batch_size, queryL, sourceL)
|
||||
|
||||
# Step 2: identify irrelevant fragments
|
||||
# Learning an indicator function H, one for relevant, zero for irrelevant
|
||||
if opt.correct_type == 'equal':
|
||||
re_attn = correct_equal(attn, query, context, sourceL, g_sim)
|
||||
elif opt.correct_type == 'prob':
|
||||
re_attn = correct_prob(attn, query, context, sourceL, g_sim)
|
||||
# --> (batch, d, sourceL)
|
||||
contextT = torch.transpose(context, 1, 2)
|
||||
# --> (batch, sourceL, queryL)
|
||||
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
||||
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||||
# --> (batch, d, queryL)
|
||||
weightedContext = torch.bmm(contextT, re_attnT)
|
||||
|
||||
# --> (batch, queryL, d)
|
||||
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||||
|
||||
if torch.isnan(weightedContext).any():
|
||||
print('ddd')
|
||||
return weightedContext, re_attn
|
||||
|
||||
|
||||
def correct_equal(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as equal
|
||||
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_equal(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_equal(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_equal(attn, query, context, sourceL):
|
||||
funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
return re_attn
|
||||
|
||||
|
||||
def correct_prob(attn, query, context, sourceL, g_sim):
|
||||
"""
|
||||
consider the confidence g(x) for each fragment as the sqrt
|
||||
of their similarity probability to the query fragment
|
||||
sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj
|
||||
attn: (batch, queryL, sourceL)
|
||||
"""
|
||||
# GCU process
|
||||
d = g_sim - 0.3
|
||||
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||||
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||||
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||||
re_attn = re_attn / attn_sum
|
||||
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||||
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||||
re_attn1 = focal_prob(re_attn, query, context, sourceL)
|
||||
|
||||
# LCU process
|
||||
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||||
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||||
delta = cos - cos1
|
||||
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||||
re_attn2 = delta * re_attn1
|
||||
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||||
re_attn2 = re_attn2 / attn_sum
|
||||
re_attn2 = focal_prob(re_attn2, query, context, sourceL)
|
||||
return re_attn2
|
||||
|
||||
|
||||
def focal_prob(attn, query, context, sourceL):
|
||||
batch_size, queryL, sourceL = context.size(
|
||||
0), query.size(1), context.size(1)
|
||||
|
||||
# -> (batch, queryL, sourceL, 1)
|
||||
xi = attn.unsqueeze(-1).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj = attn.unsqueeze(2).contiguous()
|
||||
# -> (batch, queryL, 1, sourceL)
|
||||
xj_confi = torch.sqrt(xj)
|
||||
|
||||
xi = xi.view(batch_size * queryL, sourceL, 1)
|
||||
xj = xj.view(batch_size * queryL, 1, sourceL)
|
||||
xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL)
|
||||
|
||||
# -> (batch*queryL, sourceL, sourceL)
|
||||
term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8)
|
||||
term2 = xj * xj_confi
|
||||
funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL)
|
||||
funcF = funcF.view(batch_size, queryL, sourceL)
|
||||
|
||||
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||||
torch.zeros_like(attn))
|
||||
|
||||
# Step 3: reassign attention
|
||||
tmp_attn = fattn * attn
|
||||
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||||
re_attn = tmp_attn / attn_sum
|
||||
|
||||
if torch.isnan(re_attn).any():
|
||||
print("ddd")
|
||||
return re_attn
|
||||
|
||||
|
||||
def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
||||
no_imgnorm=False):
|
||||
"""A wrapper to image encoders. Chooses between an different encoders
|
||||
that uses precomputed image features.
|
||||
"""
|
||||
img_enc = EncoderImagePrecomp(img_dim, embed_size, no_imgnorm)
|
||||
|
||||
return img_enc
|
||||
|
||||
|
||||
class EncoderImagePrecomp(nn.Module):
|
||||
|
||||
def __init__(self, img_dim, embed_size, no_imgnorm=False):
|
||||
super(EncoderImagePrecomp, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_imgnorm = no_imgnorm
|
||||
self.fc = nn.Linear(img_dim, embed_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Xavier initialization for the fully connected layer
|
||||
"""
|
||||
r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
|
||||
self.fc.weight.data.uniform_(-r, r)
|
||||
self.fc.bias.data.fill_(0)
|
||||
|
||||
def forward(self, images):
|
||||
"""Extract image feature vectors."""
|
||||
# assuming that the precomputed features are already l2-normalized
|
||||
#print(images, images.shape)
|
||||
features = self.fc(images)
|
||||
#features_mean = torch.mean(features, 1)
|
||||
features_mean = torch.max(features, 1)[0]
|
||||
# normalize in the joint embedding space
|
||||
if not self.no_imgnorm:
|
||||
features = l2norm(features, dim=-1)
|
||||
|
||||
return features, features_mean
|
||||
|
||||
class EncoderTextBERT(nn.Module):
|
||||
def __init__(self, opt, order_embeddings=False, mean=True, post_transformer_layers=0):
|
||||
super().__init__()
|
||||
self.preextracted = opt.text_model_pre_extracted
|
||||
bert_config = BertConfig.from_pretrained(opt.text_model_pretrain,
|
||||
output_hidden_states=True,
|
||||
num_hidden_layers=opt.text_model_extraction_hidden_layer)
|
||||
bert_model = BertModel.from_pretrained(opt.text_model_pretrain, config=bert_config)
|
||||
self.order_embeddings = order_embeddings
|
||||
self.vocab_size = bert_model.config.vocab_size
|
||||
self.hidden_layer = opt.text_model_extraction_hidden_layer
|
||||
if not self.preextracted:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(opt.text_model_pretrain)
|
||||
self.bert_model = bert_model
|
||||
self.word_embeddings = self.bert_model.get_input_embeddings()
|
||||
if post_transformer_layers > 0:
|
||||
transformer_layer = nn.TransformerEncoderLayer(d_model=opt.text_model_word_dim, nhead=4,
|
||||
dim_feedforward=2048,
|
||||
dropout=opt.text_model_dropout, activation='relu')
|
||||
self.transformer_encoder = nn.TransformerEncoder(transformer_layer,
|
||||
num_layers=post_transformer_layers)
|
||||
self.post_transformer_layers = post_transformer_layers
|
||||
self.map = nn.Linear(opt.text_model_word_dim, opt.embed_size)
|
||||
self.mean = mean
|
||||
|
||||
def forward(self, x, lengths):
|
||||
'''
|
||||
x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ?
|
||||
lengths: tensor of lengths (LongTensor) of size B
|
||||
'''
|
||||
# print(x, x.shape)
|
||||
# print(lengths)
|
||||
if not self.preextracted or self.post_transformer_layers > 0:
|
||||
max_len = max(lengths)
|
||||
attention_mask = torch.ones(x.shape[0], max_len)
|
||||
for e, l in zip(attention_mask, lengths):
|
||||
e[l:] = 0
|
||||
attention_mask = attention_mask.to(x.device)
|
||||
|
||||
if self.preextracted:
|
||||
outputs = x
|
||||
else:
|
||||
outputs = self.bert_model(x, attention_mask=attention_mask)
|
||||
outputs = outputs[2][-1]
|
||||
|
||||
if self.post_transformer_layers > 0:
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool())
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
if self.mean:
|
||||
#x = outputs.mean(dim=1)
|
||||
x = torch.mean(outputs, 1)
|
||||
else:
|
||||
x = outputs[:, 0, :] # from the last layer take only the first word
|
||||
|
||||
out = self.map(x)
|
||||
outputs = self.map(outputs)
|
||||
|
||||
# normalization in the joint embedding space
|
||||
# out = l2norm(out)
|
||||
|
||||
# take absolute value, used by order embeddings
|
||||
if self.order_embeddings:
|
||||
out = torch.abs(out)
|
||||
#print(outputs.shape, out.shape)
|
||||
return outputs, lengths, out
|
||||
|
||||
def get_finetuning_params(self):
|
||||
return list(self.bert_model.parameters())
|
||||
|
||||
def encoder_text(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False):
|
||||
txt_enc = EncoderText(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm)
|
||||
|
||||
return txt_enc
|
||||
|
||||
|
||||
class EncoderText(nn.Module):
|
||||
|
||||
def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers,
|
||||
use_bi_gru=False, no_txtnorm=False):
|
||||
super(EncoderText, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.no_txtnorm = no_txtnorm
|
||||
|
||||
# word embedding
|
||||
self.embed = nn.Embedding(vocab_size, word_dim)
|
||||
|
||||
# caption embedding
|
||||
self.use_bi_gru = use_bi_gru
|
||||
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||||
|
||||
self.init_weights(word2idx)
|
||||
|
||||
def init_weights(self, word2idx):
|
||||
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||||
|
||||
wemb = torchtext.vocab.GloVe(cache=".vector_cache")
|
||||
|
||||
# quick-and-dirty trick to improve word-hit rate
|
||||
missing_words = []
|
||||
for word, idx in word2idx.items():
|
||||
if word not in wemb.stoi:
|
||||
word = word.replace('-', '').replace('.', '').replace("'", '')
|
||||
if '/' in word:
|
||||
word = word.split('/')[0]
|
||||
if word in wemb.stoi:
|
||||
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
|
||||
else:
|
||||
missing_words.append(word)
|
||||
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||||
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||||
|
||||
def forward(self, x, lengths):
|
||||
"""Handles variable size captions
|
||||
"""
|
||||
# Embed word ids to vectors
|
||||
#print(x, x.shape, lengths, len(lengths))
|
||||
x = self.embed(x)
|
||||
packed = pack_padded_sequence(x, lengths, batch_first=True)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.rnn.flatten_parameters()
|
||||
# Forward propagate RNN
|
||||
out, _ = self.rnn(packed)
|
||||
#print(out.dtype, out.shape)
|
||||
#print("---")
|
||||
|
||||
# Reshape *final* output to (batch_size, hidden_size)
|
||||
padded = pad_packed_sequence(out, batch_first=True)
|
||||
cap_emb, cap_len = padded
|
||||
|
||||
if self.use_bi_gru:
|
||||
cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2
|
||||
#cap_emb_mean = torch.mean(cap_emb, 1)
|
||||
cap_emb_mean = torch.max(cap_emb, 1)[0]
|
||||
# normalization in the joint embedding space
|
||||
if not self.no_txtnorm:
|
||||
cap_emb = l2norm(cap_emb, dim=-1)
|
||||
cap_emb_mean = l2norm(cap_emb_mean, dim=1)
|
||||
#print(cap_emb.shape, cap_emb_mean.shape)
|
||||
return cap_emb, cap_len, cap_emb_mean
|
||||
|
||||
|
||||
''' Visual self-attention module '''
|
||||
|
||||
|
||||
class V_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Visual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, image_dim, embed_dim, dropout_rate=0.4, img_region_num=36):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(V_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(image_dim, embed_dim) # embed visual feature to common space
|
||||
|
||||
self.fc2 = nn.Linear(image_dim, embed_dim) # embed memory to common space
|
||||
self.fc2_2 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
self.fc4 = nn.Linear(image_dim, embed_dim) # embed attentive feature to common space
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.BatchNorm1d(img_region_num), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2_2 = nn.Sequential(self.fc2_2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, v_t, m_v):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_v = self.embedding_1(v_t)
|
||||
|
||||
if m_v.size()[-1] == v_t.size()[-1]:
|
||||
W_v_m = self.embedding_2(m_v)
|
||||
else:
|
||||
W_v_m = self.embedding_2_2(m_v)
|
||||
|
||||
W_v_m = W_v_m.unsqueeze(1).repeat(1, W_v.size()[1], 1)
|
||||
|
||||
h_v = W_v.mul(W_v_m)
|
||||
|
||||
a_v = self.embedding_3(h_v)
|
||||
a_v = a_v.squeeze(2)
|
||||
weights = self.softmax(a_v)
|
||||
|
||||
v_att = ((weights.unsqueeze(2) * v_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
v_att = l2norm(v_att, -1)
|
||||
|
||||
return v_att, weights
|
||||
|
||||
|
||||
class T_single_modal_atten(nn.Module):
|
||||
"""
|
||||
Single Textual Modal Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, dropout_rate=0.4):
|
||||
"""
|
||||
param image_dim: dim of visual feature
|
||||
param embed_dim: dim of embedding space
|
||||
"""
|
||||
super(T_single_modal_atten, self).__init__()
|
||||
|
||||
self.fc1 = nn.Linear(embed_dim, embed_dim) # embed visual feature to common space
|
||||
self.fc2 = nn.Linear(embed_dim, embed_dim) # embed memory to common space
|
||||
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
||||
|
||||
self.embedding_1 = nn.Sequential(self.fc1, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_2 = nn.Sequential(self.fc2, nn.Tanh(), nn.Dropout(dropout_rate))
|
||||
self.embedding_3 = nn.Sequential(self.fc3)
|
||||
|
||||
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def forward(self, u_t, m_u):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
||||
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
W_u = self.embedding_1(u_t)
|
||||
|
||||
W_u_m = self.embedding_2(m_u)
|
||||
W_u_m = W_u_m.unsqueeze(1).repeat(1, W_u.size()[1], 1)
|
||||
|
||||
h_u = W_u.mul(W_u_m)
|
||||
|
||||
a_u = self.embedding_3(h_u)
|
||||
a_u = a_u.squeeze(2)
|
||||
weights = self.softmax(a_u)
|
||||
|
||||
u_att = ((weights.unsqueeze(2) * u_t)).sum(dim=1)
|
||||
|
||||
# l2 norm
|
||||
u_att = l2norm(u_att, -1)
|
||||
|
||||
return u_att, weights
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Compute contrastive loss
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, scores):
|
||||
# compute image-sentence score matrix
|
||||
|
||||
diagonal = scores.diag().view(-1, 1)
|
||||
d1 = diagonal.expand_as(scores)
|
||||
d2 = diagonal.t().expand_as(scores)
|
||||
|
||||
# compare every diagonal score to scores in its column
|
||||
# caption retrieval
|
||||
cost_s = (self.margin + scores - d1).clamp(min=0)
|
||||
# compare every diagonal score to scores in its row
|
||||
# image retrieval
|
||||
cost_im = (self.margin + scores - d2).clamp(min=0)
|
||||
|
||||
# clear diagonals
|
||||
mask = torch.eye(scores.size(0)) > .5
|
||||
I = Variable(mask)
|
||||
if torch.cuda.is_available():
|
||||
I = I.cuda()
|
||||
cost_s = cost_s.masked_fill_(I, 0)
|
||||
cost_im = cost_im.masked_fill_(I, 0)
|
||||
|
||||
# keep the maximum violating negative for each query
|
||||
cost_s = cost_s.max(1)[0]
|
||||
cost_im = cost_im.max(0)[0]
|
||||
|
||||
return cost_s.sum() + cost_im.sum()
|
||||
|
||||
|
||||
class SCAN3(nn.Module):
|
||||
"""
|
||||
Stacked Cross Attention Network (SCAN) model
|
||||
"""
|
||||
|
||||
def __init__(self, word2idx, opt):
|
||||
super(SCAN3, self).__init__()
|
||||
# Build Models
|
||||
self.grad_clip = opt.grad_clip
|
||||
self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
|
||||
precomp_enc_type=opt.precomp_enc_type,
|
||||
no_imgnorm=opt.no_imgnorm)
|
||||
self.txt_enc = encoder_text(word2idx, opt.vocab_size, opt.word_dim,
|
||||
opt.embed_size, opt.num_layers,
|
||||
use_bi_gru=True,
|
||||
no_txtnorm=opt.no_txtnorm)
|
||||
#self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers)
|
||||
|
||||
self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size)
|
||||
self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size)
|
||||
|
||||
self.opt = opt
|
||||
self.Eiters = 0
|
||||
|
||||
def forward_emb(self, images, captions, lengths):
|
||||
"""Compute the image and caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
#print(img_emb.shape,img_mean.shape)
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return img_emb, img_mean, cap_emb, cap_lens, cap_mean
|
||||
|
||||
def txt_emb(self, captions, lengths):
|
||||
"""Compute the caption embeddings
|
||||
"""
|
||||
# Forward
|
||||
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
||||
|
||||
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
||||
return cap_emb, cap_lens, cap_mean
|
||||
|
||||
def image_emb(self, images):
|
||||
"""Compute the image embeddings
|
||||
"""
|
||||
# Forward
|
||||
img_emb, img_mean = self.img_enc(images)
|
||||
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
||||
return img_emb, img_mean
|
||||
|
||||
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
||||
"""Compute the loss given pairs of image and caption embeddings
|
||||
"""
|
||||
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
||||
|
||||
return scores
|
||||
|
||||
def forward(self, images, captions, lengths, ids=None, *args):
|
||||
# compute the embeddings
|
||||
lengths = lengths.cpu().numpy().tolist()
|
||||
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
|
||||
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
||||
return scores
|
||||
|
||||
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
||||
similarities = []
|
||||
n_image = images.size(0)
|
||||
n_caption = captions.size(0)
|
||||
g_sims = cap_mean.mm(img_mean.t())
|
||||
now = time.time()
|
||||
for i in range(n_caption):
|
||||
# Get the i-th text description
|
||||
n_word = cap_lens[i]
|
||||
g_sim = g_sims[i]
|
||||
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
||||
# --> (n_image, n_word, d)
|
||||
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
||||
|
||||
# t2i process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
||||
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
||||
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# i2t process
|
||||
# weiContext: (n_image, n_word, d)
|
||||
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
||||
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
||||
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
||||
|
||||
# Overall similarity for image and text
|
||||
sim = t2i_sim + i2t_sim
|
||||
|
||||
similarities.append(sim)
|
||||
|
||||
# (n_image, n_caption)
|
||||
similarities = torch.cat(similarities, 1)
|
||||
|
||||
if self.training:
|
||||
similarities = similarities.transpose(0, 1)
|
||||
#print('Time:{:.4f}'.format(time.time() - now))
|
||||
return similarities
|
|
@ -0,0 +1,9 @@
|
|||
from .bua import add_bottom_up_attention_config
|
||||
|
||||
def add_config(args, cfg):
|
||||
if args.mode == "caffe":
|
||||
add_bottom_up_attention_config(cfg, True)
|
||||
elif args.mode == "detectron2":
|
||||
add_bottom_up_attention_config(cfg)
|
||||
else:
|
||||
raise Exception("detection model not supported: {}".format(args.model))
|
|
@ -0,0 +1,5 @@
|
|||
from .config import add_bottom_up_attention_config
|
||||
from .backbone import build_bua_resnet_backbone
|
||||
from .rcnn import GeneralizedBUARCNN
|
||||
from .roi_heads import BUACaffeRes5ROIHeads
|
||||
from .rpn import StandardBUARPNHead, BUARPN
|
|
@ -0,0 +1,276 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import fvcore.nn.weight_init as weight_init
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm, BatchNorm2d
|
||||
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, make_stage
|
||||
from detectron2.modeling.backbone.resnet import BottleneckBlock, DeformBottleneckBlock, ResNetBlockBase
|
||||
|
||||
from .layers.wrappers import Conv2dv2
|
||||
|
||||
__all__ = ["BUABasicStem", "BUABasicStemv2", "build_bua_resnet_backbone"]
|
||||
|
||||
class BUABasicStem(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable): a callable that takes the number of
|
||||
channels and return a `nn.Module`, or a pre-defined string
|
||||
(one of {"FrozenBN", "BN", "GN"}).
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False,
|
||||
norm=get_norm(norm, out_channels),
|
||||
)
|
||||
weight_init.c2_msra_fill(self.conv1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu_(x)
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||
return x
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return self.conv1.out_channels
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return 4 # = stride 2 conv -> stride 2 max pool
|
||||
|
||||
class BUABasicStemv2(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable): a callable that takes the number of
|
||||
channels and return a `nn.Module`, or a pre-defined string
|
||||
(one of {"FrozenBN", "BN", "GN"}).
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm = BatchNorm2d(in_channels, eps=2e-5)
|
||||
self.conv1 = Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False,
|
||||
norm=BatchNorm2d(out_channels, eps=2e-5),
|
||||
)
|
||||
# weight_init.c2_msra_fill(self.norm)
|
||||
weight_init.c2_msra_fill(self.conv1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.conv1(x)
|
||||
x = F.relu_(x)
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||
return x
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return self.conv1.out_channels
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return 4 # = stride 2 conv -> stride 2 max pool
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_bua_resnet_backbone(cfg, input_shape):
|
||||
"""
|
||||
Create a ResNet instance from config.
|
||||
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
# need registration of new blocks/stems?
|
||||
norm = cfg.MODEL.RESNETS.NORM
|
||||
if cfg.MODEL.BUA.RESNET_VERSION == 2:
|
||||
stem = BUABasicStemv2(
|
||||
in_channels=input_shape.channels,
|
||||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||||
)
|
||||
else:
|
||||
stem = BUABasicStem(
|
||||
in_channels=input_shape.channels,
|
||||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||||
norm=norm,
|
||||
)
|
||||
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
||||
|
||||
if freeze_at >= 1:
|
||||
for p in stem.parameters():
|
||||
p.requires_grad = False
|
||||
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
|
||||
|
||||
# fmt: off
|
||||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
||||
depth = cfg.MODEL.RESNETS.DEPTH
|
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||||
bottleneck_channels = num_groups * width_per_group
|
||||
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||||
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||||
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
||||
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
||||
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
||||
# fmt: on
|
||||
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
||||
|
||||
stages = []
|
||||
|
||||
# Avoid creating variables without gradients
|
||||
# It consumes extra memory and may cause allreduce to fail
|
||||
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
|
||||
max_stage_idx = max(out_stage_idx)
|
||||
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
||||
dilation = res5_dilation if stage_idx == 5 else 1
|
||||
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
||||
stage_kargs = {
|
||||
"num_blocks": num_blocks_per_stage[idx],
|
||||
"first_stride": first_stride,
|
||||
"in_channels": in_channels,
|
||||
"bottleneck_channels": bottleneck_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_groups": num_groups,
|
||||
"norm": norm,
|
||||
"stride_in_1x1": stride_in_1x1,
|
||||
"dilation": dilation,
|
||||
}
|
||||
if deform_on_per_stage[idx]:
|
||||
stage_kargs["block_class"] = DeformBottleneckBlock
|
||||
stage_kargs["deform_modulated"] = deform_modulated
|
||||
stage_kargs["deform_num_groups"] = deform_num_groups
|
||||
else:
|
||||
stage_kargs["block_class"] = BottleneckBlock if cfg.MODEL.BUA.RESNET_VERSION == 1 else BottleneckBlockv2
|
||||
blocks = make_stage(**stage_kargs)
|
||||
in_channels = out_channels
|
||||
out_channels *= 2
|
||||
bottleneck_channels *= 2
|
||||
|
||||
if freeze_at >= stage_idx:
|
||||
for block in blocks:
|
||||
block.freeze()
|
||||
stages.append(blocks)
|
||||
return ResNet(stem, stages, out_features=out_features)
|
||||
|
||||
class BottleneckBlockv2(ResNetBlockBase):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
*,
|
||||
bottleneck_channels,
|
||||
stride=1,
|
||||
num_groups=1,
|
||||
norm="BN",
|
||||
stride_in_1x1=False,
|
||||
dilation=1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable): a callable that takes the number of
|
||||
channels and return a `nn.Module`, or a pre-defined string
|
||||
(one of {"FrozenBN", "BN", "GN"}).
|
||||
stride_in_1x1 (bool): when stride==2, whether to put stride in the
|
||||
first 1x1 convolution or the bottleneck 3x3 convolution.
|
||||
"""
|
||||
super().__init__(in_channels, out_channels, stride)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv2dv2(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
norm=None,
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
# The original MSRA ResNet models have stride in the first 1x1 conv
|
||||
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
||||
# stride in the 3x3 conv
|
||||
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
||||
|
||||
self.conv1 = Conv2dv2(
|
||||
in_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=1,
|
||||
stride=stride_1x1,
|
||||
bias=False,
|
||||
norm=None,
|
||||
)
|
||||
|
||||
self.conv2 = Conv2dv2(
|
||||
bottleneck_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=3,
|
||||
stride=stride_3x3,
|
||||
padding=1 * dilation,
|
||||
bias=False,
|
||||
groups=num_groups,
|
||||
dilation=dilation,
|
||||
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
|
||||
activation=F.relu_,
|
||||
)
|
||||
|
||||
self.conv3 = Conv2dv2(
|
||||
bottleneck_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
|
||||
activation=F.relu_,
|
||||
)
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
||||
if layer is not None: # shortcut can be None
|
||||
weight_init.c2_msra_fill(layer)
|
||||
|
||||
self.norm = BatchNorm2d(in_channels, eps=2e-5)
|
||||
|
||||
# Zero-initialize the last normalization in each residual branch,
|
||||
# so that at the beginning, the residual branch starts with zeros,
|
||||
# and each residual block behaves like an identity.
|
||||
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
||||
# "For BN layers, the learnable scaling coefficient γ is initialized
|
||||
# to be 1, except for each residual block's last BN
|
||||
# where γ is initialized to be 0."
|
||||
|
||||
# nn.init.constant_(self.conv3.norm.weight, 0)
|
||||
# TODO this somehow hurts performance when training GN models from scratch.
|
||||
# Add it as an option when we need to use this code to train a backbone.
|
||||
|
||||
def forward(self, x):
|
||||
x_2 = self.norm(x)
|
||||
x_2 = F.relu_(x_2)
|
||||
|
||||
out = self.conv1(x_2)
|
||||
# out = F.relu_(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
# out = F.relu_(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(x_2)
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
out += shortcut
|
||||
# out = F.relu_(out)
|
||||
return out
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import math
|
||||
import torch
|
||||
from detectron2.structures import Boxes
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
# Value for clamping large dw and dh predictions. The heuristic is that we clamp
|
||||
# such that dw and dh are no larger than what would transform a 16px box into a
|
||||
# 1000px box (based on a small anchor, 16px, and a typical image size, 1000px).
|
||||
_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
|
||||
|
||||
|
||||
__all__ = ["BUABoxes", "BUABox2BoxTransform"]
|
||||
|
||||
class BUABoxes(Boxes):
|
||||
"""
|
||||
This structure stores a list of boxes as a Nx4 torch.Tensor.
|
||||
It supports some common methods about boxes
|
||||
(`area`, `clip`, `nonempty`, etc),
|
||||
and also behaves like a Tensor
|
||||
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
||||
|
||||
Attributes:
|
||||
tensor: float matrix of Nx4.
|
||||
"""
|
||||
|
||||
BoxSizeType = Union[List[int], Tuple[int, int]]
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
super().__init__(tensor)
|
||||
|
||||
def clip(self, box_size: BoxSizeType) -> None:
|
||||
"""
|
||||
NOTE: In order to be the same as bottom-up-attention network, we have
|
||||
defined the new clip function.
|
||||
|
||||
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
||||
and y coordinates to the range [0, height].
|
||||
|
||||
Args:
|
||||
box_size (height, width): The clipping box's size.
|
||||
"""
|
||||
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
|
||||
TO_REMOVE = 1
|
||||
h, w = box_size
|
||||
self.tensor[:, 0].clamp_(min=0, max=w - TO_REMOVE)
|
||||
self.tensor[:, 1].clamp_(min=0, max=h - TO_REMOVE)
|
||||
self.tensor[:, 2].clamp_(min=0, max=w - TO_REMOVE)
|
||||
self.tensor[:, 3].clamp_(min=0, max=h - TO_REMOVE)
|
||||
|
||||
def nonempty(self, threshold: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
NOTE: In order to be the same as bottom-up-attention network, we have
|
||||
defined the new nonempty function.
|
||||
|
||||
Find boxes that are non-empty.
|
||||
A box is considered empty, if either of its side is no larger than threshold.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
a binary vector which represents whether each box is empty
|
||||
(False) or non-empty (True).
|
||||
"""
|
||||
TO_REMOVE = 1
|
||||
box = self.tensor
|
||||
widths = box[:, 2] - box[:, 0] + TO_REMOVE
|
||||
heights = box[:, 3] - box[:, 1] + TO_REMOVE
|
||||
keep = (widths > threshold) & (heights > threshold)
|
||||
return keep
|
||||
|
||||
def filter_boxes(self):
|
||||
box = self.tensor
|
||||
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
|
||||
return keep
|
||||
|
||||
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Boxes":
|
||||
"""
|
||||
Returns:
|
||||
BUABoxes: Create a new :class:`BUABoxes` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
|
||||
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
||||
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned Boxes might share storage with this Boxes,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return BUABoxes(self.tensor[item].view(1, -1))
|
||||
b = self.tensor[item]
|
||||
assert b.dim() == 2, "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
||||
return BUABoxes(b)
|
||||
|
||||
class BUABox2BoxTransform(object):
|
||||
"""
|
||||
The box-to-box transform defined in R-CNN. The transformation is parameterized
|
||||
by 4 deltas: (dx, dy, dw, dh). The transformation scales the box's width and height
|
||||
by exp(dw), exp(dh) and shifts a box's center by the offset (dx * width, dy * height).
|
||||
"""
|
||||
|
||||
def __init__(self, weights, scale_clamp=_DEFAULT_SCALE_CLAMP):
|
||||
"""
|
||||
Args:
|
||||
weights (4-element tuple): Scaling factors that are applied to the
|
||||
(dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set
|
||||
such that the deltas have unit variance; now they are treated as
|
||||
hyperparameters of the system.
|
||||
scale_clamp (float): When predicting deltas, the predicted box scaling
|
||||
factors (dw and dh) are clamped such that they are <= scale_clamp.
|
||||
"""
|
||||
self.weights = weights
|
||||
self.scale_clamp = scale_clamp
|
||||
|
||||
def get_deltas(self, src_boxes, target_boxes):
|
||||
"""
|
||||
Get box regression transformation deltas (dx, dy, dw, dh) that can be used
|
||||
to transform the `src_boxes` into the `target_boxes`. That is, the relation
|
||||
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
|
||||
any delta is too large and is clamped).
|
||||
|
||||
Args:
|
||||
src_boxes (Tensor): source boxes, e.g., object proposals
|
||||
target_boxes (Tensor): target of the transformation, e.g., ground-truth
|
||||
boxes.
|
||||
"""
|
||||
assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
|
||||
assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
|
||||
|
||||
TO_REMOVE = 1 # TODO remove
|
||||
src_widths = src_boxes[:, 2] - src_boxes[:, 0] + TO_REMOVE
|
||||
src_heights = src_boxes[:, 3] - src_boxes[:, 1] + TO_REMOVE
|
||||
src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths
|
||||
src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights
|
||||
|
||||
target_widths = target_boxes[:, 2] - target_boxes[:, 0] + TO_REMOVE
|
||||
target_heights = target_boxes[:, 3] - target_boxes[:, 1] + TO_REMOVE
|
||||
target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
|
||||
target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
|
||||
|
||||
wx, wy, ww, wh = self.weights
|
||||
dx = wx * (target_ctr_x - src_ctr_x) / src_widths
|
||||
dy = wy * (target_ctr_y - src_ctr_y) / src_heights
|
||||
dw = ww * torch.log(target_widths / src_widths)
|
||||
dh = wh * torch.log(target_heights / src_heights)
|
||||
|
||||
deltas = torch.stack((dx, dy, dw, dh), dim=1)
|
||||
assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!"
|
||||
return deltas
|
||||
|
||||
def apply_deltas(self, deltas, boxes):
|
||||
"""
|
||||
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
|
||||
|
||||
Args:
|
||||
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
|
||||
deltas[i] represents k potentially different class-specific
|
||||
box transformations for the single box boxes[i].
|
||||
boxes (Tensor): boxes to transform, of shape (N, 4)
|
||||
"""
|
||||
assert torch.isfinite(deltas).all().item(), "Box regression deltas become infinite or NaN!"
|
||||
boxes = boxes.to(deltas.dtype)
|
||||
|
||||
TO_REMOVE = 1 # TODO remove
|
||||
widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
|
||||
heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
|
||||
ctr_x = boxes[:, 0] + 0.5 * widths
|
||||
ctr_y = boxes[:, 1] + 0.5 * heights
|
||||
|
||||
wx, wy, ww, wh = self.weights
|
||||
dx = deltas[:, 0::4] / wx
|
||||
dy = deltas[:, 1::4] / wy
|
||||
dw = deltas[:, 2::4] / ww
|
||||
dh = deltas[:, 3::4] / wh
|
||||
|
||||
# Prevent sending too large values into torch.exp()
|
||||
dw = torch.clamp(dw, max=self.scale_clamp)
|
||||
dh = torch.clamp(dh, max=self.scale_clamp)
|
||||
|
||||
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
||||
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
||||
pred_w = torch.exp(dw) * widths[:, None]
|
||||
pred_h = torch.exp(dh) * heights[:, None]
|
||||
|
||||
pred_boxes = torch.zeros_like(deltas)
|
||||
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
|
||||
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
|
||||
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
|
||||
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
|
||||
return pred_boxes
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from detectron2.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_bottom_up_attention_config(cfg, caffe=False):
|
||||
"""
|
||||
Add config for tridentnet.
|
||||
"""
|
||||
_C = cfg
|
||||
|
||||
_C.MODEL.BUA = CN()
|
||||
_C.MODEL.BUA.CAFFE = caffe
|
||||
_C.MODEL.BUA.RESNET_VERSION = 1
|
||||
_C.MODEL.BUA.ATTRIBUTE_ON = False
|
||||
_C.MODEL.BUA.EXTRACT_FEATS = False
|
||||
|
||||
_C.MODEL.BUA.RPN = CN()
|
||||
# out_channels of conv for bottom-up-attentions RPN.
|
||||
_C.MODEL.BUA.RPN.CONV_OUT_CHANNELS = 512
|
||||
|
||||
_C.MODEL.BUA.EXTRACTOR = CN()
|
||||
|
||||
# EXTRACTOR.MODE {1: extract roi features, 2: extract bbox only ,3: extract roi features by gt_bbox}
|
||||
_C.MODEL.BUA.EXTRACTOR.MODE = 1
|
||||
|
||||
# config of postprocessing in extractor
|
||||
_C.MODEL.BUA.EXTRACTOR.MIN_BOXES = 10
|
||||
_C.MODEL.BUA.EXTRACTOR.MAX_BOXES = 100
|
||||
_C.MODEL.BUA.EXTRACTOR.CONF_THRESH = 0.2
|
||||
_C.MODEL.BUA.EXTRACTOR.OUTPUT_DIR = ".output/"
|
||||
|
||||
_C.MODEL.BUA.ATTRIBUTE = CN()
|
||||
_C.MODEL.BUA.ATTRIBUTE.NUM_CLASSES = 401
|
|
@ -0,0 +1,594 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from fvcore.nn import smooth_l1_loss
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.layers import cat
|
||||
from detectron2.structures import Instances
|
||||
from detectron2.utils.events import get_event_storage
|
||||
from detectron2.modeling.roi_heads import select_foreground_proposals
|
||||
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference, fast_rcnn_inference_single_image, FastRCNNOutputs
|
||||
|
||||
from .layers.nms import batched_nms
|
||||
from .box_regression import BUABoxes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
Shape shorthand in this module:
|
||||
|
||||
N: number of images in the minibatch
|
||||
R: number of ROIs, combined over all images, in the minibatch
|
||||
Ri: number of ROIs in image i
|
||||
K: number of foreground classes. E.g.,there are 80 foreground classes in COCO.
|
||||
|
||||
Naming convention:
|
||||
|
||||
deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box
|
||||
transform (see :class:`box_regression.Box2BoxTransform`).
|
||||
|
||||
pred_class_logits: predicted class scores in [-inf, +inf]; use
|
||||
softmax(pred_class_logits) to estimate P(class).
|
||||
|
||||
gt_classes: ground-truth classification labels in [0, K], where [0, K) represent
|
||||
foreground object classes and K represents the background class.
|
||||
|
||||
pred_proposal_deltas: predicted box2box transform deltas for transforming proposals
|
||||
to detection box predictions.
|
||||
|
||||
gt_proposal_deltas: ground-truth box2box transform deltas
|
||||
"""
|
||||
|
||||
class BUACaffeFastRCNNOutputs(object):
|
||||
"""
|
||||
A class that stores information about outputs of a Fast R-CNN head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, box2box_transform, pred_class_logits, pred_proposal_deltas, proposals, smooth_l1_beta, image_scales, attr_on=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
box2box_transform (Box2BoxTransform/Box2BoxTransformRotated):
|
||||
box2box transform instance for proposal-to-detection transformations.
|
||||
pred_class_logits (Tensor): A tensor of shape (R, K + 1) storing the predicted class
|
||||
logits for all R predicted object instances.
|
||||
Each row corresponds to a predicted object instance.
|
||||
pred_proposal_deltas (Tensor): A tensor of shape (R, K * B) or (R, B) for
|
||||
class-specific or class-agnostic regression. It stores the predicted deltas that
|
||||
transform proposals into final box detections.
|
||||
B is the box dimension (4 or 5).
|
||||
When B is 4, each row is [dx, dy, dw, dh (, ....)].
|
||||
When B is 5, each row is [dx, dy, dw, dh, da (, ....)].
|
||||
proposals (list[Instances]): A list of N Instances, where Instances i stores the
|
||||
proposals for image i, in the field "proposal_boxes".
|
||||
When training, each Instances must have ground-truth labels
|
||||
stored in the field "gt_classes" and "gt_boxes".
|
||||
smooth_l1_beta (float): The transition point between L1 and L2 loss in
|
||||
the smooth L1 loss function. When set to 0, the loss becomes L1. When
|
||||
set to +inf, the loss becomes constant 0.
|
||||
"""
|
||||
self.box2box_transform = box2box_transform
|
||||
self.num_preds_per_image = [len(p) for p in proposals]
|
||||
self.pred_class_logits = pred_class_logits
|
||||
self.pred_proposal_deltas = pred_proposal_deltas
|
||||
self.smooth_l1_beta = smooth_l1_beta
|
||||
self.image_scales = image_scales
|
||||
self.attr_on = attr_on
|
||||
|
||||
box_type = type(proposals[0].proposal_boxes)
|
||||
# cat(..., dim=0) concatenates over all images in the batch
|
||||
self.proposals = box_type.cat([p.proposal_boxes for p in proposals])
|
||||
assert not self.proposals.tensor.requires_grad, "Proposals should not require gradients!"
|
||||
self.image_shapes = [x.image_size for x in proposals]
|
||||
|
||||
# The following fields should exist only when training.
|
||||
if proposals[0].has("gt_boxes"):
|
||||
self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
|
||||
assert proposals[0].has("gt_classes")
|
||||
self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
|
||||
|
||||
def fast_rcnn_inference(self, boxes, scores, image_shapes, image_scales, score_thresh, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Call `fast_rcnn_inference_single_image` for all images.
|
||||
|
||||
Args:
|
||||
boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic
|
||||
boxes for each image. Element i has shape (Ri, K * 4) if doing
|
||||
class-specific regression, or (Ri, 4) if doing class-agnostic
|
||||
regression, where Ri is the number of predicted objects for image i.
|
||||
This is compatible with the output of :meth:`FastRCNNOutputs.predict_boxes`.
|
||||
scores (list[Tensor]): A list of Tensors of predicted class scores for each image.
|
||||
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
|
||||
for image i. Compatible with the output of :meth:`FastRCNNOutputs.predict_probs`.
|
||||
image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.
|
||||
score_thresh (float): Only return detections with a confidence score exceeding this
|
||||
threshold.
|
||||
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
|
||||
all detections.
|
||||
|
||||
Returns:
|
||||
instances: (list[Instances]): A list of N instances, one for each image in the batch,
|
||||
that stores the topk most confidence detections.
|
||||
kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates
|
||||
the corresponding boxes/scores index in [0, Ri) from the input, for image i.
|
||||
"""
|
||||
result_per_image = [
|
||||
self.fast_rcnn_inference_single_image(
|
||||
boxes_per_image, scores_per_image, image_shape, image_scale, score_thresh, nms_thresh, topk_per_image
|
||||
)
|
||||
for scores_per_image, boxes_per_image, image_shape, image_scale in zip(scores, boxes, image_shapes, image_scales)
|
||||
]
|
||||
return tuple(list(x) for x in zip(*result_per_image))
|
||||
|
||||
def fast_rcnn_inference_single_image(
|
||||
self, boxes, scores, image_shape, image_scale, score_thresh, nms_thresh, topk_per_image
|
||||
):
|
||||
"""
|
||||
Single-image inference. Return bounding-box detection results by thresholding
|
||||
on scores and applying non-maximum suppression (NMS).
|
||||
|
||||
Args:
|
||||
Same as `fast_rcnn_inference`, but with boxes, scores, and image shapes
|
||||
per image.
|
||||
|
||||
Returns:
|
||||
Same as `fast_rcnn_inference`, but for only one image.
|
||||
"""
|
||||
scores = scores[:, 1:]
|
||||
boxes = boxes[:, 4:]
|
||||
num_bbox_reg_classes = boxes.shape[1] // 4
|
||||
# Convert to Boxes to use the `clip` function ...
|
||||
boxes = BUABoxes(boxes.reshape(-1, 4))
|
||||
boxes.clip((image_shape[0]/image_scale, image_shape[1]/image_scale))
|
||||
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
|
||||
|
||||
# Filter results based on detection scores
|
||||
filter_mask = scores > score_thresh # R x K
|
||||
# R' x 2. First column contains indices of the R predictions;
|
||||
# Second column contains indices of classes.
|
||||
filter_inds = filter_mask.nonzero()
|
||||
if num_bbox_reg_classes == 1:
|
||||
boxes = boxes[filter_inds[:, 0], 0]
|
||||
else:
|
||||
boxes = boxes[filter_mask]
|
||||
scores = scores[filter_mask]
|
||||
|
||||
# Apply per-class NMS
|
||||
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
|
||||
if topk_per_image >= 0:
|
||||
keep = keep[:topk_per_image]
|
||||
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
||||
|
||||
result = Instances(image_shape)
|
||||
result.pred_boxes = BUABoxes(boxes)
|
||||
result.scores = scores
|
||||
result.pred_classes = filter_inds[:, 1]
|
||||
return result, filter_inds[:, 0]
|
||||
|
||||
def predict_boxes(self):
|
||||
"""
|
||||
Returns:
|
||||
list[Tensor]: A list of Tensors of predicted class-specific or class-agnostic boxes
|
||||
for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
|
||||
the number of predicted objects for image i and B is the box dimension (4 or 5)
|
||||
"""
|
||||
# Always use 1 image per worker during inference since this is the
|
||||
# standard when reporting inference time in papers.
|
||||
self.proposals.scale(1.0/self.image_scales[0], 1.0/self.image_scales[0])
|
||||
num_pred = len(self.proposals)
|
||||
B = self.proposals.tensor.shape[1]
|
||||
K = self.pred_proposal_deltas.shape[1] // B
|
||||
boxes = self.box2box_transform.apply_deltas(
|
||||
self.pred_proposal_deltas,
|
||||
self.proposals.tensor,
|
||||
)
|
||||
return boxes.view(num_pred, K * B).split(self.num_preds_per_image, dim=0)
|
||||
|
||||
def predict_probs(self):
|
||||
"""
|
||||
Returns:
|
||||
list[Tensor]: A list of Tensors of predicted class probabilities for each image.
|
||||
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
|
||||
for image i.
|
||||
"""
|
||||
probs = F.softmax(self.pred_class_logits, dim=-1)
|
||||
return probs.split(self.num_preds_per_image, dim=0)
|
||||
|
||||
def inference(self, score_thresh, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Args:
|
||||
score_thresh (float): same as fast_rcnn_inference.
|
||||
nms_thresh (float): same as fast_rcnn_inference.
|
||||
topk_per_image (int): same as fast_rcnn_inference.
|
||||
Returns:
|
||||
list[Instances]: same as fast_rcnn_inference.
|
||||
list[Tensor]: same as fast_rcnn_inference.
|
||||
"""
|
||||
boxes = self.predict_boxes()
|
||||
scores = self.predict_probs()
|
||||
image_shapes = self.image_shapes
|
||||
image_scales = self.image_scales
|
||||
|
||||
return self.fast_rcnn_inference(
|
||||
boxes, scores, image_shapes, image_scales, score_thresh, nms_thresh, topk_per_image
|
||||
)
|
||||
|
||||
class BUACaffeFastRCNNOutputLayers(nn.Module):
|
||||
"""
|
||||
Two linear layers for predicting Fast R-CNN outputs:
|
||||
(1) proposal-to-detection box regression deltas
|
||||
(2) classification scores
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_classes, cls_agnostic_bbox_reg, box_dim=4, attr_on=False, num_attr_classes=401):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): channels, or (channels, height, width)
|
||||
num_classes (int): number of foreground classes
|
||||
cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression
|
||||
box_dim (int): the dimension of bounding boxes.
|
||||
Example box dimensions: 4 for regular XYXY boxes and 5 for rotated XYWHA boxes
|
||||
"""
|
||||
super(BUACaffeFastRCNNOutputLayers, self).__init__()
|
||||
|
||||
if not isinstance(input_size, int):
|
||||
input_size = np.prod(input_size)
|
||||
self.attr_on = attr_on
|
||||
|
||||
# The prediction layer for num_classes foreground classes and one background class
|
||||
# (hence + 1)
|
||||
self.cls_score = nn.Linear(input_size, num_classes)
|
||||
num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
|
||||
self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim)
|
||||
|
||||
nn.init.normal_(self.cls_score.weight, std=0.01)
|
||||
nn.init.normal_(self.bbox_pred.weight, std=0.001)
|
||||
for l in [self.cls_score, self.bbox_pred]:
|
||||
nn.init.constant_(l.bias, 0)
|
||||
|
||||
if self.attr_on:
|
||||
self.cls_embed = nn.Embedding(num_classes, 256)
|
||||
self.attr_linear1 = nn.Linear(input_size + 256, 512)
|
||||
self.attr_linear2 = nn.Linear(512, num_attr_classes)
|
||||
|
||||
nn.init.normal_(self.cls_embed.weight, std=0.01)
|
||||
nn.init.normal_(self.attr_linear1.weight, std=0.01)
|
||||
nn.init.normal_(self.attr_linear2.weight, std=0.01)
|
||||
nn.init.constant_(self.attr_linear1.bias, 0)
|
||||
nn.init.constant_(self.attr_linear2.bias, 0)
|
||||
|
||||
def forward(self, x, proposal_boxes=None):
|
||||
if x.dim() > 2:
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
scores = self.cls_score(x)
|
||||
proposal_deltas = self.bbox_pred(x)
|
||||
|
||||
if self.attr_on:
|
||||
|
||||
# get labels and indices of proposals with foreground
|
||||
all_labels = torch.argmax(scores, dim=1)
|
||||
|
||||
# get embeddings of indices using gt cls labels
|
||||
cls_embed_out = self.cls_embed(all_labels)
|
||||
|
||||
# concat with fc7 feats
|
||||
concat_attr = cat([x, cls_embed_out], dim=1)
|
||||
|
||||
# pass through attr head layers
|
||||
fc_attr = self.attr_linear1(concat_attr)
|
||||
attr_score = F.softmax(self.attr_linear2(F.relu(fc_attr)), dim=-1)
|
||||
return scores, proposal_deltas, attr_score
|
||||
|
||||
return scores, proposal_deltas
|
||||
|
||||
class BUADetection2FastRCNNOutputs(FastRCNNOutputs):
|
||||
"""
|
||||
A class that stores information about outputs of a Fast R-CNN head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
self, box2box_transform, pred_class_logits, pred_proposal_deltas, proposals, smooth_l1_beta, attr_on=False, pred_attribute_logits=None, num_attr_classes=400, gt_attributes=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
box2box_transform (Box2BoxTransform/Box2BoxTransformRotated):
|
||||
box2box transform instance for proposal-to-detection transformations.
|
||||
pred_class_logits (Tensor): A tensor of shape (R, K + 1) storing the predicted class
|
||||
logits for all R predicted object instances.
|
||||
Each row corresponds to a predicted object instance.
|
||||
pred_proposal_deltas (Tensor): A tensor of shape (R, K * B) or (R, B) for
|
||||
class-specific or class-agnostic regression. It stores the predicted deltas that
|
||||
transform proposals into final box detections.
|
||||
B is the box dimension (4 or 5).
|
||||
When B is 4, each row is [dx, dy, dw, dh (, ....)].
|
||||
When B is 5, each row is [dx, dy, dw, dh, da (, ....)].
|
||||
pred_attribute_logits (Tensor:) A tensor of shape (R, C) storing the predicted attribute
|
||||
logits for all R predicted object instances.
|
||||
proposals (list[Instances]): A list of N Instances, where Instances i stores the
|
||||
proposals for image i, in the field "proposal_boxes".
|
||||
When training, each Instances must have ground-truth labels
|
||||
stored in the field "gt_classes" and "gt_boxes".
|
||||
smooth_l1_beta (float): The transition point between L1 and L2 loss in
|
||||
the smooth L1 loss function. When set to 0, the loss becomes L1. When
|
||||
set to +inf, the loss becomes constant 0.
|
||||
"""
|
||||
self.attr_on = attr_on
|
||||
self.box2box_transform = box2box_transform
|
||||
self.num_preds_per_image = [len(p) for p in proposals]
|
||||
self.pred_class_logits = pred_class_logits
|
||||
self.pred_proposal_deltas = pred_proposal_deltas
|
||||
|
||||
if self.attr_on:
|
||||
self.pred_attribute_logits = pred_attribute_logits
|
||||
self.gt_attributes = gt_attributes
|
||||
self.smooth_l1_beta = smooth_l1_beta
|
||||
|
||||
box_type = type(proposals[0].proposal_boxes)
|
||||
# cat(..., dim=0) concatenates over all images in the batch
|
||||
self.proposals = box_type.cat([p.proposal_boxes for p in proposals])
|
||||
assert not self.proposals.tensor.requires_grad, "Proposals should not require gradients!"
|
||||
self.image_shapes = [x.image_size for x in proposals]
|
||||
self.num_attr_classes = num_attr_classes
|
||||
|
||||
# The following fields should exist only when training.
|
||||
if proposals[0].has("gt_boxes"):
|
||||
self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
|
||||
assert proposals[0].has("gt_classes")
|
||||
self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
|
||||
|
||||
def _log_accuracy(self):
|
||||
"""
|
||||
Log the accuracy metrics to EventStorage.
|
||||
"""
|
||||
num_instances = self.gt_classes.numel()
|
||||
pred_classes = self.pred_class_logits.argmax(dim=1)
|
||||
bg_class_ind = self.pred_class_logits.shape[1] - 1
|
||||
|
||||
fg_inds = (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind)
|
||||
num_fg = fg_inds.nonzero().numel()
|
||||
fg_gt_classes = self.gt_classes[fg_inds]
|
||||
fg_pred_classes = pred_classes[fg_inds]
|
||||
|
||||
num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel()
|
||||
num_accurate = (pred_classes == self.gt_classes).nonzero().numel()
|
||||
fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel()
|
||||
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("fast_rcnn/cls_accuracy", num_accurate / num_instances)
|
||||
if num_fg > 0:
|
||||
storage.put_scalar("fast_rcnn/fg_cls_accuracy", fg_num_accurate / num_fg)
|
||||
storage.put_scalar("fast_rcnn/false_negative", num_false_negative / num_fg)
|
||||
|
||||
def softmax_cross_entropy_loss(self):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
|
||||
Returns:
|
||||
scalar Tensor
|
||||
"""
|
||||
self._log_accuracy()
|
||||
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
|
||||
|
||||
def smooth_l1_loss(self):
|
||||
"""
|
||||
Compute the smooth L1 loss for box regression.
|
||||
|
||||
Returns:
|
||||
scalar Tensor
|
||||
"""
|
||||
gt_proposal_deltas = self.box2box_transform.get_deltas(
|
||||
self.proposals.tensor, self.gt_boxes.tensor
|
||||
)
|
||||
box_dim = gt_proposal_deltas.size(1) # 4 or 5
|
||||
cls_agnostic_bbox_reg = self.pred_proposal_deltas.size(1) == box_dim
|
||||
device = self.pred_proposal_deltas.device
|
||||
|
||||
bg_class_ind = self.pred_class_logits.shape[1] - 1
|
||||
|
||||
# Box delta loss is only computed between the prediction for the gt class k
|
||||
# (if 0 <= k < bg_class_ind) and the target; there is no loss defined on predictions
|
||||
# for non-gt classes and background.
|
||||
# Empty fg_inds produces a valid loss of zero as long as the size_average
|
||||
# arg to smooth_l1_loss is False (otherwise it uses torch.mean internally
|
||||
# and would produce a nan loss).
|
||||
fg_inds = torch.nonzero((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind)).squeeze(
|
||||
1
|
||||
)
|
||||
if cls_agnostic_bbox_reg:
|
||||
# pred_proposal_deltas only corresponds to foreground class for agnostic
|
||||
gt_class_cols = torch.arange(box_dim, device=device)
|
||||
else:
|
||||
fg_gt_classes = self.gt_classes[fg_inds]
|
||||
# pred_proposal_deltas for class k are located in columns [b * k : b * k + b],
|
||||
# where b is the dimension of box representation (4 or 5)
|
||||
# Note that compared to Detectron1,
|
||||
# we do not perform bounding box regression for background classes.
|
||||
gt_class_cols = box_dim * fg_gt_classes[:, None] + torch.arange(box_dim, device=device)
|
||||
|
||||
loss_box_reg = smooth_l1_loss(
|
||||
self.pred_proposal_deltas[fg_inds[:, None], gt_class_cols],
|
||||
gt_proposal_deltas[fg_inds],
|
||||
self.smooth_l1_beta,
|
||||
reduction="sum",
|
||||
)
|
||||
# The loss is normalized using the total number of regions (R), not the number
|
||||
# of foreground regions even though the box regression loss is only defined on
|
||||
# foreground regions. Why? Because doing so gives equal training influence to
|
||||
# each foreground example. To see how, consider two different minibatches:
|
||||
# (1) Contains a single foreground region
|
||||
# (2) Contains 100 foreground regions
|
||||
# If we normalize by the number of foreground regions, the single example in
|
||||
# minibatch (1) will be given 100 times as much influence as each foreground
|
||||
# example in minibatch (2). Normalizing by the total number of regions, R,
|
||||
# means that the single example in minibatch (1) and each of the 100 examples
|
||||
# in minibatch (2) are given equal influence.
|
||||
loss_box_reg = loss_box_reg / self.gt_classes.numel()
|
||||
return loss_box_reg
|
||||
|
||||
def attribute_loss(self):
|
||||
fg_gt_attributes = self.gt_attributes
|
||||
n_boxes = self.pred_attribute_logits.shape[0]
|
||||
self.pred_attribute_logits = self.pred_attribute_logits.unsqueeze(1)
|
||||
self.pred_attribute_logits = self.pred_attribute_logits.expand(n_boxes, 16, self.num_attr_classes).contiguous().view(-1, self.num_attr_classes)
|
||||
|
||||
inv_per_box_weights = (
|
||||
(fg_gt_attributes >= 0).sum(dim=1).repeat(16, 1).transpose(0, 1).flatten()
|
||||
)
|
||||
per_box_weights = inv_per_box_weights.float().reciprocal()
|
||||
per_box_weights[per_box_weights > 1] = 0.0
|
||||
|
||||
fg_gt_attributes = fg_gt_attributes.view(-1)
|
||||
attributes_loss = 0.5 * F.cross_entropy(
|
||||
self.pred_attribute_logits, fg_gt_attributes, reduction="none", ignore_index=-1
|
||||
)
|
||||
|
||||
attributes_loss = (attributes_loss * per_box_weights).view(n_boxes, -1).sum(dim=1)
|
||||
|
||||
n_valid_boxes = len(attributes_loss.nonzero())
|
||||
|
||||
if n_valid_boxes > 0:
|
||||
attributes_loss = (attributes_loss / n_valid_boxes).sum()
|
||||
else:
|
||||
attributes_loss = (attributes_loss * 0.0).sum()
|
||||
return attributes_loss
|
||||
|
||||
def losses(self):
|
||||
"""
|
||||
Compute the default losses for box head in Fast(er) R-CNN,
|
||||
with softmax cross entropy loss and smooth L1 loss.
|
||||
|
||||
Returns:
|
||||
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
|
||||
"""
|
||||
return {
|
||||
"loss_cls": self.softmax_cross_entropy_loss(),
|
||||
"loss_box_reg": self.smooth_l1_loss(),
|
||||
"loss_attr": self.attribute_loss() if self.attr_on else 0.,
|
||||
}
|
||||
|
||||
def predict_boxes(self):
|
||||
"""
|
||||
Returns:
|
||||
list[Tensor]: A list of Tensors of predicted class-specific or class-agnostic boxes
|
||||
for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
|
||||
the number of predicted objects for image i and B is the box dimension (4 or 5)
|
||||
"""
|
||||
num_pred = len(self.proposals)
|
||||
B = self.proposals.tensor.shape[1]
|
||||
K = self.pred_proposal_deltas.shape[1] // B
|
||||
boxes = self.box2box_transform.apply_deltas(
|
||||
self.pred_proposal_deltas.view(num_pred * K, B),
|
||||
self.proposals.tensor.unsqueeze(1).expand(num_pred, K, B).reshape(-1, B),
|
||||
)
|
||||
return boxes.view(num_pred, K * B).split(self.num_preds_per_image, dim=0)
|
||||
|
||||
def predict_probs(self):
|
||||
"""
|
||||
Returns:
|
||||
list[Tensor]: A list of Tensors of predicted class probabilities for each image.
|
||||
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
|
||||
for image i.
|
||||
"""
|
||||
probs = F.softmax(self.pred_class_logits, dim=-1)
|
||||
return probs.split(self.num_preds_per_image, dim=0)
|
||||
|
||||
def inference(self, score_thresh, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Args:
|
||||
score_thresh (float): same as fast_rcnn_inference.
|
||||
nms_thresh (float): same as fast_rcnn_inference.
|
||||
topk_per_image (int): same as fast_rcnn_inference.
|
||||
Returns:
|
||||
list[Instances]: same as fast_rcnn_inference.
|
||||
list[Tensor]: same as fast_rcnn_inference.
|
||||
"""
|
||||
boxes = self.predict_boxes()
|
||||
scores = self.predict_probs()
|
||||
image_shapes = self.image_shapes
|
||||
|
||||
return fast_rcnn_inference(
|
||||
boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image
|
||||
)
|
||||
|
||||
class BUADetectron2FastRCNNOutputLayers(nn.Module):
|
||||
"""
|
||||
Two linear layers for predicting Fast R-CNN outputs:
|
||||
(1) proposal-to-detection box regression deltas
|
||||
(2) classification scores
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_classes, cls_agnostic_bbox_reg, box_dim=4, attr_on=False, num_attr_classes=400):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): channels, or (channels, height, width)
|
||||
num_classes (int): number of foreground classes
|
||||
cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression
|
||||
box_dim (int): the dimension of bounding boxes.
|
||||
Example box dimensions: 4 for regular XYXY boxes and 5 for rotated XYWHA boxes
|
||||
"""
|
||||
super(BUADetectron2FastRCNNOutputLayers, self).__init__()
|
||||
self.attr_on = attr_on
|
||||
self.num_classes = num_classes
|
||||
self.num_attr_classes = num_attr_classes
|
||||
|
||||
if not isinstance(input_size, int):
|
||||
input_size = np.prod(input_size)
|
||||
|
||||
# The prediction layer for num_classes foreground classes and one background class
|
||||
# (hence + 1)
|
||||
self.cls_score = nn.Linear(input_size, num_classes + 1)
|
||||
num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
|
||||
self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim)
|
||||
|
||||
nn.init.normal_(self.cls_score.weight, std=0.01)
|
||||
nn.init.normal_(self.bbox_pred.weight, std=0.001)
|
||||
for l in [self.cls_score, self.bbox_pred]:
|
||||
nn.init.constant_(l.bias, 0)
|
||||
|
||||
if self.attr_on:
|
||||
self.cls_embed = nn.Embedding(num_classes+1, 256)
|
||||
self.attr_linear1 = nn.Linear(input_size + 256, 512)
|
||||
self.attr_linear2 = nn.Linear(512, num_attr_classes)
|
||||
|
||||
# nn.init.normal_(self.cls_embed.weight, std=0.01)
|
||||
nn.init.normal_(self.attr_linear1.weight, std=0.01)
|
||||
nn.init.normal_(self.attr_linear2.weight, std=0.01)
|
||||
nn.init.constant_(self.attr_linear1.bias, 0)
|
||||
nn.init.constant_(self.attr_linear2.bias, 0)
|
||||
|
||||
def forward(self, x, proposal_boxes=None):
|
||||
if x.dim() > 2:
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
scores = self.cls_score(x)
|
||||
proposal_deltas = self.bbox_pred(x)
|
||||
|
||||
if self.attr_on:
|
||||
if self.training:
|
||||
assert proposal_boxes is not None, "Proposals are None while attr=True"
|
||||
proposals, fg_selection_atrributes = select_foreground_proposals(proposal_boxes, self.num_classes)
|
||||
attribute_features = x[torch.cat(fg_selection_atrributes, dim=0)]
|
||||
cls_labels = torch.cat([prop.gt_classes for prop in proposals])
|
||||
|
||||
else:
|
||||
# get labels and indices of proposals with foreground
|
||||
cls_labels = torch.argmax(scores, dim=1)
|
||||
attribute_features = x
|
||||
|
||||
# get embeddings of indices using gt cls labels
|
||||
cls_embed_out = self.cls_embed(cls_labels)
|
||||
|
||||
# concat with fc7 feats
|
||||
concat_attr = cat([attribute_features, cls_embed_out], dim=1)
|
||||
|
||||
# pass through attr head layers
|
||||
fc_attr = self.attr_linear1(concat_attr)
|
||||
attr_score = self.attr_linear2(F.relu(fc_attr))
|
||||
return scores, proposal_deltas, attr_score, cat([p.gt_attributes for p in proposals], dim=0) if self.training else None
|
||||
|
||||
return scores, proposal_deltas
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .nms import SwapAlign2Nat, swap_align2nat
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
int const threadsPerBlock = sizeof(unsigned long long) * 8;
|
||||
|
||||
__device__ inline float devIoU(float const * const a, float const * const b) {
|
||||
float left = max(a[0], b[0]), right = min(a[2], b[2]);
|
||||
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
|
||||
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
|
||||
float interS = width * height;
|
||||
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
|
||||
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
|
||||
return interS / (Sa + Sb - interS);
|
||||
}
|
||||
|
||||
__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
|
||||
const float *dev_boxes, unsigned long long *dev_mask) {
|
||||
const int row_start = blockIdx.y;
|
||||
const int col_start = blockIdx.x;
|
||||
|
||||
// if (row_start > col_start) return;
|
||||
|
||||
const int row_size =
|
||||
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
|
||||
const int col_size =
|
||||
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
|
||||
|
||||
__shared__ float block_boxes[threadsPerBlock * 5];
|
||||
if (threadIdx.x < col_size) {
|
||||
block_boxes[threadIdx.x * 5 + 0] =
|
||||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
|
||||
block_boxes[threadIdx.x * 5 + 1] =
|
||||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
|
||||
block_boxes[threadIdx.x * 5 + 2] =
|
||||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
|
||||
block_boxes[threadIdx.x * 5 + 3] =
|
||||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
|
||||
block_boxes[threadIdx.x * 5 + 4] =
|
||||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < row_size) {
|
||||
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
|
||||
const float *cur_box = dev_boxes + cur_box_idx * 5;
|
||||
int i = 0;
|
||||
unsigned long long t = 0;
|
||||
int start = 0;
|
||||
if (row_start == col_start) {
|
||||
start = threadIdx.x + 1;
|
||||
}
|
||||
for (i = start; i < col_size; i++) {
|
||||
if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
|
||||
t |= 1ULL << i;
|
||||
}
|
||||
}
|
||||
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
|
||||
dev_mask[cur_box_idx * col_blocks + col_start] = t;
|
||||
}
|
||||
}
|
||||
|
||||
// boxes is a N x 5 tensor
|
||||
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
|
||||
using scalar_t = float;
|
||||
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
|
||||
auto scores = boxes.select(1, 4);
|
||||
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
|
||||
auto boxes_sorted = boxes.index_select(0, order_t);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
|
||||
const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
|
||||
|
||||
scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();
|
||||
|
||||
THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
|
||||
|
||||
unsigned long long* mask_dev = NULL;
|
||||
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
|
||||
// boxes_num * col_blocks * sizeof(unsigned long long)));
|
||||
|
||||
mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
|
||||
|
||||
dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
|
||||
THCCeilDiv(boxes_num, threadsPerBlock));
|
||||
dim3 threads(threadsPerBlock);
|
||||
nms_kernel<<<blocks, threads>>>(boxes_num,
|
||||
nms_overlap_thresh,
|
||||
boxes_dev,
|
||||
mask_dev);
|
||||
|
||||
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
|
||||
THCudaCheck(cudaMemcpy(&mask_host[0],
|
||||
mask_dev,
|
||||
sizeof(unsigned long long) * boxes_num * col_blocks,
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
std::vector<unsigned long long> remv(col_blocks);
|
||||
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
|
||||
at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
|
||||
int64_t* keep_out = keep.data<int64_t>();
|
||||
|
||||
int num_to_keep = 0;
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / threadsPerBlock;
|
||||
int inblock = i % threadsPerBlock;
|
||||
|
||||
if (!(remv[nblock] & (1ULL << inblock))) {
|
||||
keep_out[num_to_keep++] = i;
|
||||
unsigned long long *p = &mask_host[0] + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv[j] |= p[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
THCudaFree(state, mask_dev);
|
||||
// TODO improve this part
|
||||
return std::get<0>(order_t.index({
|
||||
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
|
||||
order_t.device(), keep.scalar_type())
|
||||
}).sort(0, false));
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
#include "vision_cpu.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include "vision_cuda.h"
|
||||
#endif
|
||||
|
||||
|
||||
at::Tensor nms(const at::Tensor& dets,
|
||||
const at::Tensor& scores,
|
||||
const float threshold) {
|
||||
|
||||
if (dets.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
// TODO raise error if not compiled with CUDA
|
||||
if (dets.numel() == 0)
|
||||
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
|
||||
auto b = at::cat({dets, scores.unsqueeze(1)}, 1);
|
||||
return nms_cuda(b, threshold);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
|
||||
at::Tensor result = nms_cpu(dets, scores, threshold);
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include "vision_cpu.h"
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
at::Tensor nms_cpu_kernel(const at::Tensor& dets,
|
||||
const at::Tensor& scores,
|
||||
const float threshold) {
|
||||
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
|
||||
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
|
||||
AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
|
||||
|
||||
if (dets.numel() == 0) {
|
||||
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
|
||||
}
|
||||
|
||||
auto x1_t = dets.select(1, 0).contiguous();
|
||||
auto y1_t = dets.select(1, 1).contiguous();
|
||||
auto x2_t = dets.select(1, 2).contiguous();
|
||||
auto y2_t = dets.select(1, 3).contiguous();
|
||||
|
||||
at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
|
||||
|
||||
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
|
||||
|
||||
auto ndets = dets.size(0);
|
||||
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
|
||||
|
||||
auto suppressed = suppressed_t.data<uint8_t>();
|
||||
auto order = order_t.data<int64_t>();
|
||||
auto x1 = x1_t.data<scalar_t>();
|
||||
auto y1 = y1_t.data<scalar_t>();
|
||||
auto x2 = x2_t.data<scalar_t>();
|
||||
auto y2 = y2_t.data<scalar_t>();
|
||||
auto areas = areas_t.data<scalar_t>();
|
||||
|
||||
for (int64_t _i = 0; _i < ndets; _i++) {
|
||||
auto i = order[_i];
|
||||
if (suppressed[i] == 1)
|
||||
continue;
|
||||
auto ix1 = x1[i];
|
||||
auto iy1 = y1[i];
|
||||
auto ix2 = x2[i];
|
||||
auto iy2 = y2[i];
|
||||
auto iarea = areas[i];
|
||||
|
||||
for (int64_t _j = _i + 1; _j < ndets; _j++) {
|
||||
auto j = order[_j];
|
||||
if (suppressed[j] == 1)
|
||||
continue;
|
||||
auto xx1 = std::max(ix1, x1[j]);
|
||||
auto yy1 = std::max(iy1, y1[j]);
|
||||
auto xx2 = std::min(ix2, x2[j]);
|
||||
auto yy2 = std::min(iy2, y2[j]);
|
||||
|
||||
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
|
||||
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
|
||||
auto inter = w * h;
|
||||
auto ovr = inter / (iarea + areas[j] - inter);
|
||||
if (ovr >= threshold)
|
||||
suppressed[j] = 1;
|
||||
}
|
||||
}
|
||||
return at::nonzero(suppressed_t == 0).squeeze(1);
|
||||
}
|
||||
|
||||
at::Tensor nms_cpu(const at::Tensor& dets,
|
||||
const at::Tensor& scores,
|
||||
const float threshold) {
|
||||
at::Tensor result;
|
||||
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
|
||||
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
|
||||
});
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio);
|
||||
|
||||
|
||||
at::Tensor nms_cpu(const at::Tensor& dets,
|
||||
const at::Tensor& scores,
|
||||
const float threshold);
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
|
||||
|
||||
|
||||
at::Tensor compute_flow_cuda(const at::Tensor& boxes,
|
||||
const int height,
|
||||
const int width);
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "nms/nms.h"
|
||||
|
||||
namespace bottom_up_attention {
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("nms", &nms, "non-maximum suppression");
|
||||
}
|
||||
|
||||
} // namespace bottom_up_attention
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# from ._utils import _C
|
||||
from models.bua import _C
|
||||
|
||||
from apex import amp
|
||||
import torch
|
||||
|
||||
# Only valid with fp32 inputs - give AMP the hint
|
||||
nms = amp.float_function(_C.nms)
|
||||
|
||||
# nms.__doc__ = """
|
||||
# This function performs Non-maximum suppresion"""
|
||||
|
||||
# NOTE: In order to be consistent with bottom-up-attention, we nms core function from maskrcnn-benchmark
|
||||
|
||||
def batched_nms(boxes, scores, idxs, iou_threshold):
|
||||
"""
|
||||
Same as torchvision.ops.boxes.batched_nms, but safer.
|
||||
"""
|
||||
assert boxes.shape[-1] == 4
|
||||
boxes = boxes.cpu()
|
||||
scores = scores.cpu()
|
||||
# TODO may need better strategy.
|
||||
# Investigate after having a fully-cuda NMS op.
|
||||
if len(boxes) < 40000:
|
||||
return box_ops_batched_nms(boxes, scores, idxs, iou_threshold)
|
||||
|
||||
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
|
||||
for id in torch.unique(idxs).cpu().tolist():
|
||||
# if id == 0:
|
||||
# continue
|
||||
mask = (idxs == id).nonzero().view(-1)
|
||||
keep = nms(boxes[mask], scores[mask], iou_threshold)
|
||||
result_mask[mask[keep]] = True
|
||||
keep = result_mask.nonzero().view(-1)
|
||||
keep = keep[scores[keep].argsort(descending=True)]
|
||||
return keep
|
||||
|
||||
def box_ops_batched_nms(boxes, scores, idxs, iou_threshold):
|
||||
"""
|
||||
Performs non-maximum suppression in a batched fashion.
|
||||
|
||||
Each index value correspond to a category, and NMS
|
||||
will not be applied between elements of different categories.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
boxes : Tensor[N, 4]
|
||||
boxes where NMS will be performed. They
|
||||
are expected to be in (x1, y1, x2, y2) format
|
||||
scores : Tensor[N]
|
||||
scores for each one of the boxes
|
||||
idxs : Tensor[N]
|
||||
indices of the categories for each one of the boxes.
|
||||
iou_threshold : float
|
||||
discards all overlapping boxes
|
||||
with IoU < iou_threshold
|
||||
|
||||
Returns
|
||||
-------
|
||||
keep : Tensor
|
||||
int64 tensor with the indices of
|
||||
the elements that have been kept by NMS, sorted
|
||||
in decreasing order of scores
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
# strategy: in order to perform NMS independently per class.
|
||||
# we add an offset to all the boxes. The offset is dependent
|
||||
# only on the class idx, and is large enough so that boxes
|
||||
# from different classes do not overlap
|
||||
max_coordinate = boxes.max()
|
||||
offsets = idxs.to(boxes) * (max_coordinate + 1)
|
||||
boxes_for_nms = boxes + offsets[:, None]
|
||||
keep = nms(boxes_for_nms, scores, iou_threshold)
|
||||
return keep
|
|
@ -0,0 +1,38 @@
|
|||
import math
|
||||
import torch
|
||||
from torch.nn.modules.utils import _ntuple
|
||||
|
||||
class Conv2dv2(torch.nn.Conv2d):
|
||||
"""
|
||||
A wrapper around :class:`torch.nn.Conv2d` to support more features.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
||||
|
||||
Args:
|
||||
norm (nn.Module, optional): a normalization layer
|
||||
activation (callable(Tensor) -> Tensor): a callable activation function
|
||||
|
||||
It assumes that norm layer is used before activation.
|
||||
"""
|
||||
norm = kwargs.pop("norm", None)
|
||||
activation = kwargs.pop("activation", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.norm = norm
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
if x.numel() == 0 and self.training:
|
||||
# https://github.com/pytorch/pytorch/issues/12013
|
||||
assert not isinstance(
|
||||
self.norm, torch.nn.SyncBatchNorm
|
||||
), "SyncBatchNorm does not support empty inputs!"
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
x = super().forward(x)
|
||||
return x
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from detectron2.structures import Instances
|
||||
from .layers.nms import nms # BC-compat
|
||||
|
||||
def extractor_postprocess(boxes, scores, features_pooled, input_per_image, extractor):
|
||||
"""
|
||||
Resize the output instances.
|
||||
The input images are often resized when entering an object detector.
|
||||
As a result, we often need the outputs of the detector in a different
|
||||
resolution from its inputs.
|
||||
|
||||
This function will resize the raw outputs of an R-CNN detector
|
||||
to produce outputs according to the desired output resolution.
|
||||
|
||||
Args:
|
||||
results (Instances): the raw outputs from the detector.
|
||||
`results.image_size` contains the input image resolution the detector sees.
|
||||
This object might be modified in-place.
|
||||
output_height, output_width: the desired output resolution.
|
||||
|
||||
Returns:
|
||||
Instances: the resized output from the model, based on the output resolution
|
||||
"""
|
||||
MIN_BOXES = extractor.MIN_BOXES
|
||||
MAX_BOXES = extractor.MAX_BOXES
|
||||
CONF_THRESH = extractor.CONF_THRESH
|
||||
|
||||
cur_device = scores.device
|
||||
|
||||
dets = boxes / input_per_image["im_scale"]
|
||||
|
||||
max_conf = torch.zeros((scores.shape[0])).to(cur_device)
|
||||
|
||||
for cls_ind in range(1, scores.shape[1]):
|
||||
cls_scores = scores[:, cls_ind]
|
||||
keep = nms(dets, cls_scores, 0.3)
|
||||
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep],
|
||||
cls_scores[keep],
|
||||
max_conf[keep])
|
||||
|
||||
keep_boxes = torch.nonzero(max_conf >= CONF_THRESH).flatten()
|
||||
if len(keep_boxes) < MIN_BOXES:
|
||||
keep_boxes = torch.argsort(max_conf, descending=True)[:MIN_BOXES]
|
||||
elif len(keep_boxes) > MAX_BOXES:
|
||||
keep_boxes = torch.argsort(max_conf, descending=True)[:MAX_BOXES]
|
||||
# keep_boxes = torch.argsort(max_conf, descending=True)[:100]
|
||||
# feat_list.append(feats[i][keep_boxes])
|
||||
image_feat = features_pooled[keep_boxes]
|
||||
image_bboxes = dets[keep_boxes]
|
||||
|
||||
return image_feat, image_bboxes
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import logging, os
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from detectron2.structures import ImageList
|
||||
from detectron2.utils.logger import log_first_n
|
||||
|
||||
from detectron2.modeling.backbone import build_backbone
|
||||
from detectron2.modeling.postprocessing import detector_postprocess
|
||||
from detectron2.modeling.proposal_generator import build_proposal_generator
|
||||
from detectron2.modeling.roi_heads import build_roi_heads
|
||||
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
|
||||
# from models.bua_caffe.postprocessing import extractor_postprocess
|
||||
#from utils import save_features
|
||||
|
||||
__all__ = ["GeneralizedBUARCNN"]
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class GeneralizedBUARCNN(nn.Module):
|
||||
"""
|
||||
Generalized R-CNN. Any models that contains the following three components:
|
||||
1. Per-image feature extraction (aka backbone)
|
||||
2. Region proposal generation
|
||||
3. Per-region feature extraction and prediction
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.device = torch.device(cfg.MODEL.DEVICE)
|
||||
self.bua_caffe = cfg.MODEL.BUA.CAFFE
|
||||
self.resnet_version = cfg.MODEL.BUA.RESNET_VERSION
|
||||
self.backbone = build_backbone(cfg)
|
||||
self.in_features = cfg.MODEL.RPN.IN_FEATURES
|
||||
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
|
||||
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
|
||||
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
self.extract_on = cfg.MODEL.BUA.EXTRACT_FEATS
|
||||
self.extractor = cfg.MODEL.BUA.EXTRACTOR
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
"""
|
||||
Args:
|
||||
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
|
||||
Each item in the list contains the inputs for one image.
|
||||
For now, each item in the list is a dict that contains:
|
||||
|
||||
* image: Tensor, image in (C, H, W) format.
|
||||
* instances (optional): groundtruth :class:`Instances`
|
||||
* proposals (optional): :class:`Instances`, precomputed proposals.
|
||||
|
||||
Other information that's included in the original dicts, such as:
|
||||
|
||||
* "height", "width" (int): the output resolution of the model, used in inference.
|
||||
See :meth:`postprocess` for details.
|
||||
|
||||
Returns:
|
||||
list[dict]:
|
||||
Each dict is the output for one input image.
|
||||
The dict contains one key "instances" whose value is a :class:`Instances`.
|
||||
The :class:`Instances` object has the following keys:
|
||||
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
|
||||
"""
|
||||
if not self.training:
|
||||
return self.inference(batched_inputs)
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
if "instances" in batched_inputs[0]:
|
||||
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
||||
elif "targets" in batched_inputs[0]:
|
||||
log_first_n(
|
||||
logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
|
||||
)
|
||||
gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
|
||||
else:
|
||||
gt_instances = None
|
||||
|
||||
features = self.backbone(images.tensor)
|
||||
|
||||
if self.resnet_version == 2:
|
||||
for f in features:
|
||||
out = self.roi_heads.res5[0].norm(features[f])
|
||||
features[f] = F.relu_(out)
|
||||
|
||||
if self.proposal_generator:
|
||||
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
||||
else:
|
||||
assert "proposals" in batched_inputs[0]
|
||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||
proposal_losses = {}
|
||||
|
||||
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
|
||||
|
||||
losses = {}
|
||||
losses.update(detector_losses)
|
||||
losses.update(proposal_losses)
|
||||
return losses
|
||||
|
||||
def inference(self, batched_inputs, detected_instances=None, do_postprocess=True):
|
||||
"""
|
||||
Run inference on the given inputs.
|
||||
|
||||
Args:
|
||||
batched_inputs (list[dict]): same as in :meth:`forward`
|
||||
detected_instances (None or list[Instances]): if not None, it
|
||||
contains an `Instances` object per image. The `Instances`
|
||||
object contains "pred_boxes" and "pred_classes" which are
|
||||
known boxes in the image.
|
||||
The inference will then skip the detection of bounding boxes,
|
||||
and only predict other per-ROI outputs.
|
||||
do_postprocess (bool): whether to apply post-processing on the outputs.
|
||||
|
||||
Returns:
|
||||
same as in :meth:`forward`.
|
||||
"""
|
||||
assert not self.training
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
features = self.backbone(images.tensor)
|
||||
|
||||
if self.resnet_version == 2:
|
||||
for f in features:
|
||||
out = self.roi_heads.res5[0].norm(features[f])
|
||||
features[f] = F.relu_(out)
|
||||
|
||||
if detected_instances is None:
|
||||
if self.proposal_generator:
|
||||
proposals, _ = self.proposal_generator(images, features, None)
|
||||
else:
|
||||
assert "proposals" in batched_inputs[0]
|
||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||
|
||||
if self.extract_on:
|
||||
return self.roi_heads(images, features, proposals, None)
|
||||
else:
|
||||
results, _ = self.roi_heads(images, features, proposals, None)
|
||||
else:
|
||||
detected_instances = [x.to(self.device) for x in detected_instances]
|
||||
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
|
||||
|
||||
if do_postprocess:
|
||||
processed_results = []
|
||||
for results_per_image, input_per_image, image_size in zip(
|
||||
results, batched_inputs, images.image_sizes
|
||||
):
|
||||
height = input_per_image.get("height", image_size[0])
|
||||
width = input_per_image.get("width", image_size[1])
|
||||
if not self.bua_caffe:
|
||||
results_per_image = detector_postprocess(results_per_image, height, width)
|
||||
processed_results.append({"instances": results_per_image})
|
||||
return processed_results
|
||||
else:
|
||||
return results
|
||||
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize, pad and batch the input images.
|
||||
"""
|
||||
images = [x["image"].to(self.device) for x in batched_inputs]
|
||||
image_scales = [x["im_scale"] for x in batched_inputs]
|
||||
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
|
||||
images.image_scales = image_scales
|
||||
return images
|
|
@ -0,0 +1,469 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.utils.events import get_event_storage
|
||||
from detectron2.modeling import ROI_HEADS_REGISTRY, ROIHeads
|
||||
from detectron2.structures import Boxes, Instances, pairwise_iou
|
||||
from detectron2.modeling.sampling import subsample_labels
|
||||
from detectron2.modeling.poolers import ROIPooler
|
||||
from detectron2.modeling.backbone.resnet import BottleneckBlock
|
||||
from detectron2.modeling.proposal_generator.proposal_utils import add_ground_truth_to_proposals
|
||||
from detectron2.layers import get_norm, BatchNorm2d
|
||||
|
||||
from .fast_rcnn import BUACaffeFastRCNNOutputs, BUACaffeFastRCNNOutputLayers, BUADetection2FastRCNNOutputs, BUADetectron2FastRCNNOutputLayers
|
||||
from .box_regression import BUABox2BoxTransform
|
||||
from .backbone import BottleneckBlockv2
|
||||
|
||||
def make_stage(block_class, num_blocks, first_stride, **kwargs):
|
||||
"""
|
||||
Create a resnet stage by creating many blocks.
|
||||
Args:
|
||||
block_class (class): a subclass of ResNetBlockBase
|
||||
num_blocks (int):
|
||||
first_stride (int): the stride of the first block. The other blocks will have stride=1.
|
||||
A `stride` argument will be passed to the block constructor.
|
||||
kwargs: other arguments passed to the block constructor.
|
||||
|
||||
Returns:
|
||||
list[nn.Module]: a list of block module.
|
||||
"""
|
||||
blocks = []
|
||||
for i in range(num_blocks):
|
||||
if kwargs["dilation"] > 1:
|
||||
first_stride = 1
|
||||
blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs))
|
||||
kwargs["in_channels"] = kwargs["out_channels"]
|
||||
return blocks
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class BUACaffeRes5ROIHeads(ROIHeads):
|
||||
"""
|
||||
The ROIHeads in a typical "C4" R-CNN model, where
|
||||
the box and mask head share the cropping and
|
||||
the per-region feature computation by a Res5 block.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
# super().__init__(cfg, input_shape)
|
||||
super().__init__(cfg)
|
||||
|
||||
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
||||
self.feature_strides = {k: v.stride for k, v in input_shape.items()}
|
||||
self.cls_agnostic_bbox_reg = cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
|
||||
self.smooth_l1_beta = cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA
|
||||
assert len(self.in_features) == 1
|
||||
|
||||
# fmt: off
|
||||
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
||||
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
||||
pooler_scales = (1.0 / self.feature_strides[self.in_features[0]], )
|
||||
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
||||
self.resnet_version = cfg.MODEL.BUA.RESNET_VERSION
|
||||
self.attr_on = cfg.MODEL.BUA.ATTRIBUTE_ON
|
||||
self.extract_on = cfg.MODEL.BUA.EXTRACT_FEATS
|
||||
self.num_attr_classes = cfg.MODEL.BUA.ATTRIBUTE.NUM_CLASSES
|
||||
self.extractor_mode = cfg.MODEL.BUA.EXTRACTOR.MODE
|
||||
|
||||
self.test_score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST
|
||||
self.test_nms_thresh = cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST
|
||||
self.test_detections_per_img = cfg.TEST.DETECTIONS_PER_IMAGE
|
||||
|
||||
self.pooler = ROIPooler(
|
||||
output_size=pooler_resolution,
|
||||
scales=pooler_scales,
|
||||
sampling_ratio=sampling_ratio,
|
||||
pooler_type=pooler_type,
|
||||
)
|
||||
|
||||
self.box2box_transform = BUABox2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)
|
||||
|
||||
self.res5, out_channels = self._build_res5_block(cfg)
|
||||
if self.resnet_version == 2:
|
||||
self.res5_bn = BatchNorm2d(out_channels, eps=2e-5)
|
||||
self.box_predictor = BUACaffeFastRCNNOutputLayers(
|
||||
out_channels, self.num_classes, self.cls_agnostic_bbox_reg, attr_on=self.attr_on, num_attr_classes=self.num_attr_classes
|
||||
)
|
||||
|
||||
def _build_res5_block(self, cfg):
|
||||
# fmt: off
|
||||
stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||||
bottleneck_channels = num_groups * width_per_group * stage_channel_factor
|
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
|
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||||
norm = cfg.MODEL.RESNETS.NORM
|
||||
dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||||
assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \
|
||||
"Deformable conv is not yet supported in res5 head."
|
||||
# fmt: on
|
||||
blocks = make_stage(
|
||||
BottleneckBlock if self.resnet_version == 1 else BottleneckBlockv2,
|
||||
3,
|
||||
first_stride=2,
|
||||
in_channels=out_channels // 2,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
out_channels=out_channels,
|
||||
num_groups=num_groups,
|
||||
norm=norm,
|
||||
stride_in_1x1=stride_in_1x1,
|
||||
dilation=dilation,
|
||||
)
|
||||
return nn.Sequential(*blocks), out_channels
|
||||
|
||||
def _shared_roi_transform(self, features, boxes):
|
||||
x = self.pooler(features, boxes)
|
||||
if self.resnet_version == 2:
|
||||
out = self.res5[0].conv1(x)
|
||||
out = self.res5[0].conv2(out)
|
||||
out = self.res5[0].conv3(out)
|
||||
if self.res5[0].shortcut is not None:
|
||||
shortcut = self.res5[0].shortcut(x)
|
||||
else:
|
||||
shortcut = x
|
||||
out += shortcut
|
||||
out = self.res5[1:](out)
|
||||
return F.relu_(self.res5_bn(out))
|
||||
return self.res5(x)
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`ROIHeads.forward`.
|
||||
"""
|
||||
image_scales = images.image_scales
|
||||
del images
|
||||
|
||||
if self.training:
|
||||
proposals = self.label_and_sample_proposals(proposals, targets)
|
||||
del targets
|
||||
|
||||
proposal_boxes = [x.proposal_boxes for x in proposals]
|
||||
box_features = self._shared_roi_transform(
|
||||
[features[f] for f in self.in_features], proposal_boxes
|
||||
)
|
||||
feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1
|
||||
if self.attr_on:
|
||||
pred_class_logits, pred_proposal_deltas, attr_scores = self.box_predictor(feature_pooled, proposals)
|
||||
else:
|
||||
pred_class_logits, pred_proposal_deltas = self.box_predictor(feature_pooled, proposals)
|
||||
if not self.extract_on:
|
||||
del feature_pooled
|
||||
|
||||
outputs = BUACaffeFastRCNNOutputs(
|
||||
self.box2box_transform,
|
||||
pred_class_logits,
|
||||
pred_proposal_deltas,
|
||||
proposals,
|
||||
self.smooth_l1_beta,
|
||||
image_scales
|
||||
)
|
||||
|
||||
if self.training:
|
||||
del features
|
||||
losses = outputs.losses()
|
||||
return [], losses
|
||||
else:
|
||||
if self.extract_on:
|
||||
num_preds_per_image = [len(p) for p in proposals]
|
||||
if self.extractor_mode == 1 or self.extractor_mode == 3:
|
||||
if self.attr_on:
|
||||
return proposal_boxes, outputs.predict_probs(), feature_pooled.split(num_preds_per_image, dim=0), attr_scores.split(num_preds_per_image, dim=0)
|
||||
else:
|
||||
return proposal_boxes, outputs.predict_probs(), feature_pooled.split(num_preds_per_image, dim=0)
|
||||
elif self.extractor_mode == 2:
|
||||
return outputs.predict_boxes(), outputs.predict_probs()
|
||||
else:
|
||||
raise ValueError('BUA.EXTRATOR.MODE ERROR')
|
||||
pred_instances, _ = outputs.inference(
|
||||
self.test_score_thresh, self.test_nms_thresh, self.test_detections_per_img
|
||||
)
|
||||
return pred_instances, {}
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class BUADetectron2Res5ROIHeads(ROIHeads):
|
||||
"""
|
||||
The ROIHeads in a typical "C4" R-CNN model, where
|
||||
the box and mask head share the cropping and
|
||||
the per-region feature computation by a Res5 block.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
# super().__init__(cfg, input_shape)
|
||||
super().__init__(cfg)
|
||||
|
||||
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
||||
self.feature_strides = {k: v.stride for k, v in input_shape.items()}
|
||||
self.cls_agnostic_bbox_reg = cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
|
||||
self.smooth_l1_beta = cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA
|
||||
self.positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
|
||||
assert len(self.in_features) == 1
|
||||
|
||||
# fmt: off
|
||||
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
||||
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
||||
pooler_scales = (1.0 / self.feature_strides[self.in_features[0]], )
|
||||
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
||||
self.resnet_version = cfg.MODEL.BUA.RESNET_VERSION
|
||||
self.attr_on = cfg.MODEL.BUA.ATTRIBUTE_ON
|
||||
self.extract_on = cfg.MODEL.BUA.EXTRACT_FEATS
|
||||
self.num_attr_classes = cfg.MODEL.BUA.ATTRIBUTE.NUM_CLASSES
|
||||
self.extractor_mode = cfg.MODEL.BUA.EXTRACTOR.MODE
|
||||
|
||||
self.test_score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST
|
||||
self.test_nms_thresh = cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST
|
||||
self.test_detections_per_img = cfg.TEST.DETECTIONS_PER_IMAGE
|
||||
|
||||
self.pooler = ROIPooler(
|
||||
output_size=pooler_resolution,
|
||||
scales=pooler_scales,
|
||||
sampling_ratio=sampling_ratio,
|
||||
pooler_type=pooler_type,
|
||||
)
|
||||
|
||||
self.box2box_transform = BUABox2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)
|
||||
|
||||
self.res5, out_channels = self._build_res5_block(cfg)
|
||||
if self.resnet_version == 2:
|
||||
self.res5_bn = BatchNorm2d(out_channels, eps=2e-5)
|
||||
self.box_predictor = BUADetectron2FastRCNNOutputLayers(
|
||||
out_channels, self.num_classes, self.cls_agnostic_bbox_reg, \
|
||||
attr_on=self.attr_on, num_attr_classes=self.num_attr_classes
|
||||
)
|
||||
|
||||
def _sample_proposals(self, matched_idxs, matched_labels, gt_classes, gt_attributes):
|
||||
"""
|
||||
Based on the matching between N proposals and M groundtruth,
|
||||
sample the proposals and set their classification labels.
|
||||
|
||||
Args:
|
||||
matched_idxs (Tensor): a vector of length N, each is the best-matched
|
||||
gt index in [0, M) for each proposal.
|
||||
matched_labels (Tensor): a vector of length N, the matcher's label
|
||||
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
|
||||
gt_classes (Tensor): a vector of length M.
|
||||
|
||||
Returns:
|
||||
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
|
||||
Tensor: a vector of the same length, the classification label for
|
||||
each sampled proposal. Each sample is labeled as either a category in
|
||||
[0, num_classes) or the background (num_classes).
|
||||
"""
|
||||
has_gt = gt_classes.numel() > 0
|
||||
# Get the corresponding GT for each proposal
|
||||
if has_gt:
|
||||
gt_classes = gt_classes[matched_idxs]
|
||||
gt_attributes = gt_attributes[matched_idxs, :]
|
||||
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
|
||||
gt_classes[matched_labels == 0] = self.num_classes
|
||||
# Label ignore proposals (-1 label)
|
||||
gt_classes[matched_labels == -1] = -1
|
||||
else:
|
||||
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
|
||||
gt_clagt_attributes = -torch.ones((len(matched_idxs),16), dtype=torch.int64).cuda()
|
||||
|
||||
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
|
||||
gt_classes, self.batch_size_per_image, self.positive_sample_fraction, self.num_classes
|
||||
)
|
||||
|
||||
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
|
||||
return sampled_idxs, gt_classes[sampled_idxs], gt_attributes[sampled_idxs]
|
||||
|
||||
def _build_res5_block(self, cfg):
|
||||
# fmt: off
|
||||
stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||||
bottleneck_channels = num_groups * width_per_group * stage_channel_factor
|
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
|
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||||
norm = cfg.MODEL.RESNETS.NORM
|
||||
dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||||
assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \
|
||||
"Deformable conv is not yet supported in res5 head."
|
||||
# fmt: on
|
||||
|
||||
blocks = make_stage(
|
||||
BottleneckBlock if self.resnet_version == 1 else BottleneckBlockv2,
|
||||
3,
|
||||
first_stride=2,
|
||||
in_channels=out_channels // 2,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
out_channels=out_channels,
|
||||
num_groups=num_groups,
|
||||
norm=norm,
|
||||
stride_in_1x1=stride_in_1x1,
|
||||
dilation=dilation,
|
||||
)
|
||||
return nn.Sequential(*blocks), out_channels
|
||||
|
||||
def _shared_roi_transform(self, features, boxes):
|
||||
x = self.pooler(features, boxes)
|
||||
if self.resnet_version == 2:
|
||||
out = self.res5[0].conv1(x)
|
||||
out = self.res5[0].conv2(out)
|
||||
out = self.res5[0].conv3(out)
|
||||
if self.res5[0].shortcut is not None:
|
||||
shortcut = self.res5[0].shortcut(x)
|
||||
else:
|
||||
shortcut = x
|
||||
out += shortcut
|
||||
out = self.res5[1:](out)
|
||||
return F.relu_(self.res5_bn(out))
|
||||
return self.res5(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def label_and_sample_proposals(self, proposals, targets):
|
||||
"""
|
||||
Prepare some proposals to be used to train the ROI heads.
|
||||
It performs box matching between `proposals` and `targets`, and assigns
|
||||
training labels to the proposals.
|
||||
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
|
||||
boxes, with a fraction of positives that is no larger than
|
||||
``self.positive_sample_fraction``.
|
||||
|
||||
Args:
|
||||
See :meth:`ROIHeads.forward`
|
||||
|
||||
Returns:
|
||||
list[Instances]:
|
||||
length `N` list of `Instances`s containing the proposals
|
||||
sampled for training. Each `Instances` has the following fields:
|
||||
|
||||
- proposal_boxes: the proposal boxes
|
||||
- gt_boxes: the ground-truth box that the proposal is assigned to
|
||||
(this is only meaningful if the proposal has a label > 0; if label = 0
|
||||
then the ground-truth box is random)
|
||||
|
||||
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
|
||||
"""
|
||||
gt_boxes = [x.gt_boxes for x in targets]
|
||||
# Augment proposals with ground-truth boxes.
|
||||
# In the case of learned proposals (e.g., RPN), when training starts
|
||||
# the proposals will be low quality due to random initialization.
|
||||
# It's possible that none of these initial
|
||||
# proposals have high enough overlap with the gt objects to be used
|
||||
# as positive examples for the second stage components (box head,
|
||||
# cls head, mask head). Adding the gt boxes to the set of proposals
|
||||
# ensures that the second stage components will have some positive
|
||||
# examples from the start of training. For RPN, this augmentation improves
|
||||
# convergence and empirically improves box AP on COCO by about 0.5
|
||||
# points (under one tested configuration).
|
||||
if self.proposal_append_gt:
|
||||
proposals = add_ground_truth_to_proposals(gt_boxes, proposals)
|
||||
|
||||
proposals_with_gt = []
|
||||
|
||||
num_fg_samples = []
|
||||
num_bg_samples = []
|
||||
for proposals_per_image, targets_per_image in zip(proposals, targets):
|
||||
has_gt = len(targets_per_image) > 0
|
||||
match_quality_matrix = pairwise_iou(
|
||||
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
|
||||
)
|
||||
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
|
||||
sampled_idxs, gt_classes, gt_attributes = self._sample_proposals(
|
||||
matched_idxs, matched_labels, targets_per_image.gt_classes, targets_per_image.gt_attributes
|
||||
)
|
||||
|
||||
# Set target attributes of the sampled proposals:
|
||||
proposals_per_image = proposals_per_image[sampled_idxs]
|
||||
proposals_per_image.gt_classes = gt_classes
|
||||
proposals_per_image.gt_attributes = gt_attributes
|
||||
|
||||
# We index all the attributes of targets that start with "gt_"
|
||||
# and have not been added to proposals yet (="gt_classes").
|
||||
if has_gt:
|
||||
sampled_targets = matched_idxs[sampled_idxs]
|
||||
# NOTE: here the indexing waste some compute, because heads
|
||||
# like masks, keypoints, etc, will filter the proposals again,
|
||||
# (by foreground/background, or number of keypoints in the image, etc)
|
||||
# so we essentially index the data twice.
|
||||
for (trg_name, trg_value) in targets_per_image.get_fields().items():
|
||||
if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name):
|
||||
proposals_per_image.set(trg_name, trg_value[sampled_targets])
|
||||
else:
|
||||
gt_boxes = Boxes(
|
||||
targets_per_image.gt_boxes.tensor.new_zeros((len(sampled_idxs), 4))
|
||||
)
|
||||
proposals_per_image.gt_boxes = gt_boxes
|
||||
|
||||
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
|
||||
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
|
||||
proposals_with_gt.append(proposals_per_image)
|
||||
|
||||
# Log the number of fg/bg samples that are selected for training ROI heads
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
|
||||
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
|
||||
|
||||
return proposals_with_gt
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`ROIHeads.forward`.
|
||||
"""
|
||||
# image_scales = images.image_scales
|
||||
del images
|
||||
|
||||
if self.training:
|
||||
proposals = self.label_and_sample_proposals(proposals, targets)
|
||||
del targets
|
||||
|
||||
proposal_boxes = [x.proposal_boxes for x in proposals]
|
||||
box_features = self._shared_roi_transform(
|
||||
[features[f] for f in self.in_features], proposal_boxes
|
||||
)
|
||||
feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1
|
||||
if self.attr_on:
|
||||
pred_class_logits, pred_proposal_deltas, pred_attribute_logits, gt_attributes = self.box_predictor(feature_pooled, proposals)
|
||||
else:
|
||||
pred_class_logits, pred_proposal_deltas = self.box_predictor(feature_pooled, proposals)
|
||||
if not self.extract_on:
|
||||
del feature_pooled
|
||||
|
||||
if self.attr_on:
|
||||
outputs = BUADetection2FastRCNNOutputs(
|
||||
self.box2box_transform,
|
||||
pred_class_logits,
|
||||
pred_proposal_deltas,
|
||||
proposals,
|
||||
self.smooth_l1_beta,
|
||||
self.attr_on,
|
||||
pred_attribute_logits=pred_attribute_logits,
|
||||
num_attr_classes=self.num_attr_classes,
|
||||
gt_attributes=gt_attributes,
|
||||
)
|
||||
else:
|
||||
outputs = BUADetection2FastRCNNOutputs(
|
||||
self.box2box_transform,
|
||||
pred_class_logits,
|
||||
pred_proposal_deltas,
|
||||
proposals,
|
||||
self.smooth_l1_beta,
|
||||
self.attr_on,
|
||||
)
|
||||
|
||||
if self.training:
|
||||
del features
|
||||
losses = outputs.losses()
|
||||
return [], losses
|
||||
else:
|
||||
if self.extract_on:
|
||||
num_preds_per_image = [len(p) for p in proposals]
|
||||
if self.extractor_mode == 1 or self.extractor_mode == 3:
|
||||
if self.attr_on:
|
||||
return proposal_boxes, outputs.predict_probs(), feature_pooled.split(num_preds_per_image, dim=0), F.softmax(pred_attribute_logits, dim=-1).split(num_preds_per_image, dim=0)
|
||||
else:
|
||||
return proposal_boxes, outputs.predict_probs(), feature_pooled.split(num_preds_per_image, dim=0)
|
||||
elif self.extractor_mode == 2:
|
||||
return outputs.predict_boxes(), outputs.predict_probs()
|
||||
else:
|
||||
raise ValueError('BUA.EXTRATOR.MODE ERROR')
|
||||
pred_instances, _ = outputs.inference(
|
||||
self.test_score_thresh, self.test_nms_thresh, self.test_detections_per_img
|
||||
)
|
||||
return pred_instances, {}
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from detectron2.modeling import RPN_HEAD_REGISTRY
|
||||
from detectron2.layers import ShapeSpec
|
||||
|
||||
from detectron2.modeling.proposal_generator import build_rpn_head
|
||||
from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY
|
||||
from detectron2.modeling.anchor_generator import build_anchor_generator
|
||||
from .box_regression import BUABox2BoxTransform
|
||||
from detectron2.modeling.matcher import Matcher
|
||||
from .rpn_outputs import BUARPNOutputs, find_top_bua_rpn_proposals
|
||||
|
||||
import copy
|
||||
|
||||
@RPN_HEAD_REGISTRY.register()
|
||||
class StandardBUARPNHead(nn.Module):
|
||||
"""
|
||||
RPN classification and regression heads. Uses a 3x3 conv to produce a shared
|
||||
hidden state from which one 1x1 conv predicts objectness logits for each anchor
|
||||
and a second 1x1 conv predicts bounding-box deltas specifying how to deform
|
||||
each anchor into an object proposal.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape: List[ShapeSpec]):
|
||||
super().__init__()
|
||||
|
||||
# Standard RPN is shared across levels:
|
||||
out_channels = cfg.MODEL.BUA.RPN.CONV_OUT_CHANNELS
|
||||
|
||||
in_channels = [s.channels for s in input_shape]
|
||||
assert len(set(in_channels)) == 1, "Each level must have the same channel!"
|
||||
in_channels = in_channels[0]
|
||||
|
||||
# RPNHead should take the same input as anchor generator
|
||||
# NOTE: it assumes that creating an anchor generator does not have unwanted side effect.
|
||||
anchor_generator = build_anchor_generator(cfg, input_shape)
|
||||
num_cell_anchors = anchor_generator.num_cell_anchors
|
||||
box_dim = anchor_generator.box_dim
|
||||
assert (
|
||||
len(set(num_cell_anchors)) == 1
|
||||
), "Each level must have the same number of cell anchors"
|
||||
num_cell_anchors = num_cell_anchors[0]
|
||||
|
||||
# 3x3 conv for the hidden representation
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# 1x1 conv for predicting objectness logits
|
||||
self.objectness_logits = nn.Conv2d(out_channels, num_cell_anchors * 2, kernel_size=1, stride=1)
|
||||
# 1x1 conv for predicting box2box transform deltas
|
||||
self.anchor_deltas = nn.Conv2d(
|
||||
out_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1
|
||||
)
|
||||
|
||||
for l in [self.conv, self.objectness_logits, self.anchor_deltas]:
|
||||
nn.init.normal_(l.weight, std=0.01)
|
||||
nn.init.constant_(l.bias, 0)
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Args:
|
||||
features (list[Tensor]): list of feature maps
|
||||
"""
|
||||
pred_objectness_logits = []
|
||||
pred_anchor_deltas = []
|
||||
for x in features:
|
||||
t = F.relu(self.conv(x))
|
||||
pred_objectness_logits.append(self.objectness_logits(t))
|
||||
pred_anchor_deltas.append(self.anchor_deltas(t))
|
||||
return pred_objectness_logits, pred_anchor_deltas
|
||||
|
||||
@PROPOSAL_GENERATOR_REGISTRY.register()
|
||||
class BUARPN(nn.Module):
|
||||
"""
|
||||
Region Proposal Network, introduced by the Faster R-CNN paper.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
||||
super().__init__()
|
||||
|
||||
# fmt: off
|
||||
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
|
||||
self.in_features = cfg.MODEL.RPN.IN_FEATURES
|
||||
self.nms_thresh = cfg.MODEL.RPN.NMS_THRESH
|
||||
self.batch_size_per_image = cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE
|
||||
self.positive_fraction = cfg.MODEL.RPN.POSITIVE_FRACTION
|
||||
self.smooth_l1_beta = cfg.MODEL.RPN.SMOOTH_L1_BETA
|
||||
self.loss_weight = cfg.MODEL.RPN.LOSS_WEIGHT
|
||||
# fmt: on
|
||||
|
||||
# Map from self.training state to train/test settings
|
||||
self.pre_nms_topk = {
|
||||
True: cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN,
|
||||
False: cfg.MODEL.RPN.PRE_NMS_TOPK_TEST,
|
||||
}
|
||||
self.post_nms_topk = {
|
||||
True: cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN,
|
||||
False: cfg.MODEL.RPN.POST_NMS_TOPK_TEST,
|
||||
}
|
||||
self.boundary_threshold = cfg.MODEL.RPN.BOUNDARY_THRESH
|
||||
|
||||
self.anchor_generator = build_anchor_generator(
|
||||
cfg, [input_shape[f] for f in self.in_features]
|
||||
)
|
||||
self.box2box_transform = BUABox2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
|
||||
self.anchor_matcher = Matcher(
|
||||
cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True
|
||||
)
|
||||
self.rpn_head = build_rpn_head(cfg, [input_shape[f] for f in self.in_features])
|
||||
|
||||
def forward(self, images, features, gt_instances=None):
|
||||
"""
|
||||
Args:
|
||||
images (ImageList): input images of length `N`
|
||||
features (dict[str: Tensor]): input data as a mapping from feature
|
||||
map name to tensor. Axis 0 represents the number of images `N` in
|
||||
the input data; axes 1-3 are channels, height, and width, which may
|
||||
vary between feature maps (e.g., if a feature pyramid is used).
|
||||
gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
|
||||
Each `Instances` stores ground-truth instances for the corresponding image.
|
||||
|
||||
Returns:
|
||||
proposals: list[Instances] or None
|
||||
loss: dict[Tensor]
|
||||
"""
|
||||
gt_boxes = [x.gt_boxes for x in gt_instances] if gt_instances is not None else None
|
||||
del gt_instances
|
||||
features = [features[f] for f in self.in_features]
|
||||
pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
|
||||
anchors_in_image = self.anchor_generator(features)
|
||||
anchors = [copy.deepcopy(anchors_in_image) for _ in range(len(features[0]))]
|
||||
# TODO: The anchors only depend on the feature map shape; there's probably
|
||||
# an opportunity for some optimizations (e.g., caching anchors).
|
||||
outputs = BUARPNOutputs(
|
||||
self.box2box_transform,
|
||||
self.anchor_matcher,
|
||||
self.batch_size_per_image,
|
||||
self.positive_fraction,
|
||||
images,
|
||||
pred_objectness_logits,
|
||||
pred_anchor_deltas,
|
||||
anchors,
|
||||
self.boundary_threshold,
|
||||
gt_boxes,
|
||||
self.smooth_l1_beta,
|
||||
)
|
||||
|
||||
if self.training:
|
||||
losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
|
||||
else:
|
||||
losses = {}
|
||||
|
||||
with torch.no_grad():
|
||||
# Find the top proposals by applying NMS and removing boxes that
|
||||
# are too small. The proposals are treated as fixed for approximate
|
||||
# joint training with roi heads. This approach ignores the derivative
|
||||
# w.r.t. the proposal boxes’ coordinates that are also network
|
||||
# responses, so is approximate.
|
||||
proposals = find_top_bua_rpn_proposals(
|
||||
outputs.predict_proposals(),
|
||||
outputs.predict_objectness_logits(),
|
||||
images,
|
||||
self.nms_thresh,
|
||||
self.pre_nms_topk[self.training],
|
||||
self.post_nms_topk[self.training],
|
||||
self.min_box_side_len,
|
||||
self.training,
|
||||
)
|
||||
# For RPN-only models, the proposals are the final output and we return them in
|
||||
# high-to-low confidence order.
|
||||
# For end-to-end models, the RPN proposals are an intermediate state
|
||||
# and this sorting is actually not needed. But the cost is negligible.
|
||||
# inds = [p.objectness_logits.sort(descending=True)[1] for p in proposals]
|
||||
# proposals = [p[ind] for p, ind in zip(proposals, inds)]
|
||||
|
||||
return proposals, losses
|
|
@ -0,0 +1,404 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import itertools
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fvcore.nn import smooth_l1_loss
|
||||
|
||||
from detectron2.layers import cat
|
||||
from detectron2.structures import Instances, pairwise_iou
|
||||
from detectron2.utils.events import get_event_storage
|
||||
|
||||
from detectron2.modeling.sampling import subsample_labels
|
||||
|
||||
from .box_regression import BUABoxes
|
||||
from .layers.nms import batched_nms
|
||||
|
||||
def find_top_bua_rpn_proposals(
|
||||
proposals,
|
||||
pred_objectness_logits,
|
||||
images,
|
||||
nms_thresh,
|
||||
pre_nms_topk,
|
||||
post_nms_topk,
|
||||
min_box_side_len,
|
||||
training,
|
||||
):
|
||||
"""
|
||||
For each feature map, select the `pre_nms_topk` highest scoring proposals,
|
||||
apply NMS, clip proposals, and remove small boxes. Return the `post_nms_topk`
|
||||
highest scoring proposals among all the feature maps if `training` is True,
|
||||
otherwise, returns the highest `post_nms_topk` scoring proposals for each
|
||||
feature map.
|
||||
|
||||
Args:
|
||||
proposals (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A, 4).
|
||||
All proposal predictions on the feature maps.
|
||||
pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A).
|
||||
images (ImageList): Input images as an :class:`ImageList`.
|
||||
nms_thresh (float): IoU threshold to use for NMS
|
||||
pre_nms_topk (int): number of top k scoring proposals to keep before applying NMS.
|
||||
When RPN is run on multiple feature maps (as in FPN) this number is per
|
||||
feature map.
|
||||
post_nms_topk (int): number of top k scoring proposals to keep after applying NMS.
|
||||
When RPN is run on multiple feature maps (as in FPN) this number is total,
|
||||
over all feature maps.
|
||||
min_box_side_len (float): minimum proposal box side length in pixels (absolute units
|
||||
wrt input images).
|
||||
training (bool): True if proposals are to be used in training, otherwise False.
|
||||
This arg exists only to support a legacy bug; look for the "NB: Legacy bug ..."
|
||||
comment.
|
||||
|
||||
Returns:
|
||||
proposals (list[Instances]): list of N Instances. The i-th Instances
|
||||
stores post_nms_topk object proposals for image i.
|
||||
"""
|
||||
image_sizes = images.image_sizes # in (h, w) order
|
||||
image_scales = images.image_scales
|
||||
device = proposals[0].device
|
||||
|
||||
# 1. Concat all levels together
|
||||
all_scores = []
|
||||
all_proposals = []
|
||||
level_ids = []
|
||||
for level_id, proposals_i, logits_i in zip(
|
||||
itertools.count(), proposals, pred_objectness_logits
|
||||
):
|
||||
Hi_Wi_A = logits_i.shape[1]
|
||||
all_proposals.append(proposals_i)
|
||||
all_scores.append(logits_i)
|
||||
level_ids.append(torch.full((Hi_Wi_A,), level_id, dtype=torch.int64, device=device))
|
||||
|
||||
all_scores = cat(all_scores, dim=1)
|
||||
all_proposals = cat(all_proposals, dim=1)
|
||||
level_ids = cat(level_ids, dim=0)
|
||||
|
||||
# 2. For each image, run a choose pre_nms_topk proposal ,per-level NMS, and choose post_nms_topk results.
|
||||
results = []
|
||||
for n, image_size in enumerate(image_sizes):
|
||||
boxes = BUABoxes(all_proposals[n])
|
||||
scores_per_img = all_scores[n]
|
||||
boxes.clip(image_size)
|
||||
keep = boxes.filter_boxes()
|
||||
boxes = boxes[keep]
|
||||
scores_per_img = scores_per_img[keep]
|
||||
lvl = level_ids[keep]
|
||||
|
||||
# filter empty boxes
|
||||
keep = boxes.nonempty(threshold=min_box_side_len*image_scales[n])
|
||||
if keep.sum().item() != len(boxes):
|
||||
boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep]
|
||||
|
||||
# choose pre_nms_topk proposal
|
||||
Hi_Wi_A = scores_per_img.shape[0]
|
||||
num_proposals_i = min(pre_nms_topk, Hi_Wi_A)
|
||||
|
||||
scores_per_img, idx = scores_per_img.sort(descending=True, dim=0)
|
||||
topk_scores_i = scores_per_img[:num_proposals_i]
|
||||
topk_idx = idx[:num_proposals_i]
|
||||
topk_boxes_i = boxes[topk_idx, :]
|
||||
lvl_i = lvl[topk_idx]
|
||||
|
||||
keep = batched_nms(topk_boxes_i.tensor, topk_scores_i, lvl_i, nms_thresh)
|
||||
# In Detectron1, there was different behavior during training vs. testing.
|
||||
# (https://github.com/facebookresearch/Detectron/issues/459)
|
||||
# During training, topk is over the proposals from *all* images in the training batch.
|
||||
# During testing, it is over the proposals for each image separately.
|
||||
# As a result, the training behavior becomes batch-dependent,
|
||||
# and the configuration "POST_NMS_TOPK_TRAIN" end up relying on the batch size.
|
||||
# This bug is addressed in Detectron2 to make the behavior independent of batch size.
|
||||
keep = keep[:post_nms_topk]
|
||||
|
||||
res = Instances(image_size)
|
||||
res.proposal_boxes = topk_boxes_i[keep]
|
||||
res.objectness_logits = topk_scores_i[keep]
|
||||
results.append(res)
|
||||
return results
|
||||
|
||||
class BUARPNOutputs(object):
|
||||
def __init__(
|
||||
self,
|
||||
box2box_transform,
|
||||
anchor_matcher,
|
||||
batch_size_per_image,
|
||||
positive_fraction,
|
||||
images,
|
||||
pred_objectness_logits,
|
||||
pred_anchor_deltas,
|
||||
anchors,
|
||||
boundary_threshold=0,
|
||||
gt_boxes=None,
|
||||
smooth_l1_beta=0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
box2box_transform (Box2BoxTransform): :class:`Box2BoxTransform` instance for
|
||||
anchor-proposal transformations.
|
||||
anchor_matcher (Matcher): :class:`Matcher` instance for matching anchors to
|
||||
ground-truth boxes; used to determine training labels.
|
||||
batch_size_per_image (int): number of proposals to sample when training
|
||||
positive_fraction (float): target fraction of sampled proposals that should be positive
|
||||
images (ImageList): :class:`ImageList` instance representing N input images
|
||||
pred_objectness_logits (list[Tensor]): A list of L elements.
|
||||
Element i is a tensor of shape (N, A, Hi, Wi) representing
|
||||
the predicted objectness logits for anchors.
|
||||
pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape
|
||||
(N, A*4, Hi, Wi) representing the predicted "deltas" used to transform anchors
|
||||
to proposals.
|
||||
anchors (list[list[Boxes]]): A list of N elements. Each element is a list of L
|
||||
Boxes. The Boxes at (n, l) stores the entire anchor array for feature map l in image
|
||||
n (i.e. the cell anchors repeated over all locations in feature map (n, l)).
|
||||
boundary_threshold (int): if >= 0, then anchors that extend beyond the image
|
||||
boundary by more than boundary_thresh are not used in training. Set to a very large
|
||||
number or < 0 to disable this behavior. Only needed in training.
|
||||
gt_boxes (list[Boxes], optional): A list of N elements. Element i a Boxes storing
|
||||
the ground-truth ("gt") boxes for image i.
|
||||
smooth_l1_beta (float): The transition point between L1 and L2 loss in
|
||||
the smooth L1 loss function. When set to 0, the loss becomes L1. When
|
||||
set to +inf, the loss becomes constant 0.
|
||||
"""
|
||||
self.box2box_transform = box2box_transform
|
||||
self.anchor_matcher = anchor_matcher
|
||||
self.batch_size_per_image = batch_size_per_image
|
||||
self.positive_fraction = positive_fraction
|
||||
self.pred_objectness_logits = pred_objectness_logits
|
||||
self.pred_anchor_deltas = pred_anchor_deltas
|
||||
|
||||
self.anchors = anchors
|
||||
self.gt_boxes = gt_boxes
|
||||
self.num_feature_maps = len(pred_objectness_logits)
|
||||
self.num_images = len(images)
|
||||
self.image_sizes = images.image_sizes
|
||||
self.boundary_threshold = boundary_threshold
|
||||
self.smooth_l1_beta = smooth_l1_beta
|
||||
|
||||
def _get_ground_truth(self):
|
||||
"""
|
||||
Returns:
|
||||
gt_objectness_logits: list of N tensors. Tensor i is a vector whose length is the
|
||||
total number of anchors in image i (i.e., len(anchors[i])). Label values are
|
||||
in {-1, 0, 1}, with meanings: -1 = ignore; 0 = negative class; 1 = positive class.
|
||||
gt_anchor_deltas: list of N tensors. Tensor i has shape (len(anchors[i]), 4).
|
||||
"""
|
||||
gt_objectness_logits = []
|
||||
gt_anchor_deltas = []
|
||||
# Concatenate anchors from all feature maps into a single Boxes per image
|
||||
anchors = [BUABoxes.cat(anchors_i) for anchors_i in self.anchors]
|
||||
for image_size_i, anchors_i, gt_boxes_i in zip(self.image_sizes, anchors, self.gt_boxes):
|
||||
"""
|
||||
image_size_i: (h, w) for the i-th image
|
||||
anchors_i: anchors for i-th image
|
||||
gt_boxes_i: ground-truth boxes for i-th image
|
||||
"""
|
||||
match_quality_matrix = pairwise_iou(gt_boxes_i, anchors_i)
|
||||
matched_idxs, gt_objectness_logits_i = self.anchor_matcher(match_quality_matrix)
|
||||
|
||||
if self.boundary_threshold >= 0:
|
||||
# Discard anchors that go out of the boundaries of the image
|
||||
# NOTE: This is legacy functionality that is turned off by default in Detectron2
|
||||
anchors_inside_image = anchors_i.inside_box(image_size_i, self.boundary_threshold)
|
||||
gt_objectness_logits_i[~anchors_inside_image] = -1
|
||||
|
||||
if len(gt_boxes_i) == 0:
|
||||
# These values won't be used anyway since the anchor is labeled as background
|
||||
gt_anchor_deltas_i = torch.zeros_like(anchors_i.tensor)
|
||||
else:
|
||||
# TODO wasted computation for ignored boxes
|
||||
matched_gt_boxes = gt_boxes_i[matched_idxs]
|
||||
gt_anchor_deltas_i = self.box2box_transform.get_deltas(
|
||||
anchors_i.tensor, matched_gt_boxes.tensor
|
||||
)
|
||||
|
||||
gt_objectness_logits.append(gt_objectness_logits_i)
|
||||
gt_anchor_deltas.append(gt_anchor_deltas_i)
|
||||
|
||||
return gt_objectness_logits, gt_anchor_deltas
|
||||
|
||||
def losses(self):
|
||||
"""
|
||||
Return the losses from a set of RPN predictions and their associated ground-truth.
|
||||
|
||||
Returns:
|
||||
dict[loss name -> loss value]: A dict mapping from loss name to loss value.
|
||||
Loss names are: `loss_rpn_cls` for objectness classification and
|
||||
`loss_rpn_loc` for proposal localization.
|
||||
"""
|
||||
|
||||
def resample(label):
|
||||
"""
|
||||
Randomly sample a subset of positive and negative examples by overwriting
|
||||
the label vector to the ignore value (-1) for all elements that are not
|
||||
included in the sample.
|
||||
"""
|
||||
pos_idx, neg_idx = subsample_labels(
|
||||
label, self.batch_size_per_image, self.positive_fraction, 0
|
||||
)
|
||||
# Fill with the ignore label (-1), then set positive and negative labels
|
||||
label.fill_(-1)
|
||||
label.scatter_(0, pos_idx, 1)
|
||||
label.scatter_(0, neg_idx, 0)
|
||||
return label
|
||||
|
||||
gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth()
|
||||
"""
|
||||
gt_objectness_logits: list of N tensors. Tensor i is a vector whose length is the
|
||||
total number of anchors in image i (i.e., len(anchors[i]))
|
||||
gt_anchor_deltas: list of N tensors. Tensor i has shape (len(anchors[i]), B),
|
||||
where B is the box dimension
|
||||
"""
|
||||
# Collect all objectness labels and delta targets over feature maps and images
|
||||
# The final ordering is L, N, H, W, A from slowest to fastest axis.
|
||||
num_anchors_per_map = [int(np.prod(x.shape[1:])/2) for x in self.pred_objectness_logits]
|
||||
num_anchors_per_image = sum(num_anchors_per_map)
|
||||
|
||||
# Stack to: (N, num_anchors_per_image)
|
||||
gt_objectness_logits = torch.stack(
|
||||
[resample(label) for label in gt_objectness_logits], dim=0
|
||||
)
|
||||
|
||||
# Log the number of positive/negative anchors per-image that's used in training
|
||||
num_pos_anchors = (gt_objectness_logits == 1).sum().item()
|
||||
num_neg_anchors = (gt_objectness_logits == 0).sum().item()
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / self.num_images)
|
||||
storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / self.num_images)
|
||||
|
||||
assert gt_objectness_logits.shape[1] == num_anchors_per_image
|
||||
# Split to tuple of L tensors, each with shape (N, num_anchors_per_map)
|
||||
gt_objectness_logits = torch.split(gt_objectness_logits, num_anchors_per_map, dim=1)
|
||||
# Concat from all feature maps
|
||||
gt_objectness_logits = cat([x.flatten() for x in gt_objectness_logits], dim=0)
|
||||
|
||||
# Stack to: (N, num_anchors_per_image, B)
|
||||
gt_anchor_deltas = torch.stack(gt_anchor_deltas, dim=0)
|
||||
assert gt_anchor_deltas.shape[1] == num_anchors_per_image
|
||||
B = gt_anchor_deltas.shape[2] # box dimension (4 or 5)
|
||||
|
||||
# Split to tuple of L tensors, each with shape (N, num_anchors_per_image)
|
||||
gt_anchor_deltas = torch.split(gt_anchor_deltas, num_anchors_per_map, dim=1)
|
||||
# Concat from all feature maps
|
||||
gt_anchor_deltas = cat([x.reshape(-1, B) for x in gt_anchor_deltas], dim=0)
|
||||
|
||||
# Collect all objectness logits and delta predictions over feature maps
|
||||
# and images to arrive at the same shape as the labels and targets
|
||||
# The final ordering is L, N, H, W, 2A from slowest to fastest axis.
|
||||
pred_objectness_logits = cat(
|
||||
[
|
||||
# Reshape: (N, 2A, Hi, Wi) -> (N, Hi, Wi, 2A) -> (N*Hi*Wi*A, 2)
|
||||
x.permute(0, 2, 3, 1).reshape(-1, 2)
|
||||
for x in self.pred_objectness_logits
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
pred_anchor_deltas = cat(
|
||||
[
|
||||
# Reshape: (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B)
|
||||
# -> (N*Hi*Wi*A, B)
|
||||
x.view(x.shape[0], -1, B, x.shape[-2], x.shape[-1])
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.reshape(-1, B)
|
||||
for x in self.pred_anchor_deltas
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
objectness_loss, localization_loss = bua_rpn_losses(
|
||||
gt_objectness_logits,
|
||||
gt_anchor_deltas,
|
||||
pred_objectness_logits,
|
||||
pred_anchor_deltas,
|
||||
self.smooth_l1_beta,
|
||||
)
|
||||
normalizer = 1.0 / (self.batch_size_per_image * self.num_images)
|
||||
loss_cls = objectness_loss * normalizer # cls: classification loss
|
||||
loss_loc = localization_loss * normalizer # loc: localization loss
|
||||
losses = {"loss_rpn_cls": loss_cls, "loss_rpn_loc": loss_loc}
|
||||
|
||||
return losses
|
||||
|
||||
def predict_proposals(self):
|
||||
"""
|
||||
Transform anchors into proposals by applying the predicted anchor deltas.
|
||||
|
||||
Returns:
|
||||
proposals (list[Tensor]): A list of L tensors. Tensor i has shape
|
||||
(N, Hi*Wi*A, B), where B is box dimension (4 or 5).
|
||||
"""
|
||||
proposals = []
|
||||
# Transpose anchors from images-by-feature-maps (N, L) to feature-maps-by-images (L, N)
|
||||
anchors = list(zip(*self.anchors))
|
||||
# anchors = list(zip(*[self.anchors]))
|
||||
# For each feature map
|
||||
for anchors_i, pred_anchor_deltas_i in zip(anchors, self.pred_anchor_deltas):
|
||||
B = anchors_i[0].tensor.size(1)
|
||||
N, _, Hi, Wi = pred_anchor_deltas_i.shape
|
||||
# Reshape: (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N*Hi*Wi*A, B)
|
||||
pred_anchor_deltas_i = (
|
||||
pred_anchor_deltas_i.view(N, -1, B, Hi, Wi).permute(0, 3, 4, 1, 2).reshape(-1, B)
|
||||
)
|
||||
# Concatenate all anchors to shape (N*Hi*Wi*A, B)
|
||||
# type(anchors_i[0]) is Boxes (B = 4) or RotatedBoxes (B = 5)
|
||||
anchors_i = type(anchors_i[0]).cat(anchors_i)
|
||||
proposals_i = self.box2box_transform.apply_deltas(
|
||||
pred_anchor_deltas_i, anchors_i.tensor
|
||||
)
|
||||
# Append feature map proposals with shape (N, Hi*Wi*A, B)
|
||||
proposals.append(proposals_i.view(N, -1, B))
|
||||
return proposals
|
||||
|
||||
def predict_objectness_logits(self):
|
||||
"""
|
||||
Return objectness logits in the same format as the proposals returned by
|
||||
:meth:`predict_proposals`.
|
||||
|
||||
Returns:
|
||||
pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape
|
||||
(N, Hi*Wi*A).
|
||||
"""
|
||||
pred_objectness_logits = [
|
||||
# Reshape: (N, 2A, Hi, Wi) -> (N, 2, A, Hi, Wi) -> (N, Hi, Wi, 1, A) -> (N, Hi*Wi*A)
|
||||
F.softmax(score.view(score.shape[0], 2, int(float(score.shape[1]) / float(2)), score.shape[2], score.shape[3]), dim=1)[:, 1:, :, :, :]\
|
||||
.permute(0, 3, 4, 1, 2).reshape(self.num_images, -1)
|
||||
for score in self.pred_objectness_logits
|
||||
]
|
||||
return pred_objectness_logits
|
||||
|
||||
|
||||
def bua_rpn_losses(
|
||||
gt_objectness_logits,
|
||||
gt_anchor_deltas,
|
||||
pred_objectness_logits,
|
||||
pred_anchor_deltas,
|
||||
smooth_l1_beta,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
gt_objectness_logits (Tensor): shape (N,), each element in {-1, 0, 1} representing
|
||||
ground-truth objectness labels with: -1 = ignore; 0 = not object; 1 = object.
|
||||
gt_anchor_deltas (Tensor): shape (N, box_dim), row i represents ground-truth
|
||||
box2box transform targets (dx, dy, dw, dh) or (dx, dy, dw, dh, da) that map anchor i to
|
||||
its matched ground-truth box.
|
||||
pred_objectness_logits (Tensor): shape (N, 2), each element is a predicted objectness
|
||||
logit.
|
||||
pred_anchor_deltas (Tensor): shape (N, box_dim), each row is a predicted box2box
|
||||
transform (dx, dy, dw, dh) or (dx, dy, dw, dh, da)
|
||||
smooth_l1_beta (float): The transition point between L1 and L2 loss in
|
||||
the smooth L1 loss function. When set to 0, the loss becomes L1. When
|
||||
set to +inf, the loss becomes constant 0.
|
||||
|
||||
Returns:
|
||||
objectness_loss, localization_loss, both unnormalized (summed over samples).
|
||||
"""
|
||||
pos_masks = gt_objectness_logits == 1
|
||||
localization_loss = smooth_l1_loss(
|
||||
pred_anchor_deltas[pos_masks], gt_anchor_deltas[pos_masks], smooth_l1_beta, reduction="sum"
|
||||
)
|
||||
|
||||
valid_masks = gt_objectness_logits >= 0
|
||||
objectness_loss = F.cross_entropy(
|
||||
pred_objectness_logits[valid_masks],
|
||||
gt_objectness_logits[valid_masks].to(torch.long),
|
||||
reduction="sum",
|
||||
)
|
||||
return objectness_loss, localization_loss
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import glob
|
||||
import os
|
||||
from setuptools import find_packages, setup
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
|
||||
|
||||
torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
|
||||
assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"
|
||||
|
||||
|
||||
def get_extensions():
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extensions_dir = os.path.join(this_dir, "models", "bua", "layers", "csrc")
|
||||
|
||||
main_source = os.path.join(extensions_dir, "vision.cpp")
|
||||
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
|
||||
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
|
||||
os.path.join(extensions_dir, "*.cu")
|
||||
)
|
||||
|
||||
sources = [main_source] + sources
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
extra_compile_args = {"cxx": []}
|
||||
define_macros = []
|
||||
|
||||
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
|
||||
extension = CUDAExtension
|
||||
sources += source_cuda
|
||||
define_macros += [("WITH_CUDA", None)]
|
||||
extra_compile_args["nvcc"] = [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
|
||||
# It's better if pytorch can do this by default ..
|
||||
CC = os.environ.get("CC", None)
|
||||
if CC is not None:
|
||||
extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
|
||||
|
||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
||||
|
||||
include_dirs = [extensions_dir]
|
||||
|
||||
ext_modules = [
|
||||
extension(
|
||||
"models.bua._C",
|
||||
sources,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
]
|
||||
|
||||
return ext_modules
|
||||
|
||||
|
||||
setup(
|
||||
name="bottom-up-attention.pytorch",
|
||||
packages=find_packages(exclude=("configs", "tests")),
|
||||
python_requires=">=3.6",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
||||
)
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
captions = []
|
||||
with open('../SCAN-master/data/f30k_precomp/dev_caps.txt', 'rb') as f:
|
||||
for line in f:
|
||||
captions.append(line.strip().decode('utf-8'))
|
||||
print(len(captions))
|
||||
#print(captions[:10])
|
||||
result_captions = []
|
||||
for index in range(0, len(captions), 5):
|
||||
#print(index)
|
||||
tmp = captions[index:index+5]
|
||||
tmp = tmp + tmp
|
||||
#print(tmp)
|
||||
for count in range(1, 5, 1):
|
||||
for pos in range(0, 5, 1):
|
||||
record = ""
|
||||
for c in range(0, count, 1):
|
||||
record += tmp[pos+c] + " "
|
||||
result_captions.append(record)
|
||||
print(len(result_captions))
|
||||
with open("../SCAN-master/data/f30k_precomp/shuffle_dev_caps.txt", 'w') as f:
|
||||
for line in result_captions:
|
||||
f.write(line +"\n")
|
|
@ -0,0 +1,22 @@
|
|||
from evaluation import evalrank, evalrank2,evalrank3,evalrank_vse,evalrank_maxpool,evalrank_f_c,evalrank_avgpool,evalrank_f_c2
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
||||
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
|
||||
#f30k数据集和coco数据集上实验
|
||||
evalrank_f_c("../SCAN-master/data/f30k_precomp","f30k")
|
||||
# evalrank_f_c("../SCAN-master/data/coco_precomp","coco")
|
||||
|
||||
|
||||
#evalrank_f_c2("../SCAN-master/data/coco_precomp","f30k")
|
||||
# evalrank_f_c2("../SCAN-master/data/f30k_precomp","coco")
|
||||
|
||||
#消融实验
|
||||
#evalrank_vse("./runs/bert_adam_bcan_gpo_vseinfty_vse_data_3045_lr2_vse_gpo/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank("./runs/bert_adam_bcan_gpo_vseinfty_union_2035/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
|
||||
#池化方式对比实验
|
||||
#evalrank_avgpool("./runs/bert_adam_bcan_gpo_vseinfty_bcan/model_best.pth.tar", "../SCAN-master/data/", "test")
|
||||
#evalrank_maxpool("./runs/bert_adam_bcan_gpo_vseinfty_bcan/model_best.pth.tar", "../SCAN-master/data/", "test")
|