From 0386b973c6c7bfe3b51ab79abd33dfad02908624 Mon Sep 17 00:00:00 2001 From: rzzn <2386089024@qq.com> Date: Mon, 24 Jun 2024 18:15:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20WZM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- WZM/README.md | 166 ++++++++++++++++++++++++++++++++++++++ WZM/data.py | 145 ++++++++++++++++++++++++++++++++++ WZM/engine.py | 184 +++++++++++++++++++++++++++++++++++++++++++ WZM/find_best_pth.py | 112 ++++++++++++++++++++++++++ WZM/mytools.py | 184 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 791 insertions(+) create mode 100644 WZM/README.md create mode 100644 WZM/data.py create mode 100644 WZM/engine.py create mode 100644 WZM/find_best_pth.py create mode 100644 WZM/mytools.py diff --git a/WZM/README.md b/WZM/README.md new file mode 100644 index 0000000..a27f0dd --- /dev/null +++ b/WZM/README.md @@ -0,0 +1,166 @@ +## The offical PyTorch code for paper ["Remote Sensing Cross-Modal Text-Image Retrieval Based on Global and Local Information", TGRS 2022.](https://doi.org/10.1109/TGRS.2022.3163706) + +# GAC + +##### Author: Zhiqiang Yuan + + +![Supported Python versions](https://img.shields.io/badge/python-3.7-blue.svg) +![Supported OS](https://img.shields.io/badge/Supported%20OS-Linux-yellow.svg) +![npm License](https://img.shields.io/npm/l/mithril.svg) + + +### ------------------------------------------------------------------------------------- + +### Welcome :+1:_`Fork and Star`_:+1:, then we'll let you know when we update + +```bash +#### News: +#### 2021.9.26: ---->Under update ...<---- +``` + +### ------------------------------------------------------------------------------------- + +## INTRODUCTION + +This is GAC, a cross-modal retrieval method for remote sensing images. +We use the MIDF module to fuse multi-level RS image features, and add the DREA mechanism to improve the performance of local features. +In addition, a multivariate rerank algorithm is designed to make full use of the information in the similarity matrix during the testing. +Our method has achieved the state-of-the-art performance (2021.10) in RS cross-modal retrieval task on multiple RS image-text datasets. + +### Network Architecture + +![arch image](./figure/GAC.jpg) +The proposed RSCTIR framework based on global and local information. Compared with the retrieval models constructed using only global features, GAC incorporates optimized local features in the visual encoding considering the target redundancy of RS. The multi-level information dynamic fusion module is designed to fuse the two types of information, using the global information to supplement the local information and utilizing the latter to correct the former. The suggested multivariate rerank algorithm as a post-processing method further improves the retrieval accuracy without extra training. + +### DREA + +To alleviate the pressure on the model from redundant target relations and increase the model’s focus on salient instances, we come up with a denoised representation matrix and a enhanced adjacency matrix to assist the GCN in producing better local representations. +DREA filters the redundant features with high similarity and enhances the features of salient targets, which enables GAC to obtain more transcendent visual representation. + +### MIDF + +MIDF +The proposed multi-level information dynamic fusion module. The method falls into two stages of feature retransformation and dynamic fusion. MIDF first uses SA and GA modules to retransform features, then uses global information to supplement local information and leverages the latter to correct the former. Further dynamic fusion of multi-level features is accomplished through the fabricated dynamic fusion module. + +### Multivariate Rerank + +similarity +The proposed multivariate rerank algorithm. In order to make full use of the similarity matrix, we use k candidates for reverse search and to optimize the similarity results by considering multiple ranking factors. The figure shows an illustration of multivariate rerank when k = 3, using image i for retrieval. + +### Performance + +![performance](./figure/performance.jpg) +Comparisons of Retrieval Performance on RSICD and RSITMD Testset. + +### ------------------------------------------------------------------------------------- + +## IMPLEMENTATION + +```bash +Installation + +We recommended the following dependencies: +Python 3 +PyTorch > 0.3 +Numpy +h5py +nltk +yaml +``` + +```bash +file structure: +-- checkpoint # savepath of ckpt and logs + +-- data # soorted anns of four datesets + -- rsicd_precomp + -- train_caps.txt # train anns + -- train_filename.txt # corresponding imgs + -- test_caps.txt # test anns + -- test_filename.txt # corresponding imgs + -- images # rsicd images here + -- rsitmd_precomp + ... + +-- exec # .sh file + +-- layers # models define + +-- logs # tensorboard save file + +-- option # different config for different datasets and models + +-- util # some script for data processing + +-- vocab # vocabs for different datasets + +-- seq2vec # some files about seq2vec + -- bi_skip.npz + -- bi_skip.npz.pkl + -- btable.npy + -- dictionary.txt + -- uni_skip.npz + -- uni_skip.npz.pkl + -- utable.npy + +-- postprocessing # multivariate rerank + -- rerank.py + -- file + +-- data.py # load data +-- engine.py # details about train and val +-- test.py # test k-fold answers +-- test_single.py # test one model +-- train.py # main file +-- utils.py # some tools +-- vocab.py # generate vocab + +Note: +1. In order to facilitate reproduction, we have provided processed annotations. +2. We prepare some used file:: + (1)[seq2vec (Password:NIST)](https://pan.baidu.com/s/1jz61ZYs8NZflhU_Mm4PbaQ) + (2)[RSICD images (Password:NIST)](https://pan.baidu.com/s/1lH5m047P9m2IvoZMPsoDsQ) +``` + +```bash +Run: (We take the dataset RSITMD as an example) +Step1: + Put the images of different datasets in ./data/{dataset}_precomp/images/ + + --data + --rsitmd_precomp + -- train_caps.txt # train anns + -- train_filename.txt # corresponding imgs + -- test_caps.txt # test anns + -- test_filename.txt # corresponding imgs + -- images # images here + --img1.jpg + --img2.jpg + ... + +Step2: + Modify the corresponding yaml in ./options. + + Regard RSITMD_AMFMN.yaml as opt, which you need to change is: + opt['dataset']['data_path'] # change to precomp path + opt['dataset']['image_path'] # change to image path + opt['model']['seq2vec']['dir_st'] # some files about seq2vec + +Step3: + Bash the ./sh in ./exec. + Note the GPU define in specific .sh file. + + cd exec/RSICD + bash run_GAC_rsicd.sh + +Note: We use k-fold verity to do a fair compare. Other details please see the code itself. +``` + +## Citation + +If you feel this code helpful or use this code or dataset, please cite it as + +``` +Z. Yuan et al., "Remote Sensing Cross-Modal Text-Image Retrieval Based on Global and Local Information," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2022.3163706. +``` diff --git a/WZM/data.py b/WZM/data.py new file mode 100644 index 0000000..4d5309f --- /dev/null +++ b/WZM/data.py @@ -0,0 +1,145 @@ +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +import os +import nltk +import numpy as np +import yaml +import argparse +import utils +from vocab import deserialize_vocab +from PIL import Image + +class PrecompDataset(data.Dataset): + """ + Load precomputed captions and image features + """ + + def __init__(self, data_split, vocab, opt): + self.vocab = vocab + self.loc = opt['dataset']['data_path'] + self.img_path = opt['dataset']['image_path'] + + # Captions + self.images = [] + self.captions = [] + self.maxlength = 0 + + if data_split != 'test': + with open(self.loc+'%s_caps_verify.txt' % data_split, 'rb') as f: + for line in f: + self.captions.append(line.strip()) + + with open(self.loc + '%s_filename_verify.txt' % data_split, 'rb') as f: + for line in f: + self.images.append(line.strip()) + else: + with open(self.loc + '%s_caps.txt' % data_split, 'rb') as f: + for line in f: + self.captions.append(line.strip()) + + with open(self.loc + '%s_filename.txt' % data_split, 'rb') as f: + for line in f: + self.images.append(line.strip()) + + self.length = len(self.captions) + # rkiros data has redundancy in images, we divide by 5, 10crop doesn't + if len(self.images) != self.length: + self.im_div = 5 + else: + self.im_div = 1 + + if data_split == "train": + self.transform = transforms.Compose([ + # transforms.Resize((278, 278)), + transforms.Resize((256, 256)), + transforms.RandomRotation(degrees=(0, 90)), + # transforms.RandomCrop(256), + transforms.RandomCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + else: + self.transform = transforms.Compose([ + # transforms.Resize((256, 256)), + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + + def __getitem__(self, index): + # handle the image redundancy + img_id = index//self.im_div + caption = self.captions[index] + + vocab = self.vocab + + # Convert caption (string) to word ids. + tokens = nltk.tokenize.word_tokenize( + caption.lower().decode('utf-8')) + punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%'] + tokens = [k for k in tokens if k not in punctuations] + tokens_UNK = [k if k in vocab.word2idx.keys() else '' for k in tokens] + + + caption = [] + caption.append(vocab('')) + caption.extend([vocab(token) for token in tokens_UNK]) + caption.append(vocab('')) + target = torch.LongTensor(caption) + + image = Image.open(self.img_path + str(self.images[img_id])[2:-1]).convert('RGB') + image = self.transform(image) # torch.Size([3, 256, 256]) + + return image, target, index, img_id + + def __len__(self): + return self.length + + +def collate_fn(data): + + # Sort a data list by caption length + data.sort(key=lambda x: len(x[1]), reverse=True) + images, captions, ids, img_ids = zip(*data) + + # Merge images (convert tuple of 3D tensor to 4D tensor) + images = torch.stack(images, 0) + + # Merget captions (convert tuple of 1D tensor to 2D tensor) + lengths = [len(cap) for cap in captions] + targets = torch.zeros(len(captions), max(lengths)).long() + for i, cap in enumerate(captions): + end = lengths[i] + targets[i, :end] = cap[:end] + + lengths = [l if l !=0 else 1 for l in lengths] + + return images, targets, lengths, ids + + +def get_precomp_loader(data_split, vocab, batch_size=100, + shuffle=True, num_workers=0, opt={}): + """Returns torch.utils.data.DataLoader for custom coco dataset.""" + dset = PrecompDataset(data_split, vocab, opt) + + data_loader = torch.utils.data.DataLoader(dataset=dset, + batch_size=batch_size, + shuffle=shuffle, + pin_memory=True, + collate_fn=collate_fn, + num_workers=num_workers) + return data_loader + +def get_loaders(vocab, opt): + train_loader = get_precomp_loader( 'train', vocab, + opt['dataset']['batch_size'], True, opt['dataset']['workers'], opt=opt) + val_loader = get_precomp_loader( 'val', vocab, + opt['dataset']['batch_size_val'], False, opt['dataset']['workers'], opt=opt) + return train_loader, val_loader + + +def get_test_loader(vocab, opt): + test_loader = get_precomp_loader( 'test', vocab, + opt['dataset']['batch_size_val'], False, opt['dataset']['workers'], opt=opt) + return test_loader diff --git a/WZM/engine.py b/WZM/engine.py new file mode 100644 index 0000000..b833907 --- /dev/null +++ b/WZM/engine.py @@ -0,0 +1,184 @@ +import time +import torch +import numpy as np +import sys +from torch.autograd import Variable +import tensorboard_logger as tb_logger +import logging +from torch.nn.utils.clip_grad import clip_grad_norm + +from model.utils import cosine_sim, cosine_similarity +import utils + +def train(train_loader, model, optimizer, epoch, opt={}): + + # extract value + grad_clip = opt['optim']['grad_clip'] + max_violation = opt['optim']['max_violation'] + margin = opt['optim']['margin'] + loss_name = opt['model']['name'] + "_" + opt['dataset']['datatype'] + print_freq = opt['logs']['print_freq'] + + # switch to train mode + model.train() + batch_time = utils.AverageMeter() + data_time = utils.AverageMeter() + train_logger = utils.LogCollector() + + end = time.time() + params = list(model.parameters()) + for i, train_data in enumerate(train_loader): + images, captions, lengths, ids= train_data + + batch_size = images.size(0) + # print("batch_size : ", batch_size) + margin = float(margin) + # measure data loading time + data_time.update(time.time() - end) + model.logger = train_logger + + input_visual = Variable(images) + input_text = Variable(captions) + + if torch.cuda.is_available(): + input_visual = input_visual.cuda() + input_text = input_text.cuda() + + # visual_feature, text_feature = model(input_visual, input_local_rep, input_local_adj, input_text, lengths) + # scores = cosine_sim(visual_feature, text_feature) + # print("visual_feature shape : ", visual_feature.shape) + scores = model(input_visual, input_text, lengths) + # print("scores shape : ", scores.shape) + torch.cuda.synchronize() + loss = utils.calcul_loss(scores, input_visual.size(0), margin, max_violation=max_violation, ) + + if grad_clip > 0: + clip_grad_norm(params, grad_clip) + + train_logger.update('L', loss.cpu().data.numpy()) + + + optimizer.zero_grad() + loss.backward() + torch.cuda.synchronize() + optimizer.step() + torch.cuda.synchronize() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % print_freq == 0: + logging.info( + 'Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f}\t' + '{elog}\t' + .format(epoch, i, len(train_loader), + batch_time=batch_time, + elog=str(train_logger))) + + utils.log_to_txt( + 'Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f}\t' + '{elog}\t' + .format(epoch, i, len(train_loader), + batch_time=batch_time, + elog=str(train_logger)), + opt['logs']['ckpt_save_path']+ opt['model']['name'] + "_" + opt['dataset']['datatype'] +".txt" + ) + tb_logger.log_value('epoch', epoch) + tb_logger.log_value('step', i) + tb_logger.log_value('batch_time', batch_time.val) + train_logger.tb_log(tb_logger) + +def validate(val_loader, model): + + model.eval() + val_logger = utils.LogCollector() + model.logger = val_logger + + start = time.time() + # input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256)) + input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224)) + input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64) + input_text_lengeth = [0]*len(val_loader.dataset) + for i, val_data in enumerate(val_loader): + + images, captions, lengths, ids = val_data + + for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths): + input_visual[id] = img + input_text[id, :captions.size(1)] = cap + input_text_lengeth[id] = l + + + input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)]) + + d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth) + + end = time.time() + print("calculate similarity time:", end - start) + + (r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(d) + logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % + (r1i, r5i, r10i, medri, meanri)) + (r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(d) + logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % + (r1t, r5t, r10t, medrt, meanrt)) + currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0 + + all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format( + r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore + ) + + tb_logger.log_value('r1i', r1i) + tb_logger.log_value('r5i', r5i) + tb_logger.log_value('r10i', r10i) + tb_logger.log_value('medri', medri) + tb_logger.log_value('meanri', meanri) + tb_logger.log_value('r1t', r1t) + tb_logger.log_value('r5t', r5t) + tb_logger.log_value('r10t', r10t) + tb_logger.log_value('medrt', medrt) + tb_logger.log_value('meanrt', meanrt) + tb_logger.log_value('rsum', currscore) + + return currscore, all_score + + +def validate_test(val_loader, model): + model.eval() + val_logger = utils.LogCollector() + model.logger = val_logger + + start = time.time() + # input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256)) + input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224)) + input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64) + input_text_lengeth = [0] * len(val_loader.dataset) + + embed_start = time.time() + for i, val_data in enumerate(val_loader): + + images, captions, lengths, ids = val_data + + + for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths): + input_visual[id] = img + + + input_text[id, :captions.size(1)] = cap + input_text_lengeth[id] = l + + input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)]) + embed_end = time.time() + print("embedding time: {}".format(embed_end-embed_start)) + + d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth) + + end = time.time() + print("calculate similarity time:", end - start) + + return d + + diff --git a/WZM/find_best_pth.py b/WZM/find_best_pth.py new file mode 100644 index 0000000..820b38e --- /dev/null +++ b/WZM/find_best_pth.py @@ -0,0 +1,112 @@ +import os, random, copy +import numpy as np +import torch +import argparse +import yaml +import logging + +import utils +import data +import engine + +from vocab import deserialize_vocab +import mytools +from model import GAC as models + +def parser_options(): + # Hyper Parameters setting + parser = argparse.ArgumentParser() + parser.add_argument('--path_opt', default='option/RSITMD_mca/RSITMD_GAC.yaml', type=str, + help='path to a yaml options file') + opt = parser.parse_args() + + # load model options + with open(opt.path_opt, 'r') as handle: + options = yaml.safe_load(handle) + + return options + +def main(options, vocab): + # Create dataset, model, criterion and optimizer + test_loader = data.get_test_loader(vocab, options) + + model = models.factory(options['model'], + vocab, + cuda=True, + data_parallel=False) + + print('Model has {} parameters'.format(utils.params_count(model))) + + # optionally resume from a checkpoint + if os.path.isfile(options['optim']['resume']): + print("=> loading checkpoint '{}'".format(options['optim']['resume'])) + checkpoint = torch.load(options['optim']['resume']) + start_epoch = checkpoint['epoch'] + best_rsum = checkpoint['best_rsum'] + model.load_state_dict(checkpoint['model']) + else: + print("=> no checkpoint found at '{}'".format(options['optim']['resume'])) + + # evaluate on test set + sims = engine.validate_test(test_loader, model) + + # get indicators + (r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(sims) + logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % + (r1i, r5i, r10i, medri, meanri)) + (r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(sims) + logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % + (r1t, r5t, r10t, medrt, meanrt)) + currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0 + + all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format( + r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore + ) + # 记录到输出文件中 + outputfile_path = "RSICD_GAC_decay0.5_m0.2_without_m4m5.txt" + with open(outputfile_path, 'a') as file: + file.writelines(options['optim']['resume']) + file.write(all_score) + + print(all_score) + + return [r1i, r5i, r10i, r1t, r5t, r10t, currscore] + +def get_allpth_score(options, k, vocab): + updated_options = copy.deepcopy(options) + scores = [] + directory = options['logs']['ckpt_save_path'] + options['k_fold']['experiment_name'] + "/" + str(k) + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith('.tar'): + file_path = os.path.join(root, file) + updated_options['optim']['resume'] = file_path + # run experiment + one_score = main(updated_options, vocab) + scores.append(one_score) + return scores + +if __name__ == '__main__': + options = parser_options() + # make vocab + vocab = deserialize_vocab(options['dataset']['vocab_path']) + vocab_word = sorted(vocab.word2idx.items(), key=lambda x: x[1], reverse=False) + vocab_word = [tup[0] for tup in vocab_word] + + # calc ave k results + last_score = [] + for k in range(options['k_fold']['nums']): + print("=========================================") + print("Start evaluate {}th fold".format(k)) + + scores = get_allpth_score(options, k, vocab) + last_score.extend(scores) + + print("Complete evaluate {}th fold".format(k)) + + # average + print("===================== Ave Score ({}-fold verify) =================".format(options['k_fold']['nums'])) + last_score = np.average(last_score, axis=0) + names = ['r1i', 'r5i', 'r10i', 'r1t', 'r5t', 'r10t', 'mr'] + for name,score in zip(names, last_score): + print("{}:{}".format(name, score)) diff --git a/WZM/mytools.py b/WZM/mytools.py new file mode 100644 index 0000000..4f9bd72 --- /dev/null +++ b/WZM/mytools.py @@ -0,0 +1,184 @@ +# coding:utf-8 +"""导入一些包""" +import os +import time, random +import json +import numpy as np +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +""" 打印一些东西 """ +"""----------------------------------------------------------------------""" + + +# 打印列表按照竖行的形式 +def print_list(list): + print("++++++++++++++++++++++++++++++++++++++++++++") + for l in list: + print(l) + print("++++++++++++++++++++++++++++++++++++++++++++") + + +# 打印字典按照竖行的形式 +def print_dict(dict): + print("++++++++++++++++++++++++++++++++++++++++++++") + for k, v in dict.items(): + print("key:", k, " value:", v) + print("++++++++++++++++++++++++++++++++++++++++++++") + + +# 打印一些东西,加入标识符 +def print_with_log(info): + print("++++++++++++++++++++++++++++++++++++++++++++") + print(info) + print("++++++++++++++++++++++++++++++++++++++++++++") + + +# 打印标识符 +def print_log(): + print("++++++++++++++++++++++++++++++++++++++++++++") + + +""" 文件存储 """ +"""----------------------------------------------------------------------""" + + +# 保存结果到json文件 +def save_to_json(info, filename, encoding='UTF-8'): + with open(filename, "w", encoding=encoding) as f: + json.dump(info, f, indent=2, separators=(',', ':')) + + +# 从json文件中读取 +def load_from_json(filename): + with open(filename, encoding='utf-8') as f: + info = json.load(f) + return info + + +# 储存为npy文件 +def save_to_npy(info, filename): + np.save(filename, info, allow_pickle=True) + + +# 从npy中读取 +def load_from_npy(filename): + info = np.load(filename, allow_pickle=True) + return info + + +# 保存结果到txt文件 +def log_to_txt(contexts=None, filename="save.txt", mark=False, encoding='UTF-8', add_n=False): + f = open(filename, "a", encoding=encoding) + if mark: + sig = "------------------------------------------------\n" + f.write(sig) + elif isinstance(contexts, dict): + tmp = "" + for c in contexts.keys(): + tmp += str(c) + " | " + str(contexts[c]) + "\n" + contexts = tmp + f.write(contexts) + else: + if isinstance(contexts, list): + tmp = "" + for c in contexts: + if add_n: + tmp += str(c) + "\n" + else: + tmp += str(c) + contexts = tmp + else: + contexts = contexts + "\n" + f.write(contexts) + + f.close() + + +# 从txt中读取行 +def load_from_txt(filename, encoding="utf-8"): + f = open(filename, 'r', encoding=encoding) + contexts = f.readlines() + return contexts + + +""" 字典变换 """ +"""----------------------------------------------------------------------""" + + +# 键值互换 +def dict_k_v_exchange(dict): + tmp = {} + for key, value in dict.items(): + tmp[value] = key + return tmp + + +# 2维数组转字典 +def d2array_to_dict(d2array): + # Input: N x 2 list + # Output: dict + dict = {} + for item in d2array: + if item[0] not in dict.keys(): + dict[item[0]] = [item[1]] + else: + dict[item[0]].append(item[1]) + return dict + + +""" 绘图 """ +"""----------------------------------------------------------------------""" + + +# 绘制3D图像 +def visual_3d_points(list, color=True): + """ + :param list: N x (dim +1) + N 为点的数量 + dim 为 输入数据的维度 + 1 为类别, 即可视化的颜色 当且仅当color为True时 + """ + list = np.array(list) + if color: + data = list[:, :4] + label = list[:, -1] + else: + data = list + label = None + + # PCA降维 + pca = PCA(n_components=3, whiten=True).fit(data) + data = pca.transform(data) + + # 定义坐标轴 + fig = plt.figure() + ax1 = plt.axes(projection='3d') + if label is not None: + color = label + else: + color = "blue" + ax1.scatter3D(np.transpose(data)[0], np.transpose(data)[1], np.transpose(data)[2], c=color) # 绘制散点图 + + plt.show() + + +""" 实用工具 """ +"""----------------------------------------------------------------------""" + + +# 计算数组中元素出现的个数 +def count_list(lens): + dict = {} + for key in lens: + dict[key] = dict.get(key, 0) + 1 + dict = sorted(dict.items(), key=lambda x: x[1], reverse=True) + + print_list(dict) + return dict + + +# list 加法 w1、w2为权重 +def list_add(list1, list2, w1=1, w2=1): + return [l1 * w1 + l2 * w2 for (l1, l2) in zip(list1, list2)]