#!/usr/bin/env python
# coding: utf-8
# In[1]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# In[2]:
train_data = pd.read_csv('mnist_train.csv')
train_data.dropna(inplace=True)
train_feature = []
mid_data = np.array(train_data)
for i in mid_data:
train_feature.append(i[1:])
train_label = (mid_data.T)[0]
print(len(train_feature),len(train_label))
print(train_feature[0].shape)
# In[3]:
test_data = pd.read_csv('mnist_test.csv')
test_data.dropna(inplace=True)
test_feature = []
mid_data = np.array(test_data)
for i in mid_data:
test_feature.append(i[1:])
test_label = (mid_data.T)[0]
print(len(test_feature),len(test_label))
print(test_feature[0].shape)
# In[5]:
plt.imshow(np.reshape(test_feature[8], (28, 28)), plt.cm.gray)
plt.show()
# In[8]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=45)
#定义一个knn分类器对象
knn.fit(train_feature, train_label)
#调用该对象的训练方法,主要接收两个参数:训练数据集及其样本标签
y_predict = knn.predict(test_feature[:500])
# In[9]:
#测试
fenzi = 0
for i in range(len(y_predict)):
if y_predict[i] == test_label[i]:
fenzi += 1
else:
pass
print('测试正确率为:',float(fenzi)/len(y_predict))#100%=1
print('错误率:',1-float(fenzi)/len(y_predict))
knn算法实现mnist分类
2021/12/19 12:05:13