Graduation_Project/WZM/mytools.py

185 lines
4.8 KiB
Python

# coding:utf-8
"""导入一些包"""
import os
import time, random
import json
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
""" 打印一些东西 """
"""----------------------------------------------------------------------"""
# 打印列表按照竖行的形式
def print_list(list):
print("++++++++++++++++++++++++++++++++++++++++++++")
for l in list:
print(l)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印字典按照竖行的形式
def print_dict(dict):
print("++++++++++++++++++++++++++++++++++++++++++++")
for k, v in dict.items():
print("key:", k, " value:", v)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印一些东西,加入标识符
def print_with_log(info):
print("++++++++++++++++++++++++++++++++++++++++++++")
print(info)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印标识符
def print_log():
print("++++++++++++++++++++++++++++++++++++++++++++")
""" 文件存储 """
"""----------------------------------------------------------------------"""
# 保存结果到json文件
def save_to_json(info, filename, encoding='UTF-8'):
with open(filename, "w", encoding=encoding) as f:
json.dump(info, f, indent=2, separators=(',', ':'))
# 从json文件中读取
def load_from_json(filename):
with open(filename, encoding='utf-8') as f:
info = json.load(f)
return info
# 储存为npy文件
def save_to_npy(info, filename):
np.save(filename, info, allow_pickle=True)
# 从npy中读取
def load_from_npy(filename):
info = np.load(filename, allow_pickle=True)
return info
# 保存结果到txt文件
def log_to_txt(contexts=None, filename="save.txt", mark=False, encoding='UTF-8', add_n=False):
f = open(filename, "a", encoding=encoding)
if mark:
sig = "------------------------------------------------\n"
f.write(sig)
elif isinstance(contexts, dict):
tmp = ""
for c in contexts.keys():
tmp += str(c) + " | " + str(contexts[c]) + "\n"
contexts = tmp
f.write(contexts)
else:
if isinstance(contexts, list):
tmp = ""
for c in contexts:
if add_n:
tmp += str(c) + "\n"
else:
tmp += str(c)
contexts = tmp
else:
contexts = contexts + "\n"
f.write(contexts)
f.close()
# 从txt中读取行
def load_from_txt(filename, encoding="utf-8"):
f = open(filename, 'r', encoding=encoding)
contexts = f.readlines()
return contexts
""" 字典变换 """
"""----------------------------------------------------------------------"""
# 键值互换
def dict_k_v_exchange(dict):
tmp = {}
for key, value in dict.items():
tmp[value] = key
return tmp
# 2维数组转字典
def d2array_to_dict(d2array):
# Input: N x 2 list
# Output: dict
dict = {}
for item in d2array:
if item[0] not in dict.keys():
dict[item[0]] = [item[1]]
else:
dict[item[0]].append(item[1])
return dict
""" 绘图 """
"""----------------------------------------------------------------------"""
# 绘制3D图像
def visual_3d_points(list, color=True):
"""
:param list: N x (dim +1)
N 为点的数量
dim 为 输入数据的维度
1 为类别, 即可视化的颜色 当且仅当color为True时
"""
list = np.array(list)
if color:
data = list[:, :4]
label = list[:, -1]
else:
data = list
label = None
# PCA降维
pca = PCA(n_components=3, whiten=True).fit(data)
data = pca.transform(data)
# 定义坐标轴
fig = plt.figure()
ax1 = plt.axes(projection='3d')
if label is not None:
color = label
else:
color = "blue"
ax1.scatter3D(np.transpose(data)[0], np.transpose(data)[1], np.transpose(data)[2], c=color) # 绘制散点图
plt.show()
""" 实用工具 """
"""----------------------------------------------------------------------"""
# 计算数组中元素出现的个数
def count_list(lens):
dict = {}
for key in lens:
dict[key] = dict.get(key, 0) + 1
dict = sorted(dict.items(), key=lambda x: x[1], reverse=True)
print_list(dict)
return dict
# list 加法 w1、w2为权重
def list_add(list1, list2, w1=1, w2=1):
return [l1 * w1 + l2 * w2 for (l1, l2) in zip(list1, list2)]