MNIST-NN-WebGUI/predict.py

59 lines
1.6 KiB
Python

from flask import Flask, request, jsonify, render_template
import numpy as np
import io
import base64
from PIL import Image, ImageOps
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import transforms
app = Flask(__name__)
# Load the trained model
class MNISTNN(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = MNISTNN()
model.load_state_dict(torch.load('mnist_model_2.pth'))
model.eval()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
image_data = data['image'].split(",")[1]
# Decode the image
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
_, _, _, image = image.split()
image = image.resize((28, 28)) # Resize to 28x28
#transform = transforms.Compose([transforms.PILToTensor()])
#image_tensor = transform(image)
image_tensor = transforms.ToTensor()(image)
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Predict using the model
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return jsonify(prediction=predicted.item())
if __name__ == '__main__':
app.run(debug=True)