-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
228 lines (175 loc) · 8.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import io
import sys
import numpy as np
import onnxruntime as ort
from typing import List, Tuple, Union
# Needs to be imported before eel to not crash when using --noconsole
sys.stdout = io.StringIO()
sys.stderr = io.StringIO()
import eel
import base64
from io import BytesIO
from PIL import Image
from utilities.class_names import get_classes_for_model
from utilities.prepare_images import replace_background, resize_and_pad_image, fix_image, convert_mask
import pooch
from rembg import new_session
# Initiate models
models: dict[str, Union[None, ort.InferenceSession]] = {
"car_type": None,
"all_specific_model_variants": None,
"specific_model_variants": None,
"pre_filter": None,
}
# Initiate session
session: new_session = new_session("u2net")
def load_model(model_name: str) -> ort.InferenceSession:
"""
Load a specific model from a set of predefined models.
This function downloads a model from a remote URL based on the given model name.
After the model is downloaded, an ONNX Inference Session is initialized with the model
and the session is returned.
Args:
model_name (str): Name of the model to be loaded. Valid model names include 'car_type',
'all_specific_model_variants', 'specific_model_variants' and 'pre_filter'.
Returns:
ort.InferenceSession: The initialized ONNX inference session for the loaded model.
"""
if model_name == "car_type":
url = "https://github.com/Flippchen/PorscheInsight-CarClassification-AI/releases/download/v.0.1/vgg16-pretrained-car-types.onnx"
md5 = "7c42a075ab9ca1a2a198e5cd241a06f7"
elif model_name == "all_specific_model_variants":
url = "https://github.com/Flippchen/PorscheInsight-CarClassification-AI/releases/download/v.0.1/efficientnet-old-head-all-model-variants-full_best_model.onnx"
md5 = "c54797cf92974c9ec962842e7ecd515c"
elif model_name == "specific_model_variants":
url = "https://github.com/Flippchen/PorscheInsight-CarClassification-AI/releases/download/v.0.1/efficientnet-model-variants_best_model.onnx"
md5 = "3de16b8cf529dc90f66c962a1c93a904"
elif model_name == "pre_filter":
url = "https://github.com/Flippchen/PorscheInsight-CarClassification-AI/releases/download/v.0.1/efficientnet-pre-filter_best_model.onnx"
md5 = "b70e531f5545afc66551c58f85d6694a"
else:
raise ValueError("Invalid model name")
# Show the loading notification
eel.showLoading()
# Download and cache the model using Pooch
model_path = pooch.retrieve(
url,
f"md5:{md5}",
fname=model_name + ".onnx",
progressbar=True,
)
print("Model downloaded to: ", model_path)
# Hide the loading notification
eel.hideLoading()
return ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
def prepare_image(image_data: Image, target_size: Tuple, remove_background: bool, show_mask: bool) -> Tuple[np.ndarray, Image.Image]:
"""
Prepare image data for prediction.
This function applies background removal, resizing and padding as required.
If remove_background is set to True, it uses the U2Net model to remove the image background.
If show_mask is set to True, it will also return the mask of the removed background.
Args:
image_data (Image): Input image data to be processed.
target_size (Tuple): Target size to resize the input image data.
remove_background (bool): Flag indicating whether to remove background from image.
show_mask (bool): Flag indicating whether to show the mask of removed background.
Returns:
Tuple[np.ndarray, Image.Image]: A tuple containing processed image data as a numpy array and the mask of the image.
"""
if remove_background and show_mask:
image, mask = replace_background(image_data, session=session)
elif remove_background:
image, _ = replace_background(image_data, session=session)
mask = None
else:
image = resize_and_pad_image(image_data, target_size)
mask = None
img_array = np.array(image).astype('float32')
img_array = np.expand_dims(img_array, 0)
if mask is None:
mask = Image.fromarray(np.zeros((1, 1), dtype=np.uint8))
mask = mask.convert("RGBA")
return img_array, mask
def get_top_n_predictions(prediction: np.ndarray, model_name: str, n: int = 3) -> List[Tuple[str, float]]:
"""
Get top n predictions from the model output.
Args:
prediction (np.ndarray): Output prediction from a model.
model_name (str): Name of the model that produced the prediction.
n (int, optional): Number of top predictions to retrieve. Defaults to 3.
Returns:
List[Tuple[str, float]]: A list of top n predictions along with their respective scores.
"""
# Ensure that n does not exceed the total number of classes
n = min(n, prediction[0].shape[0])
top_n_indices = prediction[0].argsort()[-n:][::-1]
classes = get_classes_for_model(model_name)
top_n_predictions = [(classes[i], round(prediction[0][i] * 100, 2)) for i in top_n_indices]
return top_n_predictions
def get_pre_filter_prediction(image_data: np.ndarray, model_name: str):
"""
Get pre-filter prediction results from the model.
This function loads a pre-filter model if it has not been loaded already.
Then it runs this model on the given image data and returns top 3 prediction results.
Args:
image_data (np.ndarray): Image data to run the model on.
model_name (str): Name of the pre-filter model.
Returns:
filter_names: Top 3 prediction results from the pre-filter model.
"""
if models[model_name] is None:
models[model_name] = load_model(model_name)
input_name = models[model_name].get_inputs()[0].name
prediction = models[model_name].run(None, {input_name: image_data})
filter_names = get_top_n_predictions(prediction[0], "pre_filter")
return filter_names
@eel.expose
def classify_image(image_data: str, model_name: str, show_mask: bool = False) -> Tuple[List[Tuple[str, float]], str] | List[List[Tuple[str, float]]]:
"""
Classify an image using a specified model.
This function loads the specified model if it has not been loaded already.
Then it decodes the base64 image data, fixes the image orientation and color,
and prepares the image for processing and prediction. Depending on the specified
show_mask option, it also returns the mask of the processed image.
Args:
image_data (str): Base64 encoded image data.
model_name (str): Name of the model to use for classification.
show_mask (str): Flag indicating whether to show the mask of processed image. Default is "no".
Returns:
Tuple[List[Tuple[str, float]], str] | List[List[Tuple[str, float]]]: If show_mask is "yes",
it returns a tuple containing the top 3 predictions along with their scores and the mask
of the processed image as base64 encoded string. If show_mask is "no", it returns only
the top 3 predictions.
"""
# Loading the model if it's not already loaded
model = models.get(model_name)
if not model:
model = load_model(model_name)
models[model_name] = load_model(model_name)
# Decoding the base64 image and fix image orientation/color
image = Image.open(BytesIO(base64.b64decode(image_data)))
image = fix_image(image)
# Retrieving the required input size for the specified model
input_size = model.get_inputs()[0].shape[1:3]
# Preparing image for processing and prediction
# FIXME: Currently, the background is removed prior to prediction. This is a workaround,
# as predictions seem to be better with a black background.
filter_image, mask = prepare_image(image, input_size, remove_background=True, show_mask=show_mask)
# Converting the mask for processing
mask = convert_mask(mask)
buffer = io.BytesIO()
mask.save(buffer, format="PNG")
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Getting initial predictions before applying filters
pre_filter_predictions = get_pre_filter_prediction(filter_image, "pre_filter")
# If the pre-filter model doesn't predict a Porsche, we can skip the specific model
if pre_filter_predictions[0][0] != "porsche":
return (pre_filter_predictions, mask_base64) if show_mask else [pre_filter_predictions]
# Run the specific model
input_name = model.get_inputs()[0].name
prediction = model.run(None, {input_name: filter_image})
# Retrieving the top 3 predictions
top_3_predictions = get_top_n_predictions(prediction[0], model_name)
return (top_3_predictions, mask_base64) if show_mask else [top_3_predictions]
eel.init("web")
eel.start("index.html", size=(1000, 800), mode="default")