Skip to content

Commit

Permalink
generalize IndexSwapper by SwapCounter<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Feb 15, 2025
1 parent 3ada27d commit 3b1aa88
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 39 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ set(SeQuant_src
SeQuant/core/utility/string.hpp
SeQuant/core/utility/string.cpp
SeQuant/core/utility/tuple.hpp
SeQuant/core/utility/swap.hpp
SeQuant/core/wick.hpp
SeQuant/core/wick.impl.hpp
SeQuant/core/wolfram.hpp
Expand Down Expand Up @@ -415,7 +416,7 @@ if (SEQUANT_IWYU)
endif()

if (SEQUANT_BENCHMARKS)
add_subdirectory(benchmarks)
add_subdirectory(benchmarks)
endif()

### unit tests
Expand Down
28 changes: 4 additions & 24 deletions SeQuant/core/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
#define SEQUANT_INDEX_H

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/hash.hpp>
#include <iostream>
// #include <SeQuant/core/space.hpp>
#include <SeQuant/core/context.hpp>
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/tag.hpp>
#include <SeQuant/core/utility/string.hpp>
// Only needed due to a (likely) compiler bug in Apple Clang
// #include <SeQuant/core/attr.hpp>
#include <SeQuant/core/utility/swap.hpp>

#include <algorithm>
#include <atomic>
Expand All @@ -23,6 +20,7 @@
#include <cwchar>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <map>
#include <mutex>
Expand Down Expand Up @@ -868,28 +866,10 @@ void Index::canonicalize_proto_indices() {
std::stable_sort(begin(proto_indices_), end(proto_indices_));
}

class IndexSwapper {
public:
IndexSwapper() : even_num_of_swaps_(true) {}
static IndexSwapper &thread_instance() {
static thread_local IndexSwapper instance_{};
return instance_;
}

bool even_num_of_swaps() const { return even_num_of_swaps_; }
void reset() { even_num_of_swaps_ = true; }

private:
std::atomic<bool> even_num_of_swaps_;
void toggle() { even_num_of_swaps_ = !even_num_of_swaps_; }

friend inline void swap(Index &, Index &);
};

/// swap operator helps tracking # of swaps
inline void swap(Index &first, Index &second) {
std::swap(first, second);
IndexSwapper::thread_instance().toggle();
detail::count_swap<Index>();
}

/// Generates temporary indices
Expand Down
5 changes: 2 additions & 3 deletions SeQuant/core/tensor_canonicalizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,13 @@ class DefaultTensorCanonicalizer : public TensorCanonicalizer {
auto _bra = bra_range(t);
auto _ket = ket_range(t);
// std::wcout << "canonicalizing " << to_latex(t);
IndexSwapper::thread_instance().reset();
reset_ts_swap_counter<Index>();
// std::{stable_}sort does not necessarily use swap! so must implement
// sort outselves .. thankfully ranks will be low so can stick with
// bubble
bubble_sort(begin(_bra), end(_bra), idxcmp);
bubble_sort(begin(_ket), end(_ket), idxcmp);
if (is_antisymm)
even = IndexSwapper::thread_instance().even_num_of_swaps();
if (is_antisymm) even = ts_swap_counter_is_even<Index>();
// std::wcout << " is " << (even ? "even" : "odd") << " and
// produces " << to_latex(t) << std::endl;
} break;
Expand Down
125 changes: 125 additions & 0 deletions SeQuant/core/utility/swap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//
// Created by Eduard Valeyev on 2/14/25.
//

#ifndef SEQUANT_CORE_UTILITY_SWAPPABLE_HPP
#define SEQUANT_CORE_UTILITY_SWAPPABLE_HPP

#include <atomic>

namespace sequant {

/// use this as a swap-countable deep wrapper for types whose swaps are not
/// counted
template <typename T>
class SwapCountable;

/// use this as a swap-countable shallow wrapper for types whose swaps are not
/// counted
template <typename T>
class SwapCountableRef;

/// counted swap for types whose default swap is not counted
template <typename T>
inline void counted_swap(T& a, T& b);

namespace detail {

/// use this for implementing custom swap that counts swaps by default
template <typename T>
inline void count_swap();

/// atomic counter, used by swap overloads for SwapCountable and
/// SwapCountableRef
template <typename T>
struct SwapCounter {
SwapCounter() : even_num_of_swaps_(true) {}
static SwapCounter& thread_instance() {
static thread_local SwapCounter instance_{};
return instance_;
}

bool even_num_of_swaps() const { return even_num_of_swaps_; }
void reset() { even_num_of_swaps_ = true; }

private:
std::atomic<bool> even_num_of_swaps_;
void toggle() { even_num_of_swaps_ = !even_num_of_swaps_; }

friend class SwapCountable<T>;
friend class SwapCountableRef<T>;
friend inline void counted_swap<T>(T& a, T& b);
friend inline void count_swap<T>();
};

template <typename T>
inline void count_swap() {
detail::SwapCounter<T>::thread_instance().toggle();
}

} // namespace detail

/// Wraps `T` to make its swap countable
template <typename T>
class SwapCountable {
public:
template <typename U>
explicit SwapCountable(U&& v) : value_(std::forward<U>(v)) {}

private:
T value_;

friend inline void swap(SwapCountable& a, SwapCountable& b) {
using std::swap;
swap(a.value_, b.value_);
detail::SwapCounter<T>::thread_instance().toggle();
}

friend inline bool operator<(const SwapCountable& a, const SwapCountable& b) {
return a.value_ < b.value_;
}
};

/// Wraps `T&` to make its swap countable
template <typename T>
class SwapCountableRef {
public:
explicit SwapCountableRef(T& ref) : ref_(ref) {}

private:
T& ref_;

// NB swapping const wrappers swaps the payload
friend inline void swap(const SwapCountableRef& a,
const SwapCountableRef& b) {
using std::swap;
swap(a.ref_, b.ref_);
detail::SwapCounter<T>::thread_instance().toggle();
}

friend inline bool operator<(const SwapCountableRef& a,
const SwapCountableRef& b) {
return a.ref_ < b.ref_;
}
};

template <typename T>
void reset_ts_swap_counter() {
detail::SwapCounter<T>::thread_instance().reset();
}

template <typename T>
bool ts_swap_counter_is_even() {
return detail::SwapCounter<T>::thread_instance().even_num_of_swaps();
}

template <typename T>
inline void counted_swap(T& a, T& b) {
using std::swap;
swap(a, b);
detail::SwapCounter<T>::thread_instance().toggle();
}

} // namespace sequant

#endif // SEQUANT_CORE_UTILITY_SWAPPABLE_HPP
21 changes: 10 additions & 11 deletions SeQuant/domain/mbpt/spin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ ExprPtr expand_antisymm(const Tensor& tensor, bool skip_spinsymm) {
auto get_phase = [](const Tensor& t) {
container::svector<Index> bra(t.bra().begin(), t.bra().end());
container::svector<Index> ket(t.ket().begin(), t.ket().end());
IndexSwapper::thread_instance().reset();
bubble_sort(std::begin(bra), std::end(bra), std::less<Index>{});
bubble_sort(std::begin(ket), std::end(ket), std::less<Index>{});
return IndexSwapper::thread_instance().even_num_of_swaps() ? 1 : -1;
reset_ts_swap_counter<Index>();
bubble_sort(std::begin(bra), std::end(bra));
bubble_sort(std::begin(ket), std::end(ket));
return ts_swap_counter_is_even<Index>() ? 1 : -1;
};

// Generate a sum of asymmetric tensors if the input tensor is antisymmetric
Expand Down Expand Up @@ -476,10 +476,9 @@ ExprPtr expand_A_op(const Product& product) {
container::svector<Index> transformed_list;
for (const auto& [key, val] : map) transformed_list.push_back(val);

IndexSwapper::thread_instance().reset();
bubble_sort(std::begin(transformed_list), std::end(transformed_list),
std::less<Index>{});
phase = IndexSwapper::thread_instance().even_num_of_swaps() ? 1 : -1;
reset_ts_swap_counter<Index>();
bubble_sort(std::begin(transformed_list), std::end(transformed_list));
phase = ts_swap_counter_is_even<Index>() ? 1 : -1;
}

Product new_product{};
Expand Down Expand Up @@ -564,9 +563,9 @@ ExprPtr symmetrize_expr(const Product& product) {
auto get_phase = [](const container::map<Index, Index>& map) {
container::svector<Index> idx_list;
for (const auto& [key, val] : map) idx_list.push_back(val);
IndexSwapper::thread_instance().reset();
bubble_sort(std::begin(idx_list), std::end(idx_list), std::less<Index>{});
return IndexSwapper::thread_instance().even_num_of_swaps() ? 1 : -1;
reset_ts_swap_counter<Index>();
bubble_sort(std::begin(idx_list), std::end(idx_list));
return ts_swap_counter_is_even<Index>() ? 1 : -1;
};

container::svector<container::map<Index, Index>> maps;
Expand Down

0 comments on commit 3b1aa88

Please sign in to comment.