metodi/app/labs/lab7.py

144 lines
7.0 KiB
Python
Raw Permalink Normal View History

import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import matplotlib
class lab7:
def __init__(self):
pass
def main(self):
matplotlib.use('TkAgg')
# Загружаем данные ирисов
data = load_iris()
X = data.data # Признаки ирисов
Y = data.target # Сорта ирисов
label = Y # Используем Y в качестве меток классов для классификации
Y_str = data.target_names # Названия сортов ирисов
# setosa
setosa_inds = Y == np.where(Y_str == "setosa")[0][0]
setosa_data = X[setosa_inds, :]
# versicolor
versicolor_inds = Y == np.where(Y_str == "versicolor")[0][0]
versicolor_data = X[versicolor_inds, :]
# virginica
virginica_inds = Y == np.where(Y_str == "virginica")[0][0]
virginica_data = X[virginica_inds, :]
# Определяем функцию классификации
def classification(data, label, k):
X_train, X_test, y_train, y_test = train_test_split(data, label, random_state=10) #Разделение данных
knn = KNeighborsClassifier(n_neighbors=k, metric='euclidean') # Создание KNN-классификатора
knn.fit(X_train, y_train)
predicted_data = knn.predict(X_test)
report = sklearn.metrics.classification_report(y_test, predicted_data)
# Пример вызова функции классификации
setosa_data = sklearn.preprocessing.normalize(setosa_data)
versicolor_data = sklearn.preprocessing.normalize(versicolor_data)
virginica_data = sklearn.preprocessing.normalize(virginica_data)
data_all = np.vstack((setosa_data,versicolor_data,virginica_data))
fig, axs = plt.subplots(2, 3, figsize=(12,8),dpi=160)
axs[0, 0].scatter(data_all[:50, 0], data_all[:50, 1], marker='*')
axs[0, 0].scatter(data_all[50:100, 0], data_all[50:100, 1], marker='^')
axs[0, 0].scatter(data_all[100:, 0], data_all[100:, 1], marker='+')
axs[0, 0].title.set_text('features 1 and 2')
axs[0, 1].scatter(data_all[:50, 0], data_all[:50, 2], marker='*')
axs[0, 1].scatter(data_all[50:100, 0], data_all[50:100, 2], marker='^')
axs[0, 1].scatter(data_all[100:, 0], data_all[100:, 2], marker='+')
axs[0, 1].title.set_text('features 1 and 3')
axs[0, 2].scatter(data_all[:50, 0], data_all[:50, 3], marker='*')
axs[0, 2].scatter(data_all[50:100, 0], data_all[50:100, 3], marker='^')
axs[0, 2].scatter(data_all[100:, 0], data_all[100:, 3], marker='+')
axs[0, 2].title.set_text('features 1 and 4')
axs[1, 0].scatter(data_all[:50, 1], data_all[:50, 2], marker='*',)
axs[1, 0].scatter(data_all[50:100, 1], data_all[50:100, 2], marker='^')
axs[1, 0].scatter(data_all[100:, 1], data_all[100:, 2], marker='+')
axs[1, 0].title.set_text('features 2 and 3')
axs[1, 1].scatter(data_all[:50, 1], data_all[:50, 3], marker='*')
axs[1, 1].scatter(data_all[50:100, 1], data_all[50:100, 3], marker='^')
axs[1, 1].scatter(data_all[100:, 1], data_all[100:, 3], marker='+')
axs[1, 1].title.set_text('features 2 and 4')
axs[1, 2].scatter(data_all[:50, 2], data_all[:50, 3], marker='*')
axs[1, 2].scatter(data_all[50:100, 2], data_all[50:100, 3], marker='^')
axs[1, 2].scatter(data_all[100:, 2], data_all[100:, 3], marker='+')
axs[1, 2].title.set_text('features 3 and 4')
plt.show()
data = np.hstack((data_all[:,0].reshape(150,1),data_all[:,1].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
print('///////////////////////////////////////////////////////////////////')
data = np.hstack((data_all[:,0].reshape(150,1),data_all[:,2].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
print('///////////////////////////////////////////////////////////////////')
data = np.hstack((data_all[:,0].reshape(150,1),data_all[:,3].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
print('///////////////////////////////////////////////////////////////////')
data = np.hstack((data_all[:,1].reshape(150,1),data_all[:,2].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
print('///////////////////////////////////////////////////////////////////')
data = np.hstack((data_all[:,1].reshape(150,1),data_all[:,3].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
print('///////////////////////////////////////////////////////////////////')
data = np.hstack((data_all[:,2].reshape(150,1),data_all[:,3].reshape(150,1)))
print('Количество соседей = 2:')
classification(data,label,2)
print('Количество соседей = 5:')
classification(data,label,5)
print('Количество соседей = 20:')
classification(data,label,20)
print('Количество соседей = 70:')
classification(data,label,70)
if __name__ == "__main__":
l7 = lab7()
l7.main()