Skip to content

Commit

Permalink
[Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstru…
Browse files Browse the repository at this point in the history
…ct (llvm#87247)

Emit a special leaf construct table in DirectiveEmitter.cpp, which will
allow both decomposition of a construct into leafs, and composition of
constituent constructs into a single compound construct (if possible).
The function `getLeafConstructs` is no longer auto-generated, but
implemented in OMP.cpp.

The table contains a row for each directive, and each row has the
following format
`dir_id, num_leafs, leaf1, leaf2, ..., leafN, -1, ...`
The rows are sorted lexicographically with respect to the leaf
constructs. This allows a binary search for the row corresponding to the
given list of leafs.

There is an auxiliary table that for each directive contains the index
of the row corresponding to that directive.

Looking up leaf constructs for a directive `dir_id` is of constant time,
and and consists of two lookups: `LeafTable[Auxiliary[dir_id]]`.
Finding a compound directive given the set of leafs is of time O(logn),
and is roughly represented by
`row = binary_search(LeafTable); return row[0]`.

The functions `getLeafConstructs` and `getCompoundConstruct` use these
lookup methods internally.
  • Loading branch information
kparzysz authored Apr 22, 2024
1 parent 89c95ef commit 40137ff
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 86 deletions.
7 changes: 7 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@

#include "llvm/Frontend/OpenMP/OMP.h.inc"

#include "llvm/ADT/ArrayRef.h"

namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D);
Directive getCompoundConstruct(ArrayRef<Directive> Parts);
} // namespace llvm::omp

#endif // LLVM_FRONTEND_OPENMP_OMP_H
70 changes: 69 additions & 1 deletion llvm/lib/Frontend/OpenMP/OMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,80 @@

#include "llvm/Frontend/OpenMP/OMP.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"

#include <algorithm>
#include <iterator>
#include <type_traits>

using namespace llvm;
using namespace omp;
using namespace llvm::omp;

#define GEN_DIRECTIVES_IMPL
#include "llvm/Frontend/OpenMP/OMP.inc"

namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
auto Idx = static_cast<std::size_t>(D);
if (Idx >= Directive_enumSize)
std::nullopt;
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
return ArrayRef(&Row[2], static_cast<int>(Row[1]));
}

Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
if (Parts.empty())
return OMPD_unknown;

// Parts don't have to be leafs, so expand them into leafs first.
// Store the expanded leafs in the same format as rows in the leaf
// table (generated by tablegen).
SmallVector<Directive> RawLeafs(2);
for (Directive P : Parts) {
ArrayRef<Directive> Ls = getLeafConstructs(P);
if (!Ls.empty())
RawLeafs.append(Ls.begin(), Ls.end());
else
RawLeafs.push_back(P);
}

// RawLeafs will be used as key in the binary search. The search doesn't
// guarantee that the exact same entry will be found (since RawLeafs may
// not correspond to any compound directive). Because of that, we will
// need to compare the search result with the given set of leafs.
// Also, if there is only one leaf in the list, it corresponds to itself,
// no search is necessary.
auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
if (GivenLeafs.size() == 1)
return GivenLeafs.front();
RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());

auto Iter = std::lower_bound(
LeafConstructTable, LeafConstructTableEndDirective,
static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
[](const llvm::omp::Directive *RowA, const llvm::omp::Directive *RowB) {
const auto *BeginA = &RowA[2];
const auto *EndA = BeginA + static_cast<int>(RowA[1]);
const auto *BeginB = &RowB[2];
const auto *EndB = BeginB + static_cast<int>(RowB[1]);
if (BeginA == EndA && BeginB == EndB)
return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
});

if (Iter == std::end(LeafConstructTable))
return OMPD_unknown;

// Verify that we got a match.
Directive Found = (*Iter)[0];
ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
if (FoundLeafs == GivenLeafs)
return Found;
return OMPD_unknown;
}
} // namespace llvm::omp
21 changes: 13 additions & 8 deletions llvm/test/TableGen/directive1.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
// CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h"
// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
Expand Down Expand Up @@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: AKind getAKind(StringRef);
// CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind);
Expand Down Expand Up @@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: default:
// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
// IMPL-NEXT: } // switch (Dir)
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
Expand All @@ -374,4 +368,15 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
// IMPL-NEXT: };
// IMPL-EMPTY:
// IMPL-NEXT: {{.*}} static auto LeafConstructTableEndDirective = LeafConstructTable + 1;
// IMPL-EMPTY:
// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
// IMPL-NEXT: 0,
// IMPL-NEXT: };
// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
21 changes: 13 additions & 8 deletions llvm/test/TableGen/directive2.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: #define LLVM_Tdl_INC
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
Expand Down Expand Up @@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: } // namespace tdl
// CHECK-NEXT: } // namespace llvm
Expand Down Expand Up @@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: default:
// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
// IMPL-NEXT: } // switch (Dir)
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
Expand All @@ -305,4 +299,15 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
// IMPL-NEXT: };
// IMPL-EMPTY:
// IMPL-NEXT: {{.*}} static auto LeafConstructTableEndDirective = LeafConstructTable + 1;
// IMPL-EMPTY:
// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
// IMPL-NEXT: 0,
// IMPL-NEXT: };
// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
1 change: 1 addition & 0 deletions llvm/unittests/Frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPContextTest.cpp
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
OpenMPComposeTest.cpp

DEPENDS
acc_gen
Expand Down
41 changes: 41 additions & 0 deletions llvm/unittests/Frontend/OpenMPComposeTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/ArrayRef.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "gtest/gtest.h"

using namespace llvm;
using namespace llvm::omp;

TEST(Composition, GetLeafConstructs) {
ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
ASSERT_EQ(L1, (ArrayRef<Directive>{}));
ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
}

TEST(Composition, GetCompoundConstruct) {
Directive C1 =
getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
ASSERT_EQ(C1, OMPD_target_teams_distribute);
Directive C2 = getCompoundConstruct({OMPD_target});
ASSERT_EQ(C2, OMPD_target);
Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
ASSERT_EQ(C3, OMPD_unknown);
Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
ASSERT_EQ(C4, OMPD_target_teams_distribute);
Directive C5 = getCompoundConstruct({});
ASSERT_EQ(C5, OMPD_unknown);
Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
ASSERT_EQ(C6, OMPD_parallel_for_simd);
Directive C7 = getCompoundConstruct({OMPD_do, OMPD_simd});
ASSERT_EQ(C7, OMPD_do_simd); // Make sure it's not OMPD_end_do_simd
}
Loading

0 comments on commit 40137ff

Please sign in to comment.