博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Iris Classification on PyTorch
阅读量:4350 次
发布时间:2019-06-07

本文共 1593 字,大约阅读时间需要 5 分钟。

Iris Classification on PyTorch

code

# -*- coding:utf8 -*-from sklearn.datasets import load_irisfrom sklearn.utils import shufflefrom sklearn.model_selection import train_test_splitimport torchimport torch.optim as optimimport torch.nn as nnimport torch.nn.functional as Fhl = 10lr = 0.005num_epoch = 50000class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.fc1 = nn.Linear(4, hl)        self.fc2 = nn.Linear(hl, 3)        self.softmax = torch.nn.Softmax(dim=1)    def forward(self, x):        out = self.fc1(x)        out = F.relu(out)        out = self.fc2(out)        out = self.softmax(out)        return outif __name__ == '__main__':    iris = load_iris()    x, y = shuffle(iris.data,iris.target)    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=33)    net = Net()    criterion = nn.CrossEntropyLoss()    optimizer = optim.Adam(net.parameters(), lr=lr)    # train    for epoch in range(num_epoch):        x = torch.Tensor(x_train).float()        y = torch.Tensor(y_train).long()        optimizer.zero_grad()        y_pred = net(x)        loss = criterion(y_pred, y)        loss.backward()        optimizer.step()        if epoch % 50 is 0:            print(loss) # cross entropy    # test    x = torch.Tensor(x_test).float()    y = torch.Tensor(y_test).long()    y_pred = net(x)    _, predicted = torch.max(y_pred, 1)    acc = torch.sum(y == predicted).numpy() / len(x_test)    print(acc)

Result

实验了 4 次,准确率分别为 97.37%、92.11% 、97.37% 和 94.74%,平均准确率为 95.40%。

转载于:https://www.cnblogs.com/fengyubo/p/9141130.html

你可能感兴趣的文章
echart.js的使用
查看>>
自己动手写一个单链表
查看>>
常用正则表达式
查看>>
PHP 重置数组为连续数字索引的几种方式
查看>>
160809308周子济第六次作业
查看>>
大型Web应用运行时 PHP负载均衡指南
查看>>
为phpStorm 配置PHP_CodeSniffer自动检查代码
查看>>
测试工具网址大全(转)
查看>>
ServiceStack DotNet Core前期准备
查看>>
webpack中‘vant’全局引入和按需引入【vue-cli】
查看>>
Date、String和Timestamp类型转换
查看>>
计算机的组成
查看>>
关于render函数的总结
查看>>
JavaScript 小刮号
查看>>
Android为TV端助力 Linux命令查看包名类名
查看>>
[简单到爆]eclipse-jee-neon的下载和安装
查看>>
vector
查看>>
Redis学习之set类型总结
查看>>
栈和队列
查看>>
CSS2-3常见的demo列子总结一
查看>>