159 lines
8.5 KiB
Python
159 lines
8.5 KiB
Python
|
import os
|
|||
|
import random
|
|||
|
import xml.etree.ElementTree as ET
|
|||
|
|
|||
|
import numpy as np
|
|||
|
|
|||
|
from utils.utils import get_classes
|
|||
|
|
|||
|
# --------------------------------------------------------------------------------------------------------------------------------#
|
|||
|
# annotation_mode用于指定该文件运行时计算的内容
|
|||
|
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
|
|||
|
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
|
|||
|
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
|
|||
|
# --------------------------------------------------------------------------------------------------------------------------------#
|
|||
|
annotation_mode = 0
|
|||
|
# -------------------------------------------------------------------#
|
|||
|
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
|
|||
|
# 与训练和预测所用的classes_path一致即可
|
|||
|
# 如果生成的2007_train.txt里面没有目标信息
|
|||
|
# 那么就是因为classes没有设定正确
|
|||
|
# 仅在annotation_mode为0和2的时候有效
|
|||
|
# -------------------------------------------------------------------#
|
|||
|
classes_path = 'model_data/voc_classes.txt' # 这里定义的名字是xml的物体的名字,出现的顺序是训练时的onehot顺序。
|
|||
|
# --------------------------------------------------------------------------------------------------------------------------------#
|
|||
|
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
|
|||
|
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
|
|||
|
# 仅在annotation_mode为0和1的时候有效
|
|||
|
# --------------------------------------------------------------------------------------------------------------------------------#
|
|||
|
trainval_percent = 0.9
|
|||
|
train_percent = 0.9
|
|||
|
# -------------------------------------------------------#
|
|||
|
# 指向VOC数据集所在的文件夹
|
|||
|
# 默认指向根目录下的VOC数据集
|
|||
|
# -------------------------------------------------------#
|
|||
|
VOCdevkit_path = 'VOCdevkit'
|
|||
|
|
|||
|
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
|
|||
|
classes, _ = get_classes(classes_path)
|
|||
|
|
|||
|
# -------------------------------------------------------#
|
|||
|
# 统计目标数量
|
|||
|
# -------------------------------------------------------#
|
|||
|
photo_nums = np.zeros(len(VOCdevkit_sets)) # 生成train的数目,val的数目
|
|||
|
nums = np.zeros(len(classes)) # 统计各个类别的数量
|
|||
|
|
|||
|
|
|||
|
def convert_annotation(year, image_id, list_file):
|
|||
|
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml' % (year, image_id)), encoding='utf-8') # 'VOCdevkit\\VOC2007/Annotations/000001.xml'
|
|||
|
tree = ET.parse(in_file)
|
|||
|
root = tree.getroot()
|
|||
|
|
|||
|
for obj in root.iter('object'):
|
|||
|
difficult = 0
|
|||
|
if obj.find('difficult') != None:
|
|||
|
difficult = obj.find('difficult').text
|
|||
|
cls = obj.find('name').text
|
|||
|
if cls not in classes or int(difficult) == 1: # 不在classes里或者difficult为1,跳过当前类别
|
|||
|
continue
|
|||
|
cls_id = classes.index(cls) # 类别对应于classes文件的下标,是类别的id属性
|
|||
|
xmlbox = obj.find('bndbox')
|
|||
|
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)),
|
|||
|
int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
|
|||
|
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
|
|||
|
# list_file的每一行,前面先写了图片的全路径,接着一个空格,依次写各个物体的 以,分隔的坐标,和id
|
|||
|
nums[classes.index(cls)] = nums[classes.index(cls)] + 1 # 统计各个类别的个数
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
random.seed(0)
|
|||
|
if " " in os.path.abspath(VOCdevkit_path):
|
|||
|
raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
|
|||
|
|
|||
|
if annotation_mode == 0 or annotation_mode == 1:
|
|||
|
print("Generate txt in ImageSets.")
|
|||
|
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
|
|||
|
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
|
|||
|
temp_xml = os.listdir(xmlfilepath)
|
|||
|
total_xml = [xml for xml in temp_xml if xml.endswith(".xml")]
|
|||
|
|
|||
|
num = len(total_xml) # 取得原始数据集中的总数,从总数中划分数据集
|
|||
|
list = range(num)
|
|||
|
tv = int(num * trainval_percent) # 训练+验证集 总数
|
|||
|
tr = int(tv * train_percent) # 训练+验证集中 训练集的总数
|
|||
|
trainval = random.sample(list, tv) # 在总数里采样
|
|||
|
train = random.sample(trainval, tr) # 在tv中采样tr
|
|||
|
|
|||
|
print("train and val size", tv)
|
|||
|
print("train size", tr)
|
|||
|
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
|
|||
|
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
|
|||
|
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
|
|||
|
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
|
|||
|
|
|||
|
for i in list:
|
|||
|
name = total_xml[i][:-4] + '\n' # 取出除了后缀的文件名字
|
|||
|
if i in trainval:
|
|||
|
ftrainval.write(name)
|
|||
|
if i in train:
|
|||
|
ftrain.write(name)
|
|||
|
else:
|
|||
|
fval.write(name)
|
|||
|
else:
|
|||
|
ftest.write(name)
|
|||
|
|
|||
|
ftrainval.close()
|
|||
|
ftrain.close()
|
|||
|
fval.close()
|
|||
|
ftest.close()
|
|||
|
print("Generate txt in ImageSets done.")
|
|||
|
|
|||
|
if annotation_mode == 0 or annotation_mode == 2:
|
|||
|
print("Generate 2007_train.txt and 2007_val.txt for train.")
|
|||
|
type_index = 0
|
|||
|
for year, image_set in VOCdevkit_sets:
|
|||
|
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set)), # 'VOCdevkit\\VOC2007/ImageSets/Main/train.txt'
|
|||
|
encoding='utf-8').read().strip().split()
|
|||
|
list_file = open('%s_%s.txt' % (year, image_set), 'w', encoding='utf-8') # '2007_train.txt'
|
|||
|
for image_id in image_ids:
|
|||
|
list_file.write( # 'C:\\my_code\\a_python\\YOLO_all\\yolo_v3\\VOCdevkit/VOC2007/JPEGImages/000001.jpg'
|
|||
|
'%s/VOC%s/JPEGImages/%s.jpg' % (os.path.abspath(VOCdevkit_path), year, image_id)) # 文件全路径名字是拼出来的
|
|||
|
convert_annotation(year, image_id, list_file)
|
|||
|
list_file.write('\n')
|
|||
|
photo_nums[type_index] = len(image_ids) # 记录训练集总数和验证集总数
|
|||
|
type_index += 1 # 用来标记是操作 训练集还是验证集
|
|||
|
list_file.close()
|
|||
|
print("Generate 2007_train.txt and 2007_val.txt for train done.")
|
|||
|
|
|||
|
|
|||
|
def printTable(List1, List2):
|
|||
|
# for i in range(len(List1[0])):
|
|||
|
for i, _ in enumerate(List1[0]):
|
|||
|
print("|", end=' ')
|
|||
|
for j in range(len(List1)): # len(List1)为2
|
|||
|
print(List1[j][i].rjust(int(List2[j])), end=' ')
|
|||
|
print("|", end=' ')
|
|||
|
print()
|
|||
|
|
|||
|
|
|||
|
str_nums = [str(int(x)) for x in nums] # 每个类别的数目
|
|||
|
tableData = [
|
|||
|
classes, str_nums # 类别与数目对应
|
|||
|
]
|
|||
|
colWidths = [0] * len(tableData) # 计算列宽,共有len(tableData)列,这里是2
|
|||
|
len1 = 0
|
|||
|
for i in range(len(tableData)):
|
|||
|
for j in range(len(tableData[i])):
|
|||
|
if len(tableData[i][j]) > colWidths[i]:
|
|||
|
colWidths[i] = len(tableData[i][j]) # 每列中,每个元素的最大长度赋值给colWidths
|
|||
|
printTable(tableData, colWidths)
|
|||
|
|
|||
|
if photo_nums[0] <= 500:
|
|||
|
print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
|
|||
|
|
|||
|
if np.sum(nums) == 0:
|
|||
|
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
|||
|
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
|||
|
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
|||
|
print("(重要的事情说三遍)。")
|