preject completement

This commit is contained in:
Rubbit 2024-09-19 16:43:50 +08:00
parent 2bb2afa0e6
commit 0d6ee20b05
4 changed files with 109 additions and 37 deletions

BIN
mnist_model_2.pth Normal file

Binary file not shown.

View File

@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template
import numpy as np import numpy as np
import io import io
import base64 import base64
from PIL import Image from PIL import Image, ImageOps
import pandas as pd
import matplotlib.pyplot as plt
import torch import torch
from torch import nn from torch import nn
from torchvision import transforms from torchvision import transforms
@ -10,22 +12,22 @@ from torchvision import transforms
app = Flask(__name__) app = Flask(__name__)
# Load the trained model # Load the trained model
class SimpleNN(nn.Module): class MNISTNN(nn.Module):
def __init__(self): def __init__(self):
super(SimpleNN, self).__init__() super().__init__()
self.fc1 = nn.Linear(28 * 28, 128) self.flatten = nn.Flatten()
self.fc2 = nn.Linear(128, 64) self.linear_relu_stack = nn.Sequential(
self.fc3 = nn.Linear(64, 10) nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x): def forward(self, x):
x = x.view(x.shape[0], -1) x = self.flatten(x)
x = torch.relu(self.fc1(x)) logits = self.linear_relu_stack(x)
x = torch.relu(self.fc2(x)) return logits
x = self.fc3(x)
return x
model = SimpleNN() model = MNISTNN()
model.load_state_dict(torch.load('mnist_model.pth')) model.load_state_dict(torch.load('mnist_model_2.pth'))
model.eval() model.eval()
@app.route('/') @app.route('/')
@ -39,13 +41,15 @@ def predict():
# Decode the image # Decode the image
image = Image.open(io.BytesIO(base64.b64decode(image_data))) image = Image.open(io.BytesIO(base64.b64decode(image_data)))
image = image.convert('L') # Convert to grayscale _, _, _, image = image.split()
image = image.resize((28, 28)) # Resize to 28x28 image = image.resize((28, 28)) # Resize to 28x28
image = transforms.ToTensor()(image) #transform = transforms.Compose([transforms.PILToTensor()])
image = image.unsqueeze(0) # Add batch dimension #image_tensor = transform(image)
image_tensor = transforms.ToTensor()(image)
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Predict using the model # Predict using the model
with torch.no_grad(): with torch.no_grad():
output = model(image) output = model(image_tensor)
_, predicted = torch.max(output, 1) _, predicted = torch.max(output, 1)
return jsonify(prediction=predicted.item()) return jsonify(prediction=predicted.item())

64
templates/index.html Normal file
View File

@ -0,0 +1,64 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Handwritten Digit Input</title>
<style>
canvas {
border: 1px solid black;
}
</style>
</head>
<body>
<h1>Draw a Handwritten Digit</h1>
<canvas id="canvas" width="200" height="200"></canvas>
<br>
<button id="clear">Clear</button>
<button id="submit">Submit</button>
<script>
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
let drawing = false;
canvas.addEventListener('mousedown', () => {
drawing = true;
});
canvas.addEventListener('mouseup', () => {
drawing = false;
ctx.beginPath();
});
canvas.addEventListener('mousemove', (event) => {
if (!drawing) return;
ctx.lineWidth = 15;
ctx.lineCap = 'round';
ctx.strokeStyle = 'black';
ctx.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
ctx.stroke();
ctx.beginPath();
ctx.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
});
document.getElementById('clear').addEventListener('click', () => {
ctx.clearRect(0, 0, canvas.width, canvas.height);
});
document.getElementById('submit').addEventListener('click', () => {
const dataURL = canvas.toDataURL('image/png');
fetch('/predict', {
method: 'POST',
body: JSON.stringify({ image: dataURL }),
headers: { 'Content-Type': 'application/json' }
})
.then(response => response.json())
.then(data => {
alert('Predicted digit: ' + data.prediction);
});
});
</script>
</body>
</html>

View File

@ -4,13 +4,18 @@ import torchvision
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch import nn, optim from torch import nn, optim
lr = 0.8
bs = 64
epochs = 50
# Set device # Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Define transformations # Define transformations
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) #transforms.Normalize((0.5,), (0.5,))
]) ])
# Load MNIST dataset # Load MNIST dataset
@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Define the neural network model # Define the neural network model
class SimpleNN(nn.Module): class MNISTNN(nn.Module):
def __init__(self): def __init__(self):
super(SimpleNN, self).__init__() super().__init__()
self.fc1 = nn.Linear(28 * 28, 128) self.flatten = nn.Flatten()
self.fc2 = nn.Linear(128, 64) self.linear_relu_stack = nn.Sequential(
self.fc3 = nn.Linear(64, 10) nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x): def forward(self, x):
x = x.view(x.shape[0], -1) x = self.flatten(x)
x = torch.relu(self.fc1(x)) logits = self.linear_relu_stack(x)
x = torch.relu(self.fc2(x)) return logits
x = self.fc3(x)
return x
# Instantiate the model # Instantiate the model
model = SimpleNN().to(device) model = MNISTNN().to(device)
# Define loss function and optimizer # Define loss function and optimizer
criterion = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01) optimizer = optim.SGD(model.parameters(), lr=lr)
# Check if the model already exists # Check if the model already exists
model_path = 'mnist_model.pth' model_path = 'mnist_model_2.pth'
start_epoch = 0 start_epoch = 0
if os.path.isfile(model_path): if os.path.isfile(model_path):
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
@ -49,14 +54,13 @@ if os.path.isfile(model_path):
# start_epoch = <load_saved_epoch> # start_epoch = <load_saved_epoch>
# Train the model # Train the model
epochs = 20
for epoch in range(start_epoch, start_epoch + epochs): for epoch in range(start_epoch, start_epoch + epochs):
running_loss = 0 running_loss = 0
for images, labels in train_loader: for images, labels in train_loader:
images, labels = images.to(device), labels.to(device) images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() optimizer.zero_grad()
output = model(images) preds = model(images)
loss = criterion(output, labels) loss = loss_fn(preds, labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
running_loss += loss.item() running_loss += loss.item()