-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
86 lines (67 loc) · 3.08 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from PIL import Image
import random
import torch
import torchvision.transforms as transforms
import numpy as np
import requests
import os
@st.cache(allow_output_mutation=True)
def download_models():
models = ["VGG", "RESNET"]
for model in models:
if os.path.exists(f"{model}.pt"):
continue
url = f"https://storage.googleapis.com/model-weight/{model}.pt"
response = requests.get(url)
with open(f"{model}.pt", "wb") as f:
f.write(response.content)
download_models()
resnet = torch.load("RESNET.pt", map_location=torch.device("cpu"))
vgg = torch.load("VGG.pt", map_location=torch.device("cpu"))
st.title("CSE 881 - Road Sign Detection Project")
st.markdown("""
<div style="font-size:20px;">
Authors: Bao Hoang and Tanawan Premsri
</div>
""", unsafe_allow_html=True)
st.markdown("""
### About the Project
With the rapid progress in autonomous driving technology, detecting and classifying road signs has become a critical task. Road signs provide essential information for safe and efficient navigation, making their accurate detection indispensable for modern autonomous vehicles.
This project leverages cutting-edge **Computer Vision** and **Deep Learning** techniques to build and evaluate high-performance road sign detection models. The models are trained on diverse road sign images collected from Google Images, Google Shopping, and Kaggle, covering 4 categories **Stop**, **Speed Limit**, **Traffic Light**, and **Cross Walk**. For more details, please refer to our source code and the final report at [https://github.com/hoangcaobao/CSE881](https://github.com/hoangcaobao/CSE881).
Below, you can upload an image of a road sign below to see how well our fine-tuned models (ResNet and VGG) can classify it!
""")
option = st.selectbox(
"Which Computer Vision Architectures you want to use?",
("VGG", "ResNet"),
)
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
label_mapping = {0: 'Crosswalk Sign', 1: 'Speed Limit Sign', 2: 'Stop Sign', 3: "Traffic Light"}
if uploaded_image is not None:
# Open the uploaded image
image = Image.open(uploaded_image)
# Resize
image_np = np.array(image)
image = Image.fromarray(image_np)
image = image.resize((256, 256))
if image.mode == 'RGBA':
image = image.convert('RGB')
st.image(image, caption="Uploaded Image")
# Load the model
if option == "VGG":
model = vgg
elif option == "ResNet":
model = resnet
model.eval()
# Transform to Tensor
image_prep = transforms.Compose([
transforms.ToTensor(), # Convert image to PyTorch tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image with mean and std
])
image_tensor = image_prep(image).unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
_, pred = torch.max(output.data, 1)
# Get label from prediction
label = label_mapping[pred.item()]
st.write(f"Uploaded Image Is {label}")