Skip to content

Commit

Permalink
Support file uploading in Advanced Customization tab (#59)
Browse files Browse the repository at this point in the history
Implement the file uploading feature of #58 .

User can upload a single image or multi-modal MRI images.

Demo video:
#35 (reply in thread)

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Jan 16, 2025
1 parent e888c3c commit 2b0e4f4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
2 changes: 0 additions & 2 deletions m3/demo/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ RUN ln -s /usr/bin/python3 /usr/bin/python

RUN git clone https://github.com/Project-MONAI/VLM --recursive
WORKDIR /VLM
###RUN python3.10 -m venv .venv
####RUN source .venv/bin/activate
RUN make demo_m3
4 changes: 4 additions & 0 deletions m3/demo/experts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def image_to_data_url(image, format="JPEG", max_size=None):
# Create a BytesIO buffer to save the image
buffered = BytesIO()
# Save the image to the buffer in the specified format
if img.mode == 'RGBA':
img = img.convert('RGB')
img.save(buffered, format=format)
# Convert the buffer content into bytes
img_byte = buffered.getvalue()
Expand All @@ -337,6 +339,8 @@ def resize_data_url(data_url, max_size):
# Create a BytesIO buffer to save the image
buffered = BytesIO()
# Save the image to the buffer in the specified format
if img.mode == 'RGBA':
img = img.convert('RGB')
img.save(buffered, format="JPEG")
# Convert the buffer content into bytes
img_byte = buffered.getvalue()
Expand Down
23 changes: 19 additions & 4 deletions m3/demo/gradio_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(self):
self.interactive = False
self.sys_msgs_to_hide = []
self.modality_prompt = "Auto"
self.img_urls_or_paths = IMG_URLS_OR_PATHS

def restore_from_backup(self, attr):
"""Retrieve the attribute from the backup"""
Expand Down Expand Up @@ -524,6 +525,7 @@ def process_prompt(self, prompt, sv, chat_history):
interactive=True,
sys_msgs_to_hide=sv.sys_msgs_to_hide,
backup={"image_url": sv.image_url, "slice_index": sv.slice_index},
img_urls_or_paths=sv.img_urls_or_paths,
)
return (
None,
Expand All @@ -544,7 +546,7 @@ def input_image(image, sv: SessionVariables):
def update_image_selection(selected_image, sv: SessionVariables, slice_index=None):
"""Update the gradio components based on the selected image"""
logger.debug(f"Updating display image for {selected_image}")
sv.image_url = IMG_URLS_OR_PATHS.get(selected_image, None)
sv.image_url = sv.img_urls_or_paths.get(selected_image, None)
img_file = CACHED_IMAGES.get(sv.image_url, None)

if sv.image_url is None or img_file is None:
Expand Down Expand Up @@ -642,7 +644,7 @@ def clear_all_convs(sv: SessionVariables):
logger.debug(f"Clearing all conversations")
if sv.temp_working_dir is not None:
rmtree(sv.temp_working_dir)
new_sv = new_session_variables()
new_sv = new_session_variables(img_urls_or_paths=sv.img_urls_or_paths)
# Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_checkbox, model_cards_text, modality_prompt_dropdown
return (
new_sv,
Expand Down Expand Up @@ -710,6 +712,17 @@ def download_file():
"""Download the file."""
return [gr.DownloadButton(visible=False)]

def upload_file(files, sv):
"""Upload the file."""
logger.debug(f"Uploading the file {files}")
idx = len(sv.img_urls_or_paths) + 1 - len(IMG_URLS_OR_PATHS)
sv.img_urls_or_paths.update({f"User Data {idx}": files})
new_image_dropdown = gr.Dropdown(
label="Select an image", choices=["Please select .."] + list(sv.img_urls_or_paths.keys())
)
CACHED_IMAGES.cache(sv.img_urls_or_paths)
return sv, new_image_dropdown


def create_demo(source, model_path, conv_mode, server_port):
"""Main function to create the Gradio interface"""
Expand All @@ -723,7 +736,7 @@ def create_demo(source, model_path, conv_mode, server_port):
with gr.Row():
with gr.Column():
image_dropdown = gr.Dropdown(
label="Select an image", choices=["Please select .."] + list(IMG_URLS_OR_PATHS.keys())
label="Select an image", choices=["Please select .."] + list(sv.value.img_urls_or_paths.keys())
)
image_input = gr.Image(
label="Image", sources=[], placeholder="Please select an image from the dropdown list."
Expand All @@ -746,7 +759,8 @@ def create_demo(source, model_path, conv_mode, server_port):
label="Max Tokens", minimum=1, maximum=1024, step=1, value=sv.value.max_tokens, interactive=True
)

with gr.Accordion("System Prompt and Message", open=False):
with gr.Accordion("Advanced Customization", open=False):
upload_button = gr.UploadButton("Click to Upload Files")
modality_prompt_dropdown = gr.Dropdown(
label="Select Modality",
choices=["Auto", "CT", "MRI", "CXR", "Unknown"],
Expand Down Expand Up @@ -811,6 +825,7 @@ def create_demo(source, model_path, conv_mode, server_port):
model_cards_checkbox.change(fn=update_model_cards_checkbox, inputs=[model_cards_checkbox, sv], outputs=[sv])
model_cards_text.change(fn=update_model_cards_text, inputs=[model_cards_text, sv], outputs=[sv])
modality_prompt_dropdown.change(fn=update_modality_prompt, inputs=[modality_prompt_dropdown, sv], outputs=[sv])
upload_button.upload(fn=upload_file, inputs=[upload_button, sv], outputs=[sv, image_dropdown])
# Reset button
clear_btn.click(
fn=clear_all_convs,
Expand Down

0 comments on commit 2b0e4f4

Please sign in to comment.