from flask import Flask, request, jsonify, render_template import numpy as np import io import base64 from PIL import Image, ImageOps import pandas as pd import matplotlib.pyplot as plt import torch from torch import nn from torchvision import transforms app = Flask(__name__) # Load the trained model class MNISTNN(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = MNISTNN() model.load_state_dict(torch.load('mnist_model_2.pth')) model.eval() @app.route('/') def index(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): data = request.get_json() image_data = data['image'].split(",")[1] # Decode the image image = Image.open(io.BytesIO(base64.b64decode(image_data))) _, _, _, image = image.split() image = image.resize((28, 28)) # Resize to 28x28 #transform = transforms.Compose([transforms.PILToTensor()]) #image_tensor = transform(image) image_tensor = transforms.ToTensor()(image) image_tensor = image_tensor.unsqueeze(0) # Add batch dimension # Predict using the model with torch.no_grad(): output = model(image_tensor) _, predicted = torch.max(output, 1) return jsonify(prediction=predicted.item()) if __name__ == '__main__': app.run(debug=True)