Skip to content

Commit

Permalink
Update for MistV1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas0228 committed Apr 30, 2023
1 parent fef2a64 commit 56c375e
Show file tree
Hide file tree
Showing 31 changed files with 311 additions and 631 deletions.
11 changes: 6 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

hub/
src/
openai/
test/vangogh_*
test/man_*
<<<<<<< HEAD
test/man_*.png
test/sample_*.png
=======
outputs/dirs/*/*.png
outputs/dirs/*/*.jpg
outputs/images/*.png
outputs/images/*.jpg
models/
misted_*.png
>>>>>>> b9c5f8acd3dcb6bedb7b281f92f29b277d33698a
*.bin
*.ckpt
*.ipynb_checkpoints*
Expand Down
106 changes: 0 additions & 106 deletions mist-webui-size-mask.py

This file was deleted.

54 changes: 44 additions & 10 deletions mist-webui.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import numpy as np
import gradio as gr
from mist_v2 import init, infer, load_image_from_path
from mist_v3 import init, infer
from mist_utils import load_image_from_path, closing_resize
import os
from tqdm import tqdm
from PIL import Image
import PIL
from PIL import Image, ImageOps


def reverse_mask(mask):
r, g, b, a = mask.split()
mask = PIL.Image.merge('RGB', (r,g,b))
return ImageOps.invert(mask)


config = init()
target_image_path = os.path.join(os.getcwd(), 'MIST.png')


def process_image(image, eps, steps, input_size, rate, mode, block_mode):
def process_image(image, eps, steps, input_size, rate, mode, block_mode, no_resize):
print('Processing....')
if mode == 'Textural':
mode_value = 1
Expand All @@ -19,8 +28,11 @@ def process_image(image, eps, steps, input_size, rate, mode, block_mode):
mode_value = 2
if image is None:
raise ValueError
tar_img = load_image_from_path(target_image_path, input_size)
img = image.resize((input_size, input_size), resample=Image.BICUBIC)

processed_mask = reverse_mask(image['mask'])

image = image['image']

print('tar_img loading fin')
config['parameters']['epsilon'] = eps / 255.0 * (1 - (-1))
config['parameters']['steps'] = steps
Expand All @@ -29,15 +41,33 @@ def process_image(image, eps, steps, input_size, rate, mode, block_mode):

config['parameters']['mode'] = mode_value
block_num = len(block_mode) + 1
resize = len(no_resize)
bls = input_size // block_num
if resize:
img, target_size = closing_resize(image, input_size, block_num, True)
bls_h = target_size[0]//block_num
bls_w = target_size[1]//block_num
tar_img = load_image_from_path(target_image_path, target_size[0],
target_size[1])
else:
img = load_image_from_path(image_path, input_size)
tar_img = load_image_from_path(target_image_path, input_size)
bls_h = bls_w = bls
target_size = [input_size, input_size]
processed_mask = load_image_from_path(processed_mask, target_size[0], target_size[1], True)
config['parameters']['input_size'] = bls
print(config['parameters'])
output_image = np.zeros([input_size, input_size, 3])
for i in tqdm(range(block_num)):
for j in tqdm(range(block_num)):
img_block = Image.fromarray(np.array(img)[bls * i: bls * i + bls, bls * j: bls * j + bls])
tar_block = Image.fromarray(np.array(tar_img)[bls * i: bls * i + bls, bls * j: bls * j + bls])
output_image[bls * i: bls * i + bls, bls * j: bls * j + bls] = infer(img_block, config, tar_block)
if processed_mask is not None:
input_mask = Image.fromarray(np.array(processed_mask)[bls_w*i: bls_w*i+bls_w, bls_h*j: bls_h*j + bls_h])
else:
input_mask = None
img_block = Image.fromarray(np.array(img)[bls_w*i: bls_w*i+bls_w, bls_h*j: bls_h*j + bls_h])
tar_block = Image.fromarray(np.array(tar_img)[bls_w*i: bls_w*i+bls_w, bls_h*j: bls_h*j + bls_h])

output_image[bls_w*i: bls_w*i+bls_w, bls_h*j: bls_h*j + bls_h] = infer(img_block, config, tar_block, input_mask)
output = Image.fromarray(output_image.astype(np.uint8))
return output

Expand All @@ -48,7 +78,7 @@ def process_image(image, eps, steps, input_size, rate, mode, block_mode):
gr.Image("MIST_logo.png", show_label=False)
with gr.Row():
with gr.Column():
image = gr.Image(type='pil')
image = gr.Image(type='pil', tool='sketch')
eps = gr.Slider(0, 32, step=4, value=16, label='Strength',
info="Larger strength results in stronger defense at the cost of more visible noise.")
steps = gr.Slider(0, 1000, step=1, value=100, label='Steps',
Expand All @@ -66,7 +96,11 @@ def process_image(image, eps, steps, input_size, rate, mode, block_mode):
block_mode = gr.CheckboxGroup(["Low VRAM usage mode"],
info="Use this mode if the VRAM of your device is not enough. Check the documentation for more information.",
label='VRAM mode')
inputs = [image, eps, steps, input_size, rate, mode, block_mode]
with gr.Accordion("Experimental option for non-square input", open=False):
no_resize = gr.CheckboxGroup(["No-resize mode"],
info="Use this mode if you donot want your image to be resized in square shape. This option is still experimental and may reduce the strength of MIST.",
label='No-resize mode')
inputs = [image, eps, steps, input_size, rate, mode, block_mode, no_resize]
image_button = gr.Button("Mist")
outputs = gr.Image(label='Misted image')
image_button.click(process_image, inputs=inputs, outputs=outputs)
Expand Down
157 changes: 157 additions & 0 deletions mist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import argparse
import PIL
from PIL import Image
import numpy as np


def parse_args():
parser = argparse.ArgumentParser(description="Configs for Mist V1.2")
parser.add_argument(
"-img",
"--input_image_path",
type=str,
default="test/sample.png",
help="path of input image",
)
parser.add_argument(
"--output_name",
type=str,
default="misted_sample",
help="path of saved image",
)
parser.add_argument(
"--output_dir",
type=str,
default="vangogh",
help="path of output dir",
)
parser.add_argument(
"-inp",
"--input_dir_path",
type=str,
default=None,
help="Path of the dir of images to be processed.",
)
parser.add_argument(
"-e",
"--epsilon",
type=int,
default=16,
help=(
"The strength of Mist"
),
)
parser.add_argument(
"-s",
"--steps",
type=int,
default=100,
help=(
"The step of Mist"
),
)
parser.add_argument(
"-in_size",
"--input_size",
type=int,
default=512,
help=(
"The input_size of Mist"
),
)
parser.add_argument(
"-b",
"--block_num",
type=int,
default=1,
help=(
"The number of partitioned blocks"
),
)
parser.add_argument(
"--mode",
type=int,
default=2,
help=(
"The mode of MIST."
),
)
parser.add_argument(
"--rate",
type=int,
default=1,
help=(
"The fused weight under the fused mode."
),
)
parser.add_argument(
"--mask",
default=False,
action="store_true",
help=(
"Whether to mask certain region of Mist or not. Work only when input_dir_path is None. "
),
)
parser.add_argument(
"--mask_path",
type=str,
default="test/processed_mask.png",
help="Path of the mask.",
)
parser.add_argument(
"--non_resize",
default=False,
action="store_true",
help=(
"Whether to keep the original shape of the image or not."
),
)
args = parser.parse_args()
return args


def load_mask(mask):
mask = np.array(mask)[:,:,0:3]
for p in range(mask.shape[0]):
for q in range(mask.shape[1]):
# if np.sum(mask[p][q]) != 0:
# mask[p][q] = 255
if mask[p][q][0] != 255:
mask[p][q] = 0
else:
mask[p][q] = 255
return mask


def closing_resize(image_path: str, input_size: int, block_num: int = 1, no_load: bool = False) -> PIL.Image.Image:
if no_load:
im = image_path
else:
im = Image.open(image_path)
target_size = list(im.size)

resize_parameter = min(target_size[0], target_size[1])/input_size
block_size = 8 * block_num
target_size[0] = int(target_size[0] / resize_parameter) // block_size * block_size
target_size[1] = int(target_size[1] / resize_parameter) // block_size * block_size
img = im.resize(target_size)
return img, target_size


def load_image_from_path(image_path: str, input_width: int, input_height: int = 0, no_load: bool = False) -> PIL.Image.Image:
"""
Load image form the path and reshape in the input size.
:param image_path: Path of the input image
:param input_size: The requested size in int.
:returns: An :py:class:`~PIL.Image.Image` object.
"""
if input_height == 0:
input_height = input_width
if no_load:
img = image_path.resize((input_width, input_height),
resample=PIL.Image.BICUBIC)
else:
img = Image.open(image_path).resize((input_width, input_height),
resample=PIL.Image.BICUBIC)
return img

Loading

0 comments on commit 56c375e

Please sign in to comment.