Skip to content

Commit

Permalink
Merge pull request #186 from BCC168/main
Browse files Browse the repository at this point in the history
feat:wd14-tagger support
  • Loading branch information
mix1009 authored Nov 29, 2024
2 parents 84f5b85 + 415fdd4 commit 6747783
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
6 changes: 4 additions & 2 deletions webuiapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
SegmentAnythingGinoResult,
SegmentAnythingControlNetSegRandomResult,
SegmentAnythingControlNetSegNotRandomResult,
SegmentAnythingSemanticSegWithCatIdResult
SegmentAnythingSemanticSegWithCatIdResult,
TaggerInterface
)

__version__ = "0.9.16"
Expand Down Expand Up @@ -52,5 +53,6 @@
"SegmentAnythingGinoResult",
"SegmentAnythingControlNetSegRandomResult",
"SegmentAnythingControlNetSegNotRandomResult",
"SegmentAnythingSemanticSegWithCatIdResult"
"SegmentAnythingSemanticSegWithCatIdResult",
"TaggerInterface"
]
30 changes: 29 additions & 1 deletion webuiapi/webuiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def interrogate(self, image, model="clip"):

response = self.session.post(url=f"{self.baseurl}/interrogate", json=payload)
return self._to_api_result(response)

def list_prompt_gen_models(self):
r = self.custom_get("promptgen/list_models")
return r['available_models']
Expand Down Expand Up @@ -2002,3 +2002,31 @@ def sam_and_semantic_seg_with_cat_id(
masked_image=Image.open(io.BytesIO(base64.b64decode(r["masked_image"]))),
resized_input=Image.open(io.BytesIO(base64.b64decode(r["resized_input"])))
)

# https://github.com/Akegarasu/sd-webui-wd14-tagger

class TaggerInterface:
def __init__(self, webuiapi: WebUIApi):
self.api = webuiapi

def tagger_interrogate(self, image, model="wd14-vit-v2-git", threshold=0.0, use_async=False):
"""
Interrogates the tagger model with the provided image and parameters.
Args:
image (Image.Image or str): The image to be interrogated. Can be a PIL Image object or a base64 encoded string.
model (str, optional): The model to use for interrogation. Defaults to "wd14-vit-v2-git".
threshold (float, optional): The threshold value for the model. Defaults to 0.
use_async (bool, optional): Whether to use asynchronous processing. Defaults to False.
Returns:
WebUIApiResult.info: The information returned by the web API.
"""
payload = {
"image": b64_img(image) if isinstance(image, Image.Image) else image,
"model": model,
"threshold": threshold
}
return self.api.custom_post("tagger/v1/interrogate", payload=payload, use_async=use_async)
def tagger_interrogators(self):
return self.api.custom_get("tagger/v1/interrogators")

0 comments on commit 6747783

Please sign in to comment.