Skip to content

Commit

Permalink
feat: add color selector
Browse files Browse the repository at this point in the history
  • Loading branch information
kerlomz committed Dec 20, 2018
1 parent 288ef2c commit a6a1ae0
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 21 deletions.
179 changes: 179 additions & 0 deletions character.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self, conf_path: str, graph_path: str = None, model_path: str = Non
self.logger_tag = self.sys_cf['System'].get('LoggerTag')
self.logger_tag = self.logger_tag if self.logger_tag else "coriander"
self.logger = logging.getLogger(self.logger_tag)
self.static_path = self.sys_cf['System'].get('StaticPath')
self.static_path = self.static_path if self.static_path else 'static'
self.use_default_authorization = False
self.authorization = None
self.init_logger()
Expand Down
12 changes: 7 additions & 5 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
DEFAULT_HOST = "localhost"


def _image(_path, model_type=None, model_site=None):
def _image(_path, model_type=None, model_site=None, need_color=None):
with open(_path, "rb") as f:
img_bytes = f.read()

b64 = base64.b64encode(img_bytes).decode()
return {
'image': b64,
'model_type': model_type,
'model_site': model_site
'model_site': model_site,
'need_color': need_color,
}


Expand Down Expand Up @@ -225,13 +226,14 @@ def press_testing(self, image_list: dict, model_type=None, model_site=None):
_path.split('_')[0].lower(): _image(
os.path.join(path, _path),
model_type=None,
model_site=None
model_site=None,
need_color=None,
)
for i, _path in enumerate(path_list)
if i < 1000
}
print(batch)
# NoAuth(DEFAULT_HOST, ServerType.TORNADO).press_testing(batch)
# print(batch)
# NoAuth(DEFAULT_HOST, ServerType.TORNADO).local_iter(batch)
# NoAuth(DEFAULT_HOST, ServerType.FLASK).local_iter(batch)
# NoAuth(DEFAULT_HOST, ServerType.SANIC).local_iter(batch)
GoogleRPC(DEFAULT_HOST).local_iter(batch, model_site=None, model_type=None)
Expand Down
1 change: 1 addition & 0 deletions grpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ message PredictRequest {
string model_name = 3;
string model_type = 4;
string model_site = 5;
string need_color = 6;
}

message PredictResult {
Expand Down
21 changes: 14 additions & 7 deletions grpc_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def predict(self, request, context):
if not interface:
logger.info('Service is not ready!')
return {"result": "", "success": False, "code": 999}
image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=request.need_color)

if not image_batch:
return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code'])
Expand Down
6 changes: 4 additions & 2 deletions tornado_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def post(self):
model_site = ParamUtils.filter(data.get('model_site'))
model_name = ParamUtils.filter(data.get('model_name'))
split_char = ParamUtils.filter(data.get('split_char'))
need_color = ParamUtils.filter(data.get('need_color'))

if not bytes_batch:
logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
Expand All @@ -101,7 +102,7 @@ def post(self):

split_char = split_char if 'split_char' in data else interface.model_conf.split_char

image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=need_color)

if not image_batch:
logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
Expand Down Expand Up @@ -138,6 +139,7 @@ def post(self):
model_site = ParamUtils.filter(data.get('model_site'))
model_name = ParamUtils.filter(data.get('model_name'))
split_char = ParamUtils.filter(data.get('split_char'))
need_color = ParamUtils.filter(data.get('need_color'))

bytes_batch, response = ImageUtils.get_bytes_batch(data['image'])

Expand All @@ -162,7 +164,7 @@ def post(self):

split_char = split_char if 'split_char' in data else interface.model_conf.split_char

image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=need_color)

if not image_batch:
logger.error('[{}] - Size[{}] - Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
Expand Down
48 changes: 42 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,59 @@ def get_bytes_batch(base64_img):
return bytes_batch, response.SUCCESS

@staticmethod
def get_image_batch(model: ModelConfig, bytes_batch):
def get_image_batch(model: ModelConfig, bytes_batch, color=None):
# Note that there are two return objects here.
# 1.image_batch, 2.response

response = Response()

def load_image(image_bytes):
data_stream = io.BytesIO(image_bytes)
pil_image = PIL_Image.open(data_stream).convert('RGB')
image = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2GRAY)
hsv_map = {
"blue": {
"lower_hsv": np.array([100, 128, 46]),
"high_hsv": np.array([124, 255, 255])
},
"red": {
"lower_hsv": np.array([0, 128, 46]),
"high_hsv": np.array([5, 255, 255])
},
"yellow": {
"lower_hsv": np.array([15, 128, 46]),
"high_hsv": np.array([34, 255, 255])
},
"green": {
"lower_hsv": np.array([35, 128, 46]),
"high_hsv": np.array([77, 255, 255])
},
"black": {
"lower_hsv": np.array([0, 0, 0]),
"high_hsv": np.array([180, 255, 46])
}
}

def separate_color(pil_image, color):
hsv = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_BGR2HSV)
lower_hsv = hsv_map[color]['lower_hsv']
high_hsv = hsv_map[color]['high_hsv']
mask = cv2.inRange(hsv, lowerb=lower_hsv, upperb=high_hsv)
return mask

def load_image(image_bytes, color=None):

if color and color in ['red', 'blue', 'black', 'green', 'yellow']:
image = np.asarray(bytearray(image_bytes), dtype="uint8")
image = cv2.imdecode(image, -1)
image = separate_color(image, color)
else:
data_stream = io.BytesIO(image_bytes)
pil_image = PIL_Image.open(data_stream).convert('RGB')
image = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2GRAY)
image = preprocessing(image, model.binaryzation, model.smooth, model.blur).astype(np.float32)
image = cv2.resize(image, (model.resize[0], model.resize[1]))
image = image.swapaxes(0, 1)
return image[:, :, np.newaxis] / 255.

try:
image_batch = [load_image(i) for i in bytes_batch]
image_batch = [load_image(i, color=color) for i in bytes_batch]
return image_batch, response.SUCCESS
except OSError:
return None, response.IMAGE_DAMAGE
Expand Down

0 comments on commit a6a1ae0

Please sign in to comment.