大家好,欢迎来到IT知识分享网。
Pytorch学习记录-逻辑回归
1. 引入必须库&设定超参数
一样的套路
import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn # 超参数 input_size = 784 num_classes = 10 num_epochs = 5 batch_size = 100 learning_rate = 0.01
2. 获取数据和加载数据
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
3. 构建逻辑回归模型
这里有一个问题,为什么使用Linear之后没有用softmax?
答案就在损失函数,这里的损失函数使用的是CrossEntropyLoss(),多分类用的交叉熵损失函数,用这个 loss 前面不需要加 Softmax 层。
我重新写了一个Model类,但是使用MSELoss等损失函数都会报错
class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.linear = nn.Linear(input_size, num_classes) self.sigmoid = nn.Sigmoid() def forward(self, x): y_pred = self.sigmoid(self.linear(x)) return y_pred model = Model() criterion = nn.MSELoss() # model = nn.Linear(input_size, num_classes) # criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
4. 训练模型
total_step = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): images = images.reshape(-1, 28 * 28) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}] ,Loss:{:.5f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
5. 测试模型并保存模型
with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.reshape(-1, 28 * 28) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) torch.save(model.state_dict(),'LogisticModel.ckpt')
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/61203.html