preject completement
This commit is contained in:
parent
2bb2afa0e6
commit
0d6ee20b05
Binary file not shown.
40
predict.py
40
predict.py
|
@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template
|
|||
import numpy as np
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
@ -10,22 +12,22 @@ from torchvision import transforms
|
|||
app = Flask(__name__)
|
||||
|
||||
# Load the trained model
|
||||
class SimpleNN(nn.Module):
|
||||
class MNISTNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
super().__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28 * 28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10)
|
||||
)
|
||||
def forward(self, x):
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
x = self.flatten(x)
|
||||
logits = self.linear_relu_stack(x)
|
||||
return logits
|
||||
|
||||
model = SimpleNN()
|
||||
model.load_state_dict(torch.load('mnist_model.pth'))
|
||||
model = MNISTNN()
|
||||
model.load_state_dict(torch.load('mnist_model_2.pth'))
|
||||
model.eval()
|
||||
|
||||
@app.route('/')
|
||||
|
@ -39,13 +41,15 @@ def predict():
|
|||
|
||||
# Decode the image
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||
image = image.convert('L') # Convert to grayscale
|
||||
_, _, _, image = image.split()
|
||||
image = image.resize((28, 28)) # Resize to 28x28
|
||||
image = transforms.ToTensor()(image)
|
||||
image = image.unsqueeze(0) # Add batch dimension
|
||||
#transform = transforms.Compose([transforms.PILToTensor()])
|
||||
#image_tensor = transform(image)
|
||||
image_tensor = transforms.ToTensor()(image)
|
||||
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
||||
# Predict using the model
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
output = model(image_tensor)
|
||||
_, predicted = torch.max(output, 1)
|
||||
|
||||
return jsonify(prediction=predicted.item())
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Handwritten Digit Input</title>
|
||||
<style>
|
||||
canvas {
|
||||
border: 1px solid black;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Draw a Handwritten Digit</h1>
|
||||
<canvas id="canvas" width="200" height="200"></canvas>
|
||||
<br>
|
||||
<button id="clear">Clear</button>
|
||||
<button id="submit">Submit</button>
|
||||
|
||||
<script>
|
||||
const canvas = document.getElementById('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
let drawing = false;
|
||||
|
||||
canvas.addEventListener('mousedown', () => {
|
||||
drawing = true;
|
||||
});
|
||||
|
||||
canvas.addEventListener('mouseup', () => {
|
||||
drawing = false;
|
||||
ctx.beginPath();
|
||||
});
|
||||
|
||||
canvas.addEventListener('mousemove', (event) => {
|
||||
if (!drawing) return;
|
||||
ctx.lineWidth = 15;
|
||||
ctx.lineCap = 'round';
|
||||
ctx.strokeStyle = 'black';
|
||||
|
||||
ctx.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
|
||||
ctx.stroke();
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
|
||||
});
|
||||
|
||||
document.getElementById('clear').addEventListener('click', () => {
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
});
|
||||
|
||||
document.getElementById('submit').addEventListener('click', () => {
|
||||
const dataURL = canvas.toDataURL('image/png');
|
||||
fetch('/predict', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ image: dataURL }),
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
alert('Predicted digit: ' + data.prediction);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
42
train.py
42
train.py
|
@ -4,13 +4,18 @@ import torchvision
|
|||
from torchvision import datasets, transforms
|
||||
from torch import nn, optim
|
||||
|
||||
lr = 0.8
|
||||
bs = 64
|
||||
epochs = 50
|
||||
|
||||
|
||||
# Set device
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Define transformations
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
#transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
|
||||
# Load MNIST dataset
|
||||
|
@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform
|
|||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||
|
||||
# Define the neural network model
|
||||
class SimpleNN(nn.Module):
|
||||
class MNISTNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
super().__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28 * 28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10)
|
||||
)
|
||||
def forward(self, x):
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
x = self.flatten(x)
|
||||
logits = self.linear_relu_stack(x)
|
||||
return logits
|
||||
|
||||
# Instantiate the model
|
||||
model = SimpleNN().to(device)
|
||||
model = MNISTNN().to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
# Check if the model already exists
|
||||
model_path = 'mnist_model.pth'
|
||||
model_path = 'mnist_model_2.pth'
|
||||
start_epoch = 0
|
||||
if os.path.isfile(model_path):
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
@ -49,14 +54,13 @@ if os.path.isfile(model_path):
|
|||
# start_epoch = <load_saved_epoch>
|
||||
|
||||
# Train the model
|
||||
epochs = 20
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
running_loss = 0
|
||||
for images, labels in train_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(images)
|
||||
loss = criterion(output, labels)
|
||||
preds = model(images)
|
||||
loss = loss_fn(preds, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
|
|
Loading…
Reference in New Issue