Skip to content

Commit

Permalink
[unit] Add more corner cases and extend BTAS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajay-mk committed Feb 13, 2025
1 parent 6e9f5bf commit e670c48
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/unit/test_eval_btas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,24 @@ TEST_CASE("TEST_EVAL_USING_BTAS", "[eval]") {
btas::scal(0.5, man1);

REQUIRE(norm(eval1) == Catch::Approx(norm(man1)));

auto expr2 = parse_antisymm(L"R_{a1,a2}^{i1}");
auto tidx2 = tidxs(L"a_1,a_2,i_1");
auto eval2 = eval_antisymm(expr2, tidx2);

auto const& r = yield(L"R{v,v;o}");
BTensorD man2{r.range()}, temp2{r.range()};
man2.fill(0.0);
temp2.fill(0.0);

man2 += BTensorD{permute(r, {0, 1, 2})};

temp2 += BTensorD{permute(r, {1, 0, 2})};
btas::scal(-1.0, temp2);
man2 += temp2;
temp2.clear();

REQUIRE(norm(eval2) == Catch::Approx(norm(man2)));
}

SECTION("Symmetrization") {
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_eval_ta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,17 @@ TEST_CASE("TEST_EVAL_USING_TA", "[eval]") {
zero2("0,1,2,3,4") = man2("0,1,2,3,4") - eval2("0,1,2,3,4");
REQUIRE(norm(zero2) == Catch::Approx(0).margin(
100 * std::numeric_limits<double>::epsilon()));

auto expr3 = parse_antisymm(L"R_{a1,a2}^{}");
auto eval3 = eval_antisymm(expr3, "a_1,a_2");
auto const& arr3 = yield(L"R{a1,a2;}");
auto man3 = TArrayD{};
man3("0,1") = arr3("0,1") - arr3("1,0");

TArrayD zero3;
zero3("0,1") = man3("0,1") - eval3("0,1");
REQUIRE(norm(zero3) == Catch::Approx(0).margin(
100 * std::numeric_limits<double>::epsilon()));
}

SECTION("Symmetrization") {
Expand Down Expand Up @@ -615,6 +626,17 @@ TEST_CASE("TEST_EVAL_USING_TA_COMPLEX", "[eval]") {
zero2("0,1,2,3,4") = man2("0,1,2,3,4") - eval2("0,1,2,3,4");
REQUIRE(norm(zero2) == Catch::Approx(0).margin(
100 * std::numeric_limits<double>::epsilon()));

auto expr3 = parse_expr(L"R_{a1,a2}^{}");
auto eval3 = eval_antisymm(expr3, "a_1,a_2");
auto const& arr3 = yield(L"R{a1,a2;}");
auto man3 = TArrayC{};
man3("0,1") = arr3("0,1") - arr3("1,0");

TArrayC zero3;
zero3("0,1") = man3("0,1") - eval3("0,1");
REQUIRE(norm(zero3) == Catch::Approx(0).margin(
100 * std::numeric_limits<double>::epsilon()));
}

SECTION("Symmetrization") {
Expand Down

0 comments on commit e670c48

Please sign in to comment.