diff --git a/mnist_model_2.pth b/mnist_model_2.pth
new file mode 100644
index 0000000..b211b9b
Binary files /dev/null and b/mnist_model_2.pth differ
diff --git a/predict.py b/predict.py
index 7a2d999..b53017b 100644
--- a/predict.py
+++ b/predict.py
@@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template
import numpy as np
import io
import base64
-from PIL import Image
+from PIL import Image, ImageOps
+import pandas as pd
+import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import transforms
@@ -10,22 +12,22 @@ from torchvision import transforms
app = Flask(__name__)
# Load the trained model
-class SimpleNN(nn.Module):
+class MNISTNN(nn.Module):
def __init__(self):
- super(SimpleNN, self).__init__()
- self.fc1 = nn.Linear(28 * 28, 128)
- self.fc2 = nn.Linear(128, 64)
- self.fc3 = nn.Linear(64, 10)
-
+ super().__init__()
+ self.flatten = nn.Flatten()
+ self.linear_relu_stack = nn.Sequential(
+ nn.Linear(28 * 28, 512),
+ nn.ReLU(),
+ nn.Linear(512, 10)
+ )
def forward(self, x):
- x = x.view(x.shape[0], -1)
- x = torch.relu(self.fc1(x))
- x = torch.relu(self.fc2(x))
- x = self.fc3(x)
- return x
+ x = self.flatten(x)
+ logits = self.linear_relu_stack(x)
+ return logits
-model = SimpleNN()
-model.load_state_dict(torch.load('mnist_model.pth'))
+model = MNISTNN()
+model.load_state_dict(torch.load('mnist_model_2.pth'))
model.eval()
@app.route('/')
@@ -39,13 +41,15 @@ def predict():
# Decode the image
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
- image = image.convert('L') # Convert to grayscale
+ _, _, _, image = image.split()
image = image.resize((28, 28)) # Resize to 28x28
- image = transforms.ToTensor()(image)
- image = image.unsqueeze(0) # Add batch dimension
+ #transform = transforms.Compose([transforms.PILToTensor()])
+ #image_tensor = transform(image)
+ image_tensor = transforms.ToTensor()(image)
+ image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Predict using the model
with torch.no_grad():
- output = model(image)
+ output = model(image_tensor)
_, predicted = torch.max(output, 1)
return jsonify(prediction=predicted.item())
diff --git a/templates/index.html b/templates/index.html
new file mode 100644
index 0000000..1aa5316
--- /dev/null
+++ b/templates/index.html
@@ -0,0 +1,64 @@
+
+
+
+
+
+ Handwritten Digit Input
+
+
+
+ Draw a Handwritten Digit
+
+
+
+
+
+
+
+
diff --git a/train.py b/train.py
index 64e1d52..037f166 100644
--- a/train.py
+++ b/train.py
@@ -4,13 +4,18 @@ import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
+lr = 0.8
+bs = 64
+epochs = 50
+
+
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,))
+ #transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
@@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Define the neural network model
-class SimpleNN(nn.Module):
+class MNISTNN(nn.Module):
def __init__(self):
- super(SimpleNN, self).__init__()
- self.fc1 = nn.Linear(28 * 28, 128)
- self.fc2 = nn.Linear(128, 64)
- self.fc3 = nn.Linear(64, 10)
-
+ super().__init__()
+ self.flatten = nn.Flatten()
+ self.linear_relu_stack = nn.Sequential(
+ nn.Linear(28 * 28, 512),
+ nn.ReLU(),
+ nn.Linear(512, 10)
+ )
def forward(self, x):
- x = x.view(x.shape[0], -1)
- x = torch.relu(self.fc1(x))
- x = torch.relu(self.fc2(x))
- x = self.fc3(x)
- return x
+ x = self.flatten(x)
+ logits = self.linear_relu_stack(x)
+ return logits
# Instantiate the model
-model = SimpleNN().to(device)
+model = MNISTNN().to(device)
# Define loss function and optimizer
-criterion = nn.CrossEntropyLoss()
-optimizer = optim.SGD(model.parameters(), lr=0.01)
+loss_fn = nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=lr)
# Check if the model already exists
-model_path = 'mnist_model.pth'
+model_path = 'mnist_model_2.pth'
start_epoch = 0
if os.path.isfile(model_path):
model.load_state_dict(torch.load(model_path))
@@ -49,14 +54,13 @@ if os.path.isfile(model_path):
# start_epoch =
# Train the model
-epochs = 20
for epoch in range(start_epoch, start_epoch + epochs):
running_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
- output = model(images)
- loss = criterion(output, labels)
+ preds = model(images)
+ loss = loss_fn(preds, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()