55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
|
from flask import Flask, request, jsonify, render_template
|
||
|
import numpy as np
|
||
|
import io
|
||
|
import base64
|
||
|
from PIL import Image
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torchvision import transforms
|
||
|
|
||
|
app = Flask(__name__)
|
||
|
|
||
|
# Load the trained model
|
||
|
class SimpleNN(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(SimpleNN, self).__init__()
|
||
|
self.fc1 = nn.Linear(28 * 28, 128)
|
||
|
self.fc2 = nn.Linear(128, 64)
|
||
|
self.fc3 = nn.Linear(64, 10)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x.view(x.shape[0], -1)
|
||
|
x = torch.relu(self.fc1(x))
|
||
|
x = torch.relu(self.fc2(x))
|
||
|
x = self.fc3(x)
|
||
|
return x
|
||
|
|
||
|
model = SimpleNN()
|
||
|
model.load_state_dict(torch.load('mnist_model.pth'))
|
||
|
model.eval()
|
||
|
|
||
|
@app.route('/')
|
||
|
def index():
|
||
|
return render_template('index.html')
|
||
|
|
||
|
@app.route('/predict', methods=['POST'])
|
||
|
def predict():
|
||
|
data = request.get_json()
|
||
|
image_data = data['image'].split(",")[1]
|
||
|
|
||
|
# Decode the image
|
||
|
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||
|
image = image.convert('L') # Convert to grayscale
|
||
|
image = image.resize((28, 28)) # Resize to 28x28
|
||
|
image = transforms.ToTensor()(image)
|
||
|
image = image.unsqueeze(0) # Add batch dimension
|
||
|
# Predict using the model
|
||
|
with torch.no_grad():
|
||
|
output = model(image)
|
||
|
_, predicted = torch.max(output, 1)
|
||
|
|
||
|
return jsonify(prediction=predicted.item())
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
app.run(debug=True)
|