From 2bb2afa0e626c81447289d06e7e577b69e2d39bc Mon Sep 17 00:00:00 2001 From: Rubbit Date: Thu, 19 Sep 2024 12:57:16 +0800 Subject: [PATCH] first commit --- predict.py | 54 +++++++++++++++++++++++++++++++++++++++++++ train.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 predict.py create mode 100644 train.py diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..7a2d999 --- /dev/null +++ b/predict.py @@ -0,0 +1,54 @@ +from flask import Flask, request, jsonify, render_template +import numpy as np +import io +import base64 +from PIL import Image +import torch +from torch import nn +from torchvision import transforms + +app = Flask(__name__) + +# Load the trained model +class SimpleNN(nn.Module): + def __init__(self): + super(SimpleNN, self).__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + +model = SimpleNN() +model.load_state_dict(torch.load('mnist_model.pth')) +model.eval() + +@app.route('/') +def index(): + return render_template('index.html') + +@app.route('/predict', methods=['POST']) +def predict(): + data = request.get_json() + image_data = data['image'].split(",")[1] + + # Decode the image + image = Image.open(io.BytesIO(base64.b64decode(image_data))) + image = image.convert('L') # Convert to grayscale + image = image.resize((28, 28)) # Resize to 28x28 + image = transforms.ToTensor()(image) + image = image.unsqueeze(0) # Add batch dimension + # Predict using the model + with torch.no_grad(): + output = model(image) + _, predicted = torch.max(output, 1) + + return jsonify(prediction=predicted.item()) + +if __name__ == '__main__': + app.run(debug=True) diff --git a/train.py b/train.py new file mode 100644 index 0000000..64e1d52 --- /dev/null +++ b/train.py @@ -0,0 +1,67 @@ +import os +import torch +import torchvision +from torchvision import datasets, transforms +from torch import nn, optim + +# Set device +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# Define transformations +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) +]) + +# Load MNIST dataset +train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform) +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) + +# Define the neural network model +class SimpleNN(nn.Module): + def __init__(self): + super(SimpleNN, self).__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + +# Instantiate the model +model = SimpleNN().to(device) + +# Define loss function and optimizer +criterion = nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01) + +# Check if the model already exists +model_path = 'mnist_model.pth' +start_epoch = 0 +if os.path.isfile(model_path): + model.load_state_dict(torch.load(model_path)) + print("Loaded existing model.") + # Optionally load the epoch if you save that too + # start_epoch = + +# Train the model +epochs = 20 +for epoch in range(start_epoch, start_epoch + epochs): + running_loss = 0 + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() + print(f"Epoch {epoch + 1}/{start_epoch + epochs}, Loss: {running_loss/len(train_loader)}") + +# Save the trained model +torch.save(model.state_dict(), model_path) +print("Model saved!")