72 lines
2.0 KiB
Python
72 lines
2.0 KiB
Python
import os
|
|
import torch
|
|
import torchvision
|
|
from torchvision import datasets, transforms
|
|
from torch import nn, optim
|
|
|
|
lr = 0.8
|
|
bs = 64
|
|
epochs = 50
|
|
|
|
|
|
# Set device
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
# Define transformations
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
#transforms.Normalize((0.5,), (0.5,))
|
|
])
|
|
|
|
# Load MNIST dataset
|
|
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
|
|
# Define the neural network model
|
|
class MNISTNN(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.flatten = nn.Flatten()
|
|
self.linear_relu_stack = nn.Sequential(
|
|
nn.Linear(28 * 28, 512),
|
|
nn.ReLU(),
|
|
nn.Linear(512, 10)
|
|
)
|
|
def forward(self, x):
|
|
x = self.flatten(x)
|
|
logits = self.linear_relu_stack(x)
|
|
return logits
|
|
|
|
# Instantiate the model
|
|
model = MNISTNN().to(device)
|
|
|
|
# Define loss function and optimizer
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=lr)
|
|
|
|
# Check if the model already exists
|
|
model_path = 'mnist_model_2.pth'
|
|
start_epoch = 0
|
|
if os.path.isfile(model_path):
|
|
model.load_state_dict(torch.load(model_path))
|
|
print("Loaded existing model.")
|
|
# Optionally load the epoch if you save that too
|
|
# start_epoch = <load_saved_epoch>
|
|
|
|
# Train the model
|
|
for epoch in range(start_epoch, start_epoch + epochs):
|
|
running_loss = 0
|
|
for images, labels in train_loader:
|
|
images, labels = images.to(device), labels.to(device)
|
|
optimizer.zero_grad()
|
|
preds = model(images)
|
|
loss = loss_fn(preds, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
running_loss += loss.item()
|
|
print(f"Epoch {epoch + 1}/{start_epoch + epochs}, Loss: {running_loss/len(train_loader)}")
|
|
|
|
# Save the trained model
|
|
torch.save(model.state_dict(), model_path)
|
|
print("Model saved!")
|