diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 08290e028..7475a383b 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -63,7 +63,6 @@ namespace clad { std::vector m_LoopBlock; unsigned outputArrayCursor = 0; unsigned numParams = 0; - bool enableTBR = false; // FIXME: Should we make this an object instead of a pointer? // Downside of making it an object: We will need to include // 'MultiplexExternalRMVSource.h' file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 54f0f80cf..6b946d94a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -292,10 +292,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString(); } - // Check if DiffRequest asks for TBR analysis to be enabled - if (request.EnableTBRAnalysis) - enableTBR = true; - auto derivativeBaseName = request.BaseFunctionName; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if @@ -475,10 +471,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivativeAndOverload ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, const DiffRequest& request) { - if (request.EnableTBRAnalysis) - enableTBR = true; - TBRAnalyzer analyzer(m_Context); - if (enableTBR) { + if (request.EnableTBRAnalysis) { + TBRAnalyzer analyzer(m_Context); analyzer.Analyze(FD); m_ToBeRecorded = analyzer.getResult(); } @@ -602,8 +596,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DifferentiateWithClad() { - TBRAnalyzer analyzer(m_Context); - if (enableTBR) { + if (m_DiffReq.EnableTBRAnalysis) { + TBRAnalyzer analyzer(m_Context); analyzer.Analyze(m_DiffReq.Function); m_ToBeRecorded = analyzer.getResult(); } @@ -1695,7 +1689,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.Mode = DiffMode::experimental_pullback; // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; - pullbackRequest.EnableTBRAnalysis = enableTBR; + pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (DerivedCallOutputArgs[i + isaMethod]) @@ -2943,7 +2937,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isa(B) || isa(B) || isa(B)) { // If TBR analysis is off, assume E is useful to store. - if (!enableTBR) + if (!m_DiffReq.EnableTBRAnalysis) return true; // FIXME: currently, we allow all pointer operations to be stored. // This is not correct, but we need to implement a more advanced analysis