Skip to content

Commit

Permalink
Update eval tests due to evaluation node construction logic changed.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Feb 17, 2025
1 parent 5dfe736 commit a317c95
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
14 changes: 12 additions & 2 deletions tests/unit/test_eval_btas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
namespace {

auto eval_node(sequant::ExprPtr const& expr) {
return sequant::eval_node<sequant::EvalExprBTAS>(expr);
using namespace sequant;
auto node = binarize(expr);
return transform_node(node, [](auto&& val) {
if (val.is_tensor()) {
return EvalExprBTAS(
val.op_type(), val.result_type(), val.expr(),
val.as_tensor().indices() | ranges::to<EvalExpr::index_vector>(),
val.canon_phase(), val.hash_value());
} else
return EvalExprBTAS(val);
});
}

static auto const idx_rgx = boost::wregex{L"([ia])([↑↓])?_?(\\d+)"};
Expand Down Expand Up @@ -168,7 +178,7 @@ container::svector<long> tidxs(std::wstring const& csv) noexcept {

} // namespace

TEST_CASE("TEST_EVAL_USING_BTAS", "[eval]") {
TEST_CASE("TEST_EVAL_USING_BTAS", "[eval_btas]") {
using ranges::views::transform;
using namespace sequant;
using namespace sequant;
Expand Down
57 changes: 55 additions & 2 deletions tests/unit/test_eval_ta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,61 @@
#include <vector>

namespace {

///
/// \brief Represents the outer indices and the inner indices of a nested
/// tensor.
///
/// \note The nested tensor is a concept that generalizes the sequant::Tensor
/// with and without proto indices. sequant::Tensors with proto indices have
/// outer and inner indices, whereas, those without proto indices only have
/// outer indices.
///
struct NestedTensorIndices {
sequant::container::svector<sequant::Index> outer, inner;

explicit NestedTensorIndices(sequant::Tensor const& tnsr) {

using ranges::views::join;
using ranges::views::transform;
using namespace sequant;

for (auto&& ix : tnsr.aux()) {
assert(!ix.has_proto_indices() &&
"Aux indices with proto indices not supported");
outer.emplace_back(ix);
}

auto append_unique = [](auto& cont, auto const& el) {
if (!ranges::contains(cont, el)) cont.emplace_back(el);
};

for (Index const& ix : tnsr.const_braket())
append_unique(ix.has_proto_indices() ? inner : outer, ix);

for (Index const& ix :
tnsr.const_braket() | transform(&Index::proto_indices) | join)
append_unique(outer, ix);
}

[[nodiscard]] auto outer_inner() const noexcept {
return ranges::views::concat(outer, inner);
}
};


auto eval_node(sequant::ExprPtr const& expr) {
return sequant::eval_node<sequant::EvalExprTA>(expr);
using namespace sequant;
auto node = binarize(expr);
return transform_node(node, [](auto&& val) {
if (val.is_tensor()) {
return EvalExprTA(
val.op_type(), val.result_type(), val.expr(),
NestedTensorIndices(val.as_tensor()).outer_inner() | ranges::to<EvalExpr::index_vector>(),
val.canon_phase(), val.hash_value());
} else
return EvalExprTA(val);
});
}

auto tensor_to_key(sequant::Tensor const& tnsr) {
Expand All @@ -27,7 +80,7 @@ auto tensor_to_key(sequant::Tensor const& tnsr) {
return (mo[1].str() == L"i" ? L"o" : L"v") + mo[2].str();
};

sequant::NestedTensorIndices oixs{tnsr};
NestedTensorIndices oixs{tnsr};
if (oixs.inner.empty()) {
auto const tnsr_deparsed = sequant::deparse(tnsr.clone(), false);
return boost::regex_replace(tnsr_deparsed, idx_rgx, formatter);
Expand Down

0 comments on commit a317c95

Please sign in to comment.