73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
# coding: utf-8
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
|
from prettytable import PrettyTable
|
|
from sklearn.metrics import roc_curve, auc
|
|
|
|
image_path = "/data/anxiang/IJB_release/IJBC"
|
|
files = [
|
|
"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
|
|
]
|
|
|
|
|
|
def read_template_pair_list(path):
|
|
pairs = pd.read_csv(path, sep=' ', header=None).values
|
|
t1 = pairs[:, 0].astype(np.int)
|
|
t2 = pairs[:, 1].astype(np.int)
|
|
label = pairs[:, 2].astype(np.int)
|
|
return t1, t2, label
|
|
|
|
|
|
p1, p2, label = read_template_pair_list(
|
|
os.path.join('%s/meta' % image_path,
|
|
'%s_template_pair_label.txt' % 'ijbc'))
|
|
|
|
methods = []
|
|
scores = []
|
|
for file in files:
|
|
methods.append(file.split('/')[-2])
|
|
scores.append(np.load(file))
|
|
|
|
methods = np.array(methods)
|
|
scores = dict(zip(methods, scores))
|
|
colours = dict(
|
|
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
|
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
|
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
|
fig = plt.figure()
|
|
for method in methods:
|
|
fpr, tpr, _ = roc_curve(label, scores[method])
|
|
roc_auc = auc(fpr, tpr)
|
|
fpr = np.flipud(fpr)
|
|
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
|
plt.plot(fpr,
|
|
tpr,
|
|
color=colours[method],
|
|
lw=1,
|
|
label=('[%s (AUC = %0.4f %%)]' %
|
|
(method.split('-')[-1], roc_auc * 100)))
|
|
tpr_fpr_row = []
|
|
tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
|
|
for fpr_iter in np.arange(len(x_labels)):
|
|
_, min_index = min(
|
|
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
|
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
|
tpr_fpr_table.add_row(tpr_fpr_row)
|
|
plt.xlim([10 ** -6, 0.1])
|
|
plt.ylim([0.3, 1.0])
|
|
plt.grid(linestyle='--', linewidth=1)
|
|
plt.xticks(x_labels)
|
|
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
|
plt.xscale('log')
|
|
plt.xlabel('False Positive Rate')
|
|
plt.ylabel('True Positive Rate')
|
|
plt.title('ROC on IJB')
|
|
plt.legend(loc="lower right")
|
|
print(tpr_fpr_table)
|