preject completement
This commit is contained in:
parent
2bb2afa0e6
commit
0d6ee20b05
Binary file not shown.
40
predict.py
40
predict.py
|
@ -2,7 +2,9 @@ from flask import Flask, request, jsonify, render_template
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -10,22 +12,22 @@ from torchvision import transforms
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
class SimpleNN(nn.Module):
|
class MNISTNN(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SimpleNN, self).__init__()
|
super().__init__()
|
||||||
self.fc1 = nn.Linear(28 * 28, 128)
|
self.flatten = nn.Flatten()
|
||||||
self.fc2 = nn.Linear(128, 64)
|
self.linear_relu_stack = nn.Sequential(
|
||||||
self.fc3 = nn.Linear(64, 10)
|
nn.Linear(28 * 28, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, 10)
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.view(x.shape[0], -1)
|
x = self.flatten(x)
|
||||||
x = torch.relu(self.fc1(x))
|
logits = self.linear_relu_stack(x)
|
||||||
x = torch.relu(self.fc2(x))
|
return logits
|
||||||
x = self.fc3(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
model = SimpleNN()
|
model = MNISTNN()
|
||||||
model.load_state_dict(torch.load('mnist_model.pth'))
|
model.load_state_dict(torch.load('mnist_model_2.pth'))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
@ -39,13 +41,15 @@ def predict():
|
||||||
|
|
||||||
# Decode the image
|
# Decode the image
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||||
image = image.convert('L') # Convert to grayscale
|
_, _, _, image = image.split()
|
||||||
image = image.resize((28, 28)) # Resize to 28x28
|
image = image.resize((28, 28)) # Resize to 28x28
|
||||||
image = transforms.ToTensor()(image)
|
#transform = transforms.Compose([transforms.PILToTensor()])
|
||||||
image = image.unsqueeze(0) # Add batch dimension
|
#image_tensor = transform(image)
|
||||||
|
image_tensor = transforms.ToTensor()(image)
|
||||||
|
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
||||||
# Predict using the model
|
# Predict using the model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(image)
|
output = model(image_tensor)
|
||||||
_, predicted = torch.max(output, 1)
|
_, predicted = torch.max(output, 1)
|
||||||
|
|
||||||
return jsonify(prediction=predicted.item())
|
return jsonify(prediction=predicted.item())
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Handwritten Digit Input</title>
|
||||||
|
<style>
|
||||||
|
canvas {
|
||||||
|
border: 1px solid black;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Draw a Handwritten Digit</h1>
|
||||||
|
<canvas id="canvas" width="200" height="200"></canvas>
|
||||||
|
<br>
|
||||||
|
<button id="clear">Clear</button>
|
||||||
|
<button id="submit">Submit</button>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const canvas = document.getElementById('canvas');
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
let drawing = false;
|
||||||
|
|
||||||
|
canvas.addEventListener('mousedown', () => {
|
||||||
|
drawing = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
canvas.addEventListener('mouseup', () => {
|
||||||
|
drawing = false;
|
||||||
|
ctx.beginPath();
|
||||||
|
});
|
||||||
|
|
||||||
|
canvas.addEventListener('mousemove', (event) => {
|
||||||
|
if (!drawing) return;
|
||||||
|
ctx.lineWidth = 15;
|
||||||
|
ctx.lineCap = 'round';
|
||||||
|
ctx.strokeStyle = 'black';
|
||||||
|
|
||||||
|
ctx.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
|
||||||
|
ctx.stroke();
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('clear').addEventListener('click', () => {
|
||||||
|
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('submit').addEventListener('click', () => {
|
||||||
|
const dataURL = canvas.toDataURL('image/png');
|
||||||
|
fetch('/predict', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ image: dataURL }),
|
||||||
|
headers: { 'Content-Type': 'application/json' }
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
alert('Predicted digit: ' + data.prediction);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
42
train.py
42
train.py
|
@ -4,13 +4,18 @@ import torchvision
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
|
lr = 0.8
|
||||||
|
bs = 64
|
||||||
|
epochs = 50
|
||||||
|
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
# Define transformations
|
# Define transformations
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5,), (0.5,))
|
#transforms.Normalize((0.5,), (0.5,))
|
||||||
])
|
])
|
||||||
|
|
||||||
# Load MNIST dataset
|
# Load MNIST dataset
|
||||||
|
@ -18,29 +23,29 @@ train_dataset = datasets.MNIST(root='data', train=True, download=True, transform
|
||||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
# Define the neural network model
|
# Define the neural network model
|
||||||
class SimpleNN(nn.Module):
|
class MNISTNN(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SimpleNN, self).__init__()
|
super().__init__()
|
||||||
self.fc1 = nn.Linear(28 * 28, 128)
|
self.flatten = nn.Flatten()
|
||||||
self.fc2 = nn.Linear(128, 64)
|
self.linear_relu_stack = nn.Sequential(
|
||||||
self.fc3 = nn.Linear(64, 10)
|
nn.Linear(28 * 28, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, 10)
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.view(x.shape[0], -1)
|
x = self.flatten(x)
|
||||||
x = torch.relu(self.fc1(x))
|
logits = self.linear_relu_stack(x)
|
||||||
x = torch.relu(self.fc2(x))
|
return logits
|
||||||
x = self.fc3(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# Instantiate the model
|
# Instantiate the model
|
||||||
model = SimpleNN().to(device)
|
model = MNISTNN().to(device)
|
||||||
|
|
||||||
# Define loss function and optimizer
|
# Define loss function and optimizer
|
||||||
criterion = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
optimizer = optim.SGD(model.parameters(), lr=lr)
|
||||||
|
|
||||||
# Check if the model already exists
|
# Check if the model already exists
|
||||||
model_path = 'mnist_model.pth'
|
model_path = 'mnist_model_2.pth'
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
model.load_state_dict(torch.load(model_path))
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
@ -49,14 +54,13 @@ if os.path.isfile(model_path):
|
||||||
# start_epoch = <load_saved_epoch>
|
# start_epoch = <load_saved_epoch>
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
epochs = 20
|
|
||||||
for epoch in range(start_epoch, start_epoch + epochs):
|
for epoch in range(start_epoch, start_epoch + epochs):
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
for images, labels in train_loader:
|
for images, labels in train_loader:
|
||||||
images, labels = images.to(device), labels.to(device)
|
images, labels = images.to(device), labels.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(images)
|
preds = model(images)
|
||||||
loss = criterion(output, labels)
|
loss = loss_fn(preds, labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
Loading…
Reference in New Issue