Skip to content

Commit

Permalink
Fix comfy Ci error
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu committed Jun 11, 2024
1 parent 7487bb2 commit ea724a1
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions onediff_comfy_nodes/modules/oneflow/hijack_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options):
out_conds = []
out_counts = []
to_run = []

for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
Expand All @@ -40,6 +40,7 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options):
to_batch_temp.reverse()
# to_batch = to_batch_temp[:1]
to_batch = to_batch_temp

# free_memory = model_management.get_free_memory(x_in.device)
# for i in range(1, len(to_batch_temp) + 1):
# batch_amount = to_batch_temp[:len(to_batch_temp)//i]
Expand Down Expand Up @@ -93,27 +94,37 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options):
transformer_options["cond_or_uncond"] = cond_or_uncond[:]

diff_model = model.diffusion_model

if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch(diff_model):
transformer_options["sigmas"] = timestep[0].item()
patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions)
transformer_options["_attn2"] = patch_executor.get_patch(diff_model)["attn2"]
else:
transformer_options["sigmas"] = timestep
# transformer_options["sigmas"] = timestep


c['transformer_options'] = transformer_options

if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)


for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]

a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]

for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]

Expand Down

0 comments on commit ea724a1

Please sign in to comment.