From 5af9bb1b0dd2608655ef927c8749d107bd0d6a08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Mon, 22 Apr 2024 11:03:23 +0800 Subject: [PATCH] controlnet_aux.hed has update! Change Network to ControlNetHED_Apache2 Grateful for your work! As mentioned in https://github.com/microsoft/JARVIS/issues/233, running from controlnet_aux.hed import Network now raises an error. This is because starting from controlnet_aux version 0.06, the controlnet_aux/hed file has undergone changes. To address this, I referred to the code provided in https://github.com/huggingface/controlnet_aux/issues/66 and made modifications to the original code accordingly. It's worth noting that the network-bsds500.pth model has also been replaced with ControlNetHED.pth, which can be downloaded from https://huggingface.co/lllyasviel/Annotators. --- hugginggpt/server/models_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/hugginggpt/server/models_server.py b/hugginggpt/server/models_server.py index 2d7c2a3..7b4e44b 100644 --- a/hugginggpt/server/models_server.py +++ b/hugginggpt/server/models_server.py @@ -29,7 +29,7 @@ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector from controlnet_aux.open_pose.body import Body from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large -from controlnet_aux.hed import Network +from controlnet_aux.hed import ControlNetHED_Apache2 from transformers import DPTForDepthEstimation, DPTFeatureExtractor import warnings import time @@ -279,8 +279,10 @@ def mlsd_control_network(): model.load_state_dict(torch.load(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"), strict=True) return MLSDdetector(model) - - hed_network = Network(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth") + model_path = f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/ControlNetHED.pth" + hed_network = ControlNetHED_Apache2() + hed_network.load_state_dict(torch.load(model_path, map_location="cpu")) + hed_network.float().eval() controlnet_sd_pipes = { "openpose-control": { @@ -632,4 +634,4 @@ def models(model_id): if not os.path.exists("public/videos"): os.makedirs("public/videos") - waitress.serve(app, host="0.0.0.0", port=port) \ No newline at end of file + waitress.serve(app, host="0.0.0.0", port=port)