Skip to content

Commit

Permalink
controlnet_aux.hed has update! Change Network to ControlNetHED_Apache2
Browse files Browse the repository at this point in the history
Grateful for your work! As mentioned in microsoft#233, running from controlnet_aux.hed import Network now raises an error. This is because starting from controlnet_aux version 0.06, the controlnet_aux/hed file has undergone changes.

To address this, I referred to the code provided in huggingface/controlnet_aux#66 and made modifications to the original code accordingly. It's worth noting that the network-bsds500.pth model has also been replaced with ControlNetHED.pth, which can be downloaded from https://huggingface.co/lllyasviel/Annotators.
  • Loading branch information
princepride authored Apr 22, 2024
1 parent 8d38e28 commit 5af9bb1
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions hugginggpt/server/models_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
from controlnet_aux.open_pose.body import Body
from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
from controlnet_aux.hed import Network
from controlnet_aux.hed import ControlNetHED_Apache2
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
import warnings
import time
Expand Down Expand Up @@ -279,8 +279,10 @@ def mlsd_control_network():
model.load_state_dict(torch.load(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"), strict=True)
return MLSDdetector(model)


hed_network = Network(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth")
model_path = f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/ControlNetHED.pth"
hed_network = ControlNetHED_Apache2()
hed_network.load_state_dict(torch.load(model_path, map_location="cpu"))
hed_network.float().eval()

controlnet_sd_pipes = {
"openpose-control": {
Expand Down Expand Up @@ -632,4 +634,4 @@ def models(model_id):
if not os.path.exists("public/videos"):
os.makedirs("public/videos")

waitress.serve(app, host="0.0.0.0", port=port)
waitress.serve(app, host="0.0.0.0", port=port)

0 comments on commit 5af9bb1

Please sign in to comment.