MNIST-NN-WebGUI/predict.py

59 lines
1.6 KiB
Python
Raw Permalink Normal View History

2024-09-19 12:57:16 +08:00
from flask import Flask, request, jsonify, render_template
import numpy as np
import io
import base64
2024-09-19 16:43:50 +08:00
from PIL import Image, ImageOps
import pandas as pd
import matplotlib.pyplot as plt
2024-09-19 12:57:16 +08:00
import torch
from torch import nn
from torchvision import transforms
app = Flask(__name__)
# Load the trained model
2024-09-19 16:43:50 +08:00
class MNISTNN(nn.Module):
2024-09-19 12:57:16 +08:00
def __init__(self):
2024-09-19 16:43:50 +08:00
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
2024-09-19 12:57:16 +08:00
def forward(self, x):
2024-09-19 16:43:50 +08:00
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
2024-09-19 12:57:16 +08:00
2024-09-19 16:43:50 +08:00
model = MNISTNN()
model.load_state_dict(torch.load('mnist_model_2.pth'))
2024-09-19 12:57:16 +08:00
model.eval()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
image_data = data['image'].split(",")[1]
# Decode the image
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
2024-09-19 16:43:50 +08:00
_, _, _, image = image.split()
2024-09-19 12:57:16 +08:00
image = image.resize((28, 28)) # Resize to 28x28
2024-09-19 16:43:50 +08:00
#transform = transforms.Compose([transforms.PILToTensor()])
#image_tensor = transform(image)
image_tensor = transforms.ToTensor()(image)
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
2024-09-19 12:57:16 +08:00
# Predict using the model
with torch.no_grad():
2024-09-19 16:43:50 +08:00
output = model(image_tensor)
2024-09-19 12:57:16 +08:00
_, predicted = torch.max(output, 1)
return jsonify(prediction=predicted.item())
if __name__ == '__main__':
app.run(debug=True)