对于分类问题, 我们通常会对交叉熵的损失函数进行优化, 统计学上还有一个极大似然的概念, 实际发生的事件的概率要最大.
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
data = load_iris()
X = data.data
y = data.target
xtrain, xtest, ytrain, ytest = train_test_split(X, y, test_size=0.2)
xtrain = torch.Tensor(xtrain)
xtest = torch.Tensor(xtest)
ytrain = torch.Tensor(ytrain).long()
ytest = torch.Tensor(ytest).long()
ytrain = nn.functional.one_hot(ytrain)
ytest = nn.functional.one_hot(ytest)
# 搭建网络
class Network(nn.Module):
def __init__(self, input_size, class_num):
super(Network, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.1)
self.fc2 = nn.Linear(256, 128)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(0.1)
self.fc3 = nn.Linear(128, class_num)
self.out = nn.Softmax(1)
def forward(self, data):
out = self.fc1(data)
out = self.relu1(out)
out = self.dropout1(out)
out = self.fc2(out)
out = self.relu2(out)
out = self.dropout2(out)
out = self.fc3(out)
out = self.out(out)
return out
model = Network(xtrain.shape[1], ytrain.shape[1])
def negative_log_likelihood(ypre, ytrue):
"""
负对数似然损失
:params ypre: 预测的概率分布
:params ytrue: 实际的概率分布(一般标签是确定的, 所以传入的是独热码形式)
"""
pre_p = ypre*ytrue
pre_p = pre_p[pre_p!=0]
return -torch.sum(torch.log(pre_p))/len(ypre)
# 初始化优化器
sgd = optim.SGD(model.parameters(), lr=0.0001)
epochs = 500
train_loss = []
test_loss = []
for epoch in range(epochs):
ypre = model(xtrain)
loss = negative_log_likelihood(ypre=ypre, ytrue=ytrain)
loss.backward()
sgd.step()
if epoch%10 == 0:
test_ypre = model(xtest)
te_loss = negative_log_likelihood(test_ypre, ytest)
print(f'Epoch[{epoch}]/Epoch[{epochs}] | train_loss:{round(loss.item(), 4)} | test_loss:{round(te_loss.item(), 4)}')
train_loss.append(round(loss.item(), 4))
test_loss.append(round(te_loss.item(), 4))
Epoch[0]/Epoch[500] | train_loss:1.2091 | test_loss:1.2715 Epoch[10]/Epoch[500] | train_loss:1.1581 | test_loss:1.1928 Epoch[20]/Epoch[500] | train_loss:1.0941 | test_loss:1.1049 Epoch[30]/Epoch[500] | train_loss:1.0943 | test_loss:1.1198 Epoch[40]/Epoch[500] | train_loss:1.057 | test_loss:1.1139 Epoch[50]/Epoch[500] | train_loss:0.9887 | test_loss:0.9825 Epoch[60]/Epoch[500] | train_loss:0.906 | test_loss:0.9176 Epoch[70]/Epoch[500] | train_loss:0.8696 | test_loss:0.8687 Epoch[80]/Epoch[500] | train_loss:0.8538 | test_loss:0.8458 Epoch[90]/Epoch[500] | train_loss:0.8035 | test_loss:0.7657 Epoch[100]/Epoch[500] | train_loss:0.7158 | test_loss:0.6432 Epoch[110]/Epoch[500] | train_loss:0.6182 | test_loss:0.5816 Epoch[120]/Epoch[500] | train_loss:0.6271 | test_loss:0.5277 Epoch[130]/Epoch[500] | train_loss:0.5899 | test_loss:0.4906 Epoch[140]/Epoch[500] | train_loss:0.5438 | test_loss:0.4902 Epoch[150]/Epoch[500] | train_loss:0.4554 | test_loss:0.422 Epoch[160]/Epoch[500] | train_loss:0.4543 | test_loss:0.4194 Epoch[170]/Epoch[500] | train_loss:0.5196 | test_loss:0.4473 Epoch[180]/Epoch[500] | train_loss:0.4555 | test_loss:0.385 Epoch[190]/Epoch[500] | train_loss:0.376 | test_loss:0.3174 Epoch[200]/Epoch[500] | train_loss:0.408 | test_loss:0.374 Epoch[210]/Epoch[500] | train_loss:0.446 | test_loss:0.4177 Epoch[220]/Epoch[500] | train_loss:0.3581 | test_loss:0.3264 Epoch[230]/Epoch[500] | train_loss:0.337 | test_loss:0.3147 Epoch[240]/Epoch[500] | train_loss:0.4227 | test_loss:0.3006 Epoch[250]/Epoch[500] | train_loss:0.307 | test_loss:0.2373 Epoch[260]/Epoch[500] | train_loss:0.2739 | test_loss:0.2042 Epoch[270]/Epoch[500] | train_loss:0.3512 | test_loss:0.3569 Epoch[280]/Epoch[500] | train_loss:0.2658 | test_loss:0.3619 Epoch[290]/Epoch[500] | train_loss:0.2523 | test_loss:0.1989 Epoch[300]/Epoch[500] | train_loss:0.3566 | test_loss:0.2156 Epoch[310]/Epoch[500] | train_loss:0.2587 | test_loss:0.3668 Epoch[320]/Epoch[500] | train_loss:0.2604 | test_loss:0.2214 Epoch[330]/Epoch[500] | train_loss:0.3063 | test_loss:0.3488 Epoch[340]/Epoch[500] | train_loss:0.1876 | test_loss:0.2121 Epoch[350]/Epoch[500] | train_loss:0.2998 | test_loss:0.1762 Epoch[360]/Epoch[500] | train_loss:0.2308 | test_loss:0.2481 Epoch[370]/Epoch[500] | train_loss:0.1659 | test_loss:0.1174 Epoch[380]/Epoch[500] | train_loss:0.1789 | test_loss:0.2015 Epoch[390]/Epoch[500] | train_loss:0.2157 | test_loss:0.3624 Epoch[400]/Epoch[500] | train_loss:0.2419 | test_loss:0.1678 Epoch[410]/Epoch[500] | train_loss:0.1222 | test_loss:0.3144 Epoch[420]/Epoch[500] | train_loss:0.1811 | test_loss:0.2228 Epoch[430]/Epoch[500] | train_loss:0.2376 | test_loss:0.2872 Epoch[440]/Epoch[500] | train_loss:0.1405 | test_loss:0.1619 Epoch[450]/Epoch[500] | train_loss:0.1995 | test_loss:0.1871 Epoch[460]/Epoch[500] | train_loss:0.2417 | test_loss:0.1267 Epoch[470]/Epoch[500] | train_loss:0.1451 | test_loss:0.0633 Epoch[480]/Epoch[500] | train_loss:0.1377 | test_loss:0.1711 Epoch[490]/Epoch[500] | train_loss:0.2641 | test_loss:0.1647
fig = plt.figure()
plt.plot(range(0, epochs, 10), train_loss, label='train_loss', color='red')
plt.plot(range(0, epochs, 10), test_loss, label='test_loss', color='blue')
plt.legend()
plt.show()
def acc(ypre, ytrue):
"""
评估模型分类的正确性
:params ypre: 预测标签(概率分布型数据)
:params ytrue: 实际标签(独热码数据)
"""
pre_label = torch.argmax(ypre, axis=1)
true_label = torch.argmax(ytrue, axis=1)
accur = torch.sum(pre_label==true_label)/len(true_label)
return accur
# 训练集准确率
acc(model(xtrain), ytrain)
# 测试集准确率
acc(model(xtest), ytest)