diff --git a/mnist_model_2.pth b/mnist_model_2.pth new file mode 100644 index 0000000..b211b9b Binary files /dev/null and b/mnist_model_2.pth differ diff --git a/predict.py b/predict.py index 7a2d999..b53017b 100644 --- a/predict.py +++ b/predict.py @@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template import numpy as np import io import base64 -from PIL import Image +from PIL import Image, ImageOps +import pandas as pd +import matplotlib.pyplot as plt import torch from torch import nn from torchvision import transforms @@ -10,22 +12,22 @@ from torchvision import transforms app = Flask(__name__) # Load the trained model -class SimpleNN(nn.Module): +class MNISTNN(nn.Module): def __init__(self): - super(SimpleNN, self).__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) def forward(self, x): - x = x.view(x.shape[0], -1) - x = torch.relu(self.fc1(x)) - x = torch.relu(self.fc2(x)) - x = self.fc3(x) - return x + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits -model = SimpleNN() -model.load_state_dict(torch.load('mnist_model.pth')) +model = MNISTNN() +model.load_state_dict(torch.load('mnist_model_2.pth')) model.eval() @app.route('/') @@ -39,13 +41,15 @@ def predict(): # Decode the image image = Image.open(io.BytesIO(base64.b64decode(image_data))) - image = image.convert('L') # Convert to grayscale + _, _, _, image = image.split() image = image.resize((28, 28)) # Resize to 28x28 - image = transforms.ToTensor()(image) - image = image.unsqueeze(0) # Add batch dimension + #transform = transforms.Compose([transforms.PILToTensor()]) + #image_tensor = transform(image) + image_tensor = transforms.ToTensor()(image) + image_tensor = image_tensor.unsqueeze(0) # Add batch dimension # Predict using the model with torch.no_grad(): - output = model(image) + output = model(image_tensor) _, predicted = torch.max(output, 1) return jsonify(prediction=predicted.item()) diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..1aa5316 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,64 @@ + + + + + + Handwritten Digit Input + + + +

Draw a Handwritten Digit

+ +
+ + + + + + diff --git a/train.py b/train.py index 64e1d52..037f166 100644 --- a/train.py +++ b/train.py @@ -4,13 +4,18 @@ import torchvision from torchvision import datasets, transforms from torch import nn, optim +lr = 0.8 +bs = 64 +epochs = 50 + + # Set device device = 'cuda' if torch.cuda.is_available() else 'cpu' # Define transformations transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) + #transforms.Normalize((0.5,), (0.5,)) ]) # Load MNIST dataset @@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) # Define the neural network model -class SimpleNN(nn.Module): +class MNISTNN(nn.Module): def __init__(self): - super(SimpleNN, self).__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) def forward(self, x): - x = x.view(x.shape[0], -1) - x = torch.relu(self.fc1(x)) - x = torch.relu(self.fc2(x)) - x = self.fc3(x) - return x + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits # Instantiate the model -model = SimpleNN().to(device) +model = MNISTNN().to(device) # Define loss function and optimizer -criterion = nn.CrossEntropyLoss() -optimizer = optim.SGD(model.parameters(), lr=0.01) +loss_fn = nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=lr) # Check if the model already exists -model_path = 'mnist_model.pth' +model_path = 'mnist_model_2.pth' start_epoch = 0 if os.path.isfile(model_path): model.load_state_dict(torch.load(model_path)) @@ -49,14 +54,13 @@ if os.path.isfile(model_path): # start_epoch = # Train the model -epochs = 20 for epoch in range(start_epoch, start_epoch + epochs): running_loss = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) + preds = model(images) + loss = loss_fn(preds, labels) loss.backward() optimizer.step() running_loss += loss.item()