Skip to content

Commit

Permalink
use default as fallback algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 authored and i-chaochen committed Jan 29, 2025
1 parent 89dbb32 commit 38ed029
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
11 changes: 7 additions & 4 deletions tensorflow/compiler/xla/service/gpu/autotuner_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ void CSVLegend(std::ostream& os) {

os << kCsvComment << " m" << kCsvSep << "n" << kCsvSep << "k" << kCsvSep
<< "batch_count" << kCsvSep << "trans_a" << kCsvSep
<< "trans_b" << kCsvSep
<< "type_a" << kCsvSep << "type_b" << kCsvSep
<< "trans_b" << kCsvSep << "type_a" << kCsvSep << "type_b" << kCsvSep
<< "type_c" << kCsvSep << "lda" << kCsvSep << "ldb" << kCsvSep
<< "ldc" << kCsvSep << "stride_a" << kCsvSep
<< "stride_b" << kCsvSep << "stride_c" << kCsvSep
<< "alg_index" << std::endl;
<< "stride_b" << kCsvSep << "stride_c";
if (full_string) {
os << kCsvSep << "alpha_re" << kCsvSep << "alpha_im" << kCsvSep
<< "beta" << kCsvSep << "epilogue";
}
os << kCsvSep << "alg_index" << std::endl;
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ auto CublasLtMatmulThunk::GetCachedMatmulPlan(
return std::move(plan);
}
}
return InternalError("Wrong algorithm ID: %d", algorithm_id);
TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[0]));
LOG(WARNING) << "Wrong algorithm ID: " << algorithm_id << " use default instead.";
return std::move(plan);
};
return cache.GetOrCreate(canonical_hlo_, create);
}
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ std::string ToCSVString(const GemmConfig& cfg, bool full_string) {

if (full_string) {
// NOTE: epilogue is required for MatmulPlan caching !
oss //<< kCsvSep << cfg.alpha << kCsvSep << cfg.beta
<< kCsvSep << (int64_t)cfg.epilogue;
oss << kCsvSep << cfg.alpha.real() << kCsvSep << cfg.alpha.imag() << kCsvSep << cfg.beta << kCsvSep << (int64_t)cfg.epilogue;
}

return oss.str();
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/input.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
HloModule Test

ENTRY %fused_computation.1408 {
param_2.3691 = f32[11000,256]{1,0} parameter(2)
param_3.2585 = f32[11000,256]{1,0} parameter(3)
add.3118 = f32[11000,256]{1,0} add(param_2.3691, param_3.2585), metadata={op_type="AddN" op_name="gradients/AddN_741" source_file="dummy_file_name" source_line=10}
param_0.4676 = f32[11000,256]{1,0} parameter(0)
multiply.14683 = f32[11000,256]{1,0} multiply(add.3118, param_0.4676), metadata={op_type="Mul" op_name="gradients/home_network/dense_home_ctr_combine_1/mul_1_grad/Mul" source_file="dummy_file_name" source_line=10}
param_1.5886 = f32[11000,256]{1,0} parameter(1)
multiply.14681 = f32[11000,256]{1,0} multiply(add.3118, param_1.5886), metadata={op_type="Mul" op_name="gradients/home_network/dense_home_ctr_combine_1/mul_1_grad/Mul_1" source_file="dummy_file_name" source_line=10}
multiply.14680 = f32[11000,256]{1,0} multiply(multiply.14681, param_0.4676), metadata={op_type="SigmoidGrad" op_name="gradients/home_network/dense_home_ctr_combine_1/Sigmoid_grad/SigmoidGrad" source_file="dummy_file_name" source_line=10}
constant_2074 = f32[] constant(1), metadata={op_type="SigmoidGrad" op_name="gradients/home_network/gates_network_43/home_sl_interact_dtr_idp_native_gates/Sigmoid_grad/SigmoidGrad" source_file="dummy_file_name" source_line=10}
broadcast.6645 = f32[11000,256]{1,0} broadcast(constant_2074), dimensions={}, metadata={op_type="SigmoidGrad" op_name="gradients/home_network/dense_home_ctr_combine_1/Sigmoid_grad/SigmoidGrad" source_file="dummy_file_name" source_line=10}
subtract.1534 = f32[11000,256]{1,0} subtract(broadcast.6645, param_0.4676), metadata={op_type="SigmoidGrad" op_name="gradients/home_network/dense_home_ctr_combine_1/Sigmoid_grad/SigmoidGrad" source_file="dummy_file_name" source_line=10}
multiply.14679 = f32[11000,256]{1,0} multiply(multiply.14680, subtract.1534), metadata={op_type="SigmoidGrad" op_name="gradients/home_network/dense_home_ctr_combine_1/Sigmoid_grad/SigmoidGrad" source_file="dummy_file_name" source_line=10}
add.3117 = f32[11000,256]{1,0} add(multiply.14683, multiply.14679), metadata={op_type="AddN" op_name="gradients/AddN_748" source_file="dummy_file_name" source_line=10}
constant_1446 = f32[] constant(0), metadata={op_type="BiasAddGrad" op_name="gradients/home_network/gates_network_43/home_sl_interact_dtr_idp_native_gates/BiasAdd_grad/BiasAddGrad" source_file="dummy_file_name" source_line=10}
pad.1112 = f32[11008,256]{1,0} pad(add.3117, constant_1446), padding=0_8x0_0, metadata={op_type="AddN" op_name="gradients/AddN_748" source_file="dummy_file_name" source_line=10}
bitcast.53539 = f32[64,172,256]{2,1,0} bitcast(pad.1112), metadata={op_type="AddN" op_name="gradients/AddN_748" source_file="dummy_file_name" source_line=10}
reduce.1970 = f32[64,256]{1,0} reduce(bitcast.53539, constant_1446), dimensions={1}, to_apply=region_0.14357, metadata={op_type="Sum" op_name="gradients/home_network/dense_home_ctr_combine_1/batch_normalization/batchnorm/add_1_grad/Sum" source_file="dummy_file_name" source_line=10}
param_4.6582 = f32[11000,256]{1,0} parameter(4)
multiply.12670.clone.1 = f32[11000,256]{1,0} multiply(add.3117, param_4.6582), metadata={op_type="Mul" op_name="gradients/home_network/dense_home_ctr_combine_1/batch_normalization/batchnorm/mul_1_grad/Mul_1" source_file="dummy_file_name" source_line=10}
pad.1107.clone.1 = f32[11008,256]{1,0} pad(multiply.12670.clone.1, constant_1446), padding=0_8x0_0, metadata={op_type="Mul" op_name="gradients/home_network/dense_home_ctr_combine_1/batch_normalization/batchnorm/mul_1_grad/Mul_1" source_file="dummy_file_name" source_line=10}
bitcast.53531.clone.1 = f32[64,172,256]{2,1,0} bitcast(pad.1107.clone.1), metadata={op_type="Mul" op_name="gradients/home_network/dense_home_ctr_combine_1/batch_normalization/batchnorm/mul_1_grad/Mul_1" source_file="dummy_file_name" source_line=10}
reduce.1965.clone.1 = f32[64,256]{1,0} reduce(bitcast.53531.clone.1, constant_1446), dimensions={1}, to_apply=region_0.14357, metadata={op_type="Sum" op_name="gradients/home_network/dense_home_ctr_combine_1/batch_normalization/batchnorm/mul_1_grad/Sum" source_file="dummy_file_name" source_line=10}
ROOT tuple.283 = (f32[64,256]{1,0}, f32[64,256]{1,0}) tuple(reduce.1970, reduce.1965.clone.1)
}

0 comments on commit 38ed029

Please sign in to comment.