Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected performance regression with outer reduction #3815

Open
naoyam opened this issue Feb 3, 2025 · 1 comment
Open

Unexpected performance regression with outer reduction #3815

naoyam opened this issue Feb 3, 2025 · 1 comment
Assignees

Comments

@naoyam
Copy link
Collaborator

naoyam commented Feb 3, 2025

This fusion is one of the segments of the Qwen2 RoPE backward benchmark.

TEST_F(OuterReductionTest, Qwen2RopePerfRegression) {
  auto fusion_ptr = std::make_unique<Fusion>();
  auto& fusion = *fusion_ptr;
  FusionGuard fg(fusion_ptr.get());

  std::vector<int64_t> shape1{1, 28, 4096, 128};
  std::vector<int64_t> shape2{1, 4, 4096, 128};

  auto T0 = makeContigConcreteTensor(shape1, DataType::BFloat16);
  fusion.addInput(T0);
  auto T28 = TensorViewBuilder()
                 .shape(shape2)
                 .dtype(DataType::BFloat16)
                 .contiguity(std::vector<std::optional<bool>>{
                     std::nullopt, std::nullopt, true, true})
                 .expanded({false, true, false, false})
                 .build();
  fusion.addInput(T28);

  auto T34 = castOp(DataType::Float, T28);
  auto T7 = reshape(T0, shape1, {1, 4, 7, 4096, 128});
  auto T8 = castOp(DataType::Float, T7);
  auto T9 = squeeze(T8, {0});
  auto T10 = sum(T9, {1});
  auto T11 = castOp(DataType::BFloat16, T10);
  auto T163 = broadcast(T11, {true, false, false, false});
  auto T33 = castOp(DataType::Float, T163);
  auto T38 = mul(T34, T33);
  auto T42 = castOp(DataType::BFloat16, T38);
  if (getenv("SAVE_FP32")) {
    fusion.addOutput(T33);
  } else {
    fusion.addOutput(T163);
  }
  fusion.addOutput(T42);

  fusion.printMath();

  auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0);
  auto t0 = at::randn(shape1, options);
  auto t1 = at::randn({1, 1, 4096, 128}, options)
                .as_strided(shape2, {4096L * 128L, 0L, 128L, 1L});
  std::vector<c10::IValue> inputs({t0, t1});

  FusionExecutorCache executor_cache(std::move(fusion_ptr));
  auto outputs = executor_cache.runFusionWithInputs(inputs);
  testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

This is an outer reduction fusion. I'm hitting an unexpected performance regression with this segment with the segmentation optimization of #3776. Specifically, the PR allows to save T163 instead of T33. The former is bfloat16, while the latter is float, so naturally I thought this should be a real optimization and lead to a better performance. One of the outputs is half of the size, and not only that, because all inputs and outputs are bfloat16, the vectorization factor is 8 vs 4.

To my surprise, that isn't actually the case, at least on H100. With SAVE_FP32, it takes about 15 us, whereas without the flag, it's 22 us (measured with ncu). While both cases are relatively small and the impact to the overall performance is not significant (the overall time is around 140 us), it's still sad to see 45% slowdown with this seemingly obvious optimization.

This is the heuristics with the bfloat16 output:

===== Outer Reduction Stats ========
total_reduction_numel: 7
total_iteration_numel: 2097152
vectorize_factor: 8
redu_unroll_factor: 4
grid(256, 1, 1)
block(1024, 1, 1)


===== Reduction Parameters ========

Red On Slow Dim

Iteration Domain: blockIdx.x / threadIdx.x / multiple reductions per block / vectorize / factor 8
Inner Reduction Domain: unroll / factor 4
Launch Parameters: BlockDim.x = 1024, BlockDim.y = 1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = NotSet, maxrregcount = 64, enable_magic_zero = 1, enable_ptxas_verbose = 0

====================================

With the float output:

===== Outer Reduction Stats ========
total_reduction_numel: 7
total_iteration_numel: 2097152
vectorize_factor: 4
redu_unroll_factor: 7
grid(512, 1, 1)
block(1024, 1, 1)


===== Reduction Parameters ========

Red On Slow Dim

Iteration Domain: blockIdx.x / threadIdx.x / multiple reductions per block / vectorize / factor 4
Inner Reduction Domain: unroll / factor 7
Launch Parameters: BlockDim.x = 1024, BlockDim.y = 1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = NotSet, maxrregcount = 64, enable_magic_zero = 1, enable_ptxas_verbose = 0

====================================

It appears that the regression comes from the difference of redu_unroll_factor. In the case of bfloat16, because it's 4, we have Split: iS82{7} by factor 4 -> iS179{2}, iS180{4}, as seen blow:

T0_g___bfloat[iS177{( ceilDiv(262144, blockDim.x) )}, iS176{blockDim.x}, iS178{1}, iS174{8}, iS181{2}, iS182{1}, iS180{4}, bS0{1}]
 logical domain : (bS0{1}, iS1{28}, iS2{4096}, iS3{128})
 contiguity: n t t t
  Outer split: iS1{28} by factor 4 -> iS81{4}, iS82{7}
  Merge: iS2{4096} and iS3{128} -> iS171{524288}
  Merge: iS81{4} and iS171{524288} -> iS172{2097152}
  Split: iS172{2097152} by factor 8 -> iS173{262144}, iS174{8}
  Split: iS173{262144} by factor blockDim.x -> iS175{( ceilDiv(262144, blockDim.x) )}, iS176{blockDim.x}
  Split: iS175{( ceilDiv(262144, blockDim.x) )} by factor 1 -> iS177{( ceilDiv(262144, blockDim.x) )}, iS178{1}
  Split: iS82{7} by factor 4 -> iS179{2}, iS180{4}
  Split: iS179{2} by factor 1 -> iS181{2}, iS182{1}
 loop domain : (iS177{( ceilDiv(262144, blockDim.x) )}, iS176{blockDim.x}, iS178{1}, iS174{8}, iS181{2}, iS182{1}, iS180{4}, bS0{1})

In the case of float, the scheduler chooses to use 7 for redu_unroll_factor, resulting in:

T0_g___bfloat[iS177{( ceilDiv(524288, blockDim.x) )}, iS176{blockDim.x}, iS178{1}, iS174{4}, iS181{1}, iS182{1}, iS180{7}, bS0{1}]
 logical domain : (bS0{1}, iS1{28}, iS2{4096}, iS3{128})
 contiguity: n t t t
  Outer split: iS1{28} by factor 4 -> iS81{4}, iS82{7}
  Merge: iS2{4096} and iS3{128} -> iS171{524288}
  Merge: iS81{4} and iS171{524288} -> iS172{2097152}
  Split: iS172{2097152} by factor 4 -> iS173{524288}, iS174{4}
  Split: iS173{524288} by factor blockDim.x -> iS175{( ceilDiv(524288, blockDim.x) )}, iS176{blockDim.x}
  Split: iS175{( ceilDiv(524288, blockDim.x) )} by factor 1 -> iS177{( ceilDiv(524288, blockDim.x) )}, iS178{1}
  Split: iS82{7} by factor 7 -> iS179{1}, iS180{7}
  Split: iS179{1} by factor 1 -> iS181{1}, iS182{1}
 loop domain : (iS177{( ceilDiv(524288, blockDim.x) )}, iS176{blockDim.x}, iS178{1}, iS174{4}, iS181{1}, iS182{1}, iS180{7}, bS0{1})

Notice that we just have a split by 7 of iS82, which avoids the non-divisible split of 7 by 4.

I manually forced to do the same split by 7 in the bfloat16 case and it was able to recover the lost performance.

@liqiangxl
Copy link
Collaborator

Thanks for the details.
The difference in unroll factors between these two cases arises from the constraint applied to unroll factor × vect factor, ensuring that the kernel does not use excessive registers to store data loaded from global memory to registers. For this case, the constraint is unroll factor × vect factor ≤ 32.

For instance, when vect factor = 4, the unroll factor must satisfy unroll factor ≤ 32 / 4 = 8, leading to a selection of 7. Similarly, when vect factor = 8, the unroll factor is limited to unroll factor ≤ 32 / 8 = 4.

I am considering disabling unrolling if the reduction dimension is not divisible by the unroll factor for thread-local reductions. Will do some benchmarks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants