first commit
This commit is contained in:
commit
2bb2afa0e6
|
@ -0,0 +1,54 @@
|
|||
from flask import Flask, request, jsonify, render_template
|
||||
import numpy as np
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Load the trained model
|
||||
class SimpleNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
model = SimpleNN()
|
||||
model.load_state_dict(torch.load('mnist_model.pth'))
|
||||
model.eval()
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
data = request.get_json()
|
||||
image_data = data['image'].split(",")[1]
|
||||
|
||||
# Decode the image
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||
image = image.convert('L') # Convert to grayscale
|
||||
image = image.resize((28, 28)) # Resize to 28x28
|
||||
image = transforms.ToTensor()(image)
|
||||
image = image.unsqueeze(0) # Add batch dimension
|
||||
# Predict using the model
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output, 1)
|
||||
|
||||
return jsonify(prediction=predicted.item())
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
from torch import nn, optim
|
||||
|
||||
# Set device
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Define transformations
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
|
||||
# Load MNIST dataset
|
||||
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||
|
||||
# Define the neural network model
|
||||
class SimpleNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
# Instantiate the model
|
||||
model = SimpleNN().to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# Check if the model already exists
|
||||
model_path = 'mnist_model.pth'
|
||||
start_epoch = 0
|
||||
if os.path.isfile(model_path):
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
print("Loaded existing model.")
|
||||
# Optionally load the epoch if you save that too
|
||||
# start_epoch = <load_saved_epoch>
|
||||
|
||||
# Train the model
|
||||
epochs = 20
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
running_loss = 0
|
||||
for images, labels in train_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(images)
|
||||
loss = criterion(output, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
print(f"Epoch {epoch + 1}/{start_epoch + epochs}, Loss: {running_loss/len(train_loader)}")
|
||||
|
||||
# Save the trained model
|
||||
torch.save(model.state_dict(), model_path)
|
||||
print("Model saved!")
|
Loading…
Reference in New Issue