preject completement

This commit is contained in:
Rubbit 2024-09-19 16:43:50 +08:00
parent 2bb2afa0e6
commit 0d6ee20b05
4 changed files with 109 additions and 37 deletions

BIN
mnist_model_2.pth Normal file

Binary file not shown.

View File

@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template
import numpy as np
import io
import base64
from PIL import Image
from PIL import Image, ImageOps
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import transforms
@ -10,22 +12,22 @@ from torchvision import transforms
app = Flask(__name__)
# Load the trained model
class SimpleNN(nn.Module):
class MNISTNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = x.view(x.shape[0], -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = SimpleNN()
model.load_state_dict(torch.load('mnist_model.pth'))
model = MNISTNN()
model.load_state_dict(torch.load('mnist_model_2.pth'))
model.eval()
@app.route('/')
@ -39,13 +41,15 @@ def predict():
# Decode the image
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
image = image.convert('L') # Convert to grayscale
_, _, _, image = image.split()
image = image.resize((28, 28)) # Resize to 28x28
image = transforms.ToTensor()(image)
image = image.unsqueeze(0) # Add batch dimension
#transform = transforms.Compose([transforms.PILToTensor()])
#image_tensor = transform(image)
image_tensor = transforms.ToTensor()(image)
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Predict using the model
with torch.no_grad():
output = model(image)
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return jsonify(prediction=predicted.item())

64
templates/index.html Normal file
View File

@ -0,0 +1,64 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Handwritten Digit Input</title>
<style>
canvas {
border: 1px solid black;
}
</style>
</head>
<body>
<h1>Draw a Handwritten Digit</h1>
<canvas id="canvas" width="200" height="200"></canvas>
<br>
<button id="clear">Clear</button>
<button id="submit">Submit</button>
<script>
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
let drawing = false;
canvas.addEventListener('mousedown', () => {
drawing = true;
});
canvas.addEventListener('mouseup', () => {
drawing = false;
ctx.beginPath();
});
canvas.addEventListener('mousemove', (event) => {
if (!drawing) return;
ctx.lineWidth = 15;
ctx.lineCap = 'round';
ctx.strokeStyle = 'black';
ctx.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
ctx.stroke();
ctx.beginPath();
ctx.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
});
document.getElementById('clear').addEventListener('click', () => {
ctx.clearRect(0, 0, canvas.width, canvas.height);
});
document.getElementById('submit').addEventListener('click', () => {
const dataURL = canvas.toDataURL('image/png');
fetch('/predict', {
method: 'POST',
body: JSON.stringify({ image: dataURL }),
headers: { 'Content-Type': 'application/json' }
})
.then(response => response.json())
.then(data => {
alert('Predicted digit: ' + data.prediction);
});
});
</script>
</body>
</html>

View File

@ -4,13 +4,18 @@ import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
lr = 0.8
bs = 64
epochs = 50
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
#transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Define the neural network model
class SimpleNN(nn.Module):
class MNISTNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = x.view(x.shape[0], -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# Instantiate the model
model = SimpleNN().to(device)
model = MNISTNN().to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
# Check if the model already exists
model_path = 'mnist_model.pth'
model_path = 'mnist_model_2.pth'
start_epoch = 0
if os.path.isfile(model_path):
model.load_state_dict(torch.load(model_path))
@ -49,14 +54,13 @@ if os.path.isfile(model_path):
# start_epoch = <load_saved_epoch>
# Train the model
epochs = 20
for epoch in range(start_epoch, start_epoch + epochs):
running_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
preds = model(images)
loss = loss_fn(preds, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()