151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
import os.path as osp
|
|
import numpy as np
|
|
import cv2
|
|
import torchvision
|
|
import utils.transforms as tf
|
|
from .base_dataset import BaseDataset
|
|
from .registry import DATASETS
|
|
|
|
|
|
@DATASETS.register_module
|
|
class TuSimple(BaseDataset):
|
|
def __init__(self, img_path, data_list, cfg=None):
|
|
super().__init__(img_path, data_list, 'seg_label/list', cfg)
|
|
|
|
def transform_train(self):
|
|
input_mean = self.cfg.img_norm['mean']
|
|
train_transform = torchvision.transforms.Compose([
|
|
tf.GroupRandomRotation(),
|
|
tf.GroupRandomHorizontalFlip(),
|
|
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
|
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
|
self.cfg.img_norm['std'], (1, ))),
|
|
])
|
|
return train_transform
|
|
|
|
|
|
def init(self):
|
|
with open(osp.join(self.list_path, self.data_list)) as f:
|
|
for line in f:
|
|
line_split = line.strip().split(" ")
|
|
self.img_name_list.append(line_split[0])
|
|
self.full_img_path_list.append(self.img_path + line_split[0])
|
|
if not self.is_training:
|
|
continue
|
|
self.label_list.append(self.img_path + line_split[1])
|
|
self.exist_list.append(
|
|
np.array([int(line_split[2]), int(line_split[3]),
|
|
int(line_split[4]), int(line_split[5]),
|
|
int(line_split[6]), int(line_split[7])
|
|
]))
|
|
|
|
def fix_gap(self, coordinate):
|
|
if any(x > 0 for x in coordinate):
|
|
start = [i for i, x in enumerate(coordinate) if x > 0][0]
|
|
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
|
|
lane = coordinate[start:end+1]
|
|
if any(x < 0 for x in lane):
|
|
gap_start = [i for i, x in enumerate(
|
|
lane[:-1]) if x > 0 and lane[i+1] < 0]
|
|
gap_end = [i+1 for i,
|
|
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
|
|
gap_id = [i for i, x in enumerate(lane) if x < 0]
|
|
if len(gap_start) == 0 or len(gap_end) == 0:
|
|
return coordinate
|
|
for id in gap_id:
|
|
for i in range(len(gap_start)):
|
|
if i >= len(gap_end):
|
|
return coordinate
|
|
if id > gap_start[i] and id < gap_end[i]:
|
|
gap_width = float(gap_end[i] - gap_start[i])
|
|
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
|
|
gap_end[i] - id) / gap_width * lane[gap_start[i]])
|
|
if not all(x > 0 for x in lane):
|
|
print("Gaps still exist!")
|
|
coordinate[start:end+1] = lane
|
|
return coordinate
|
|
|
|
def is_short(self, lane):
|
|
start = [i for i, x in enumerate(lane) if x > 0]
|
|
if not start:
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
|
|
"""
|
|
Arguments:
|
|
----------
|
|
prob_map: prob map for single lane, np array size (h, w)
|
|
resize_shape: reshape size target, (H, W)
|
|
|
|
Return:
|
|
----------
|
|
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
|
|
"""
|
|
if resize_shape is None:
|
|
resize_shape = prob_map.shape
|
|
h, w = prob_map.shape
|
|
H, W = resize_shape
|
|
H -= self.cfg.cut_height
|
|
|
|
coords = np.zeros(pts)
|
|
coords[:] = -1.0
|
|
for i in range(pts):
|
|
y = int((H - 10 - i * y_px_gap) * h / H)
|
|
if y < 0:
|
|
break
|
|
line = prob_map[y, :]
|
|
id = np.argmax(line)
|
|
if line[id] > thresh:
|
|
coords[i] = int(id / w * W)
|
|
if (coords > 0).sum() < 2:
|
|
coords = np.zeros(pts)
|
|
self.fix_gap(coords)
|
|
#print(coords.shape)
|
|
|
|
return coords
|
|
|
|
def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
|
|
"""
|
|
Arguments:
|
|
----------
|
|
seg_pred: np.array size (5, h, w)
|
|
resize_shape: reshape size target, (H, W)
|
|
exist: list of existence, e.g. [0, 1, 1, 0]
|
|
smooth: whether to smooth the probability or not
|
|
y_px_gap: y pixel gap for sampling
|
|
pts: how many points for one lane
|
|
thresh: probability threshold
|
|
|
|
Return:
|
|
----------
|
|
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
|
|
"""
|
|
if resize_shape is None:
|
|
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
|
|
_, h, w = seg_pred.shape
|
|
H, W = resize_shape
|
|
coordinates = []
|
|
|
|
for i in range(self.cfg.num_classes - 1):
|
|
prob_map = seg_pred[i + 1]
|
|
if smooth:
|
|
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
|
|
coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
|
|
if self.is_short(coords):
|
|
continue
|
|
coordinates.append(
|
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
|
range(pts)])
|
|
|
|
|
|
if len(coordinates) == 0:
|
|
coords = np.zeros(pts)
|
|
coordinates.append(
|
|
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
|
|
range(pts)])
|
|
#print(coordinates)
|
|
|
|
return coordinates
|