[Mlir-commits] [mlir] [MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignVarKind (PR #76397)
Bharathi Ramana Joshi
llvmlistbot at llvm.org
Sat Dec 30 23:10:23 PST 2023
https://github.com/iambrj updated https://github.com/llvm/llvm-project/pull/76397
>From 293644ebc223a71a8f8da948d669f65e67ef9978 Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Fri, 22 Dec 2023 13:35:10 +0530
Subject: [PATCH] [MLIR][Presburger] Implement
PresburgerSpace::mergeAndAlignSymbols
---
.../Analysis/Presburger/PresburgerSpace.h | 5 ++
.../Analysis/Presburger/PresburgerSpace.cpp | 32 +++++++++
.../Presburger/PresburgerSpaceTest.cpp | 65 +++++++++++++++++++
3 files changed, 102 insertions(+)
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 9fe2abafd36bad..91ed349f461c69 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -290,6 +290,11 @@ class PresburgerSpace {
/// the symbols in two spaces are aligned.
bool isAligned(const PresburgerSpace &other, VarKind kind) const;
+ /// Merge and align symbol variables of `this` and `other` with respect to
+ /// identifiers. After this operation the symbol variables of both spaces have
+ /// the same identifiers in the same order.
+ void mergeAndAlignSymbols(PresburgerSpace &other);
+
void print(llvm::raw_ostream &os) const;
void dump() const;
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index 185da462aa4453..f11747b58165c5 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -294,6 +294,38 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
// `identifiers` remains same.
}
+void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
+ assert(usingIds && other.usingIds &&
+ "Both spaces need to have identifers to merge & align");
+
+ // First merge & align identifiers into `other` from `this`.
+ unsigned kindBeginOffset = other.getVarKindOffset(VarKind::Symbol);
+ unsigned i = 0;
+ for (const Identifier *identifier =
+ identifiers.begin() + getVarKindOffset(VarKind::Symbol);
+ identifier != identifiers.begin() + getVarKindEnd(VarKind::Symbol); identifier++) {
+ // If the identifier exists in `other`, then align it; otherwise insert it
+ // assuming it is a new identifier. Search in `other` starting at position
+ // `i` since the left of `i` is aligned.
+ auto *findEnd = other.identifiers.begin() + other.getVarKindEnd(VarKind::Symbol);
+ auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
+ findEnd, *identifier);
+ if (itr != findEnd) {
+ std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
+ } else {
+ other.insertVar(VarKind::Symbol, i);
+ other.getId(VarKind::Symbol, i) = *identifier;
+ }
+ i++;
+ }
+
+ // Finally add identifiers that are in `other`, but not in `this` to `this`.
+ for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; i++) {
+ insertVar(VarKind::Symbol, i);
+ getId(VarKind::Symbol, i) = other.getId(VarKind::Symbol, i);
+ }
+}
+
void PresburgerSpace::print(llvm::raw_ostream &os) const {
os << "Domain: " << getNumDomainVars() << ", "
<< "Range: " << getNumRangeVars() << ", "
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index 8229199b233471..ceef2d6d8e4630 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -193,3 +193,68 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
}
+
+TEST(PresburgerSpaceTest, mergeAndAlignSymbols) {
+ PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+ space.resetIds();
+
+ PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+ otherSpace.resetIds();
+
+ // Attach identifiers.
+ int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+ int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ // Note the common identifier.
+ space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+ // Note the common identifier.
+ otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+ otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+ otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+ space.mergeAndAlignSymbols(otherSpace);
+
+ // Check if merge & align is successful.
+ // Check symbol var identifiers.
+ EXPECT_EQ(4u, space.getNumSymbolVars());
+ EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+ Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+ Identifier(&otherIdentifiers[7]));
+ // Check that domain and range var identifiers are not affected.
+ EXPECT_EQ(3u, space.getNumDomainVars());
+ EXPECT_EQ(3u, space.getNumRangeVars());
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
+ EXPECT_EQ(3u, otherSpace.getNumDomainVars());
+ EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0), Identifier(&otherIdentifiers[0]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1), Identifier(&otherIdentifiers[1]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 0), Identifier(&otherIdentifiers[3]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 1), Identifier(&otherIdentifiers[4]));
+}
More information about the Mlir-commits
mailing list