[Mlir-commits] [mlir] [MLIR][Presburger] Use Identifiers outside Presburger library (PR #77316)
Bharathi Ramana Joshi
llvmlistbot at llvm.org
Fri Apr 5 10:59:47 PDT 2024
https://github.com/iambrj updated https://github.com/llvm/llvm-project/pull/77316
>From 513f62eb8dc3ce7cb96e4b9ee451f7c265cd4042 Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Sun, 21 Jan 2024 01:39:04 +0530
Subject: [PATCH 1/2] [MLIR][Presburger] Implement preserve identifiers in
IntegerRelation::convertVarKind
---
.../Analysis/Presburger/IntegerRelation.cpp | 26 ++--
.../Presburger/IntegerRelationTest.cpp | 136 ++++++++++++++++++
2 files changed, 146 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index a3f971db4bd428..69b815a137b7a5 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1451,28 +1451,22 @@ void IntegerRelation::removeRedundantLocalVars() {
void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
unsigned varLimit, VarKind dstKind,
unsigned pos) {
- assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range");
+ assert(varLimit <= getNumVarKind(srcKind) && "invalid id range");
if (varStart >= varLimit)
return;
- // Append new local variables corresponding to the dimensions to be converted.
+ unsigned srcOffset = getVarKindOffset(srcKind);
+ unsigned dstOffset = getVarKindOffset(dstKind);
unsigned convertCount = varLimit - varStart;
- unsigned newVarsBegin = insertVar(dstKind, pos, convertCount);
+ int forwardMoveOffset = dstOffset > srcOffset ? -convertCount : 0;
- // Swap the new local variables with dimensions.
- //
- // Essentially, this moves the information corresponding to the specified ids
- // of kind `srcKind` to the `convertCount` newly created ids of kind
- // `dstKind`. In particular, this moves the columns in the constraint
- // matrices, and zeros out the initially occupied columns (because the newly
- // created ids we're swapping with were zero-initialized).
- unsigned offset = getVarKindOffset(srcKind);
- for (unsigned i = 0; i < convertCount; ++i)
- swapVar(offset + varStart + i, newVarsBegin + i);
-
- // Complete the move by deleting the initially occupied columns.
- removeVarRange(srcKind, varStart, varLimit);
+ equalities.moveColumns(srcOffset + varStart, convertCount,
+ dstOffset + pos + forwardMoveOffset);
+ inequalities.moveColumns(srcOffset + varStart, convertCount,
+ dstOffset + pos + forwardMoveOffset);
+
+ space.convertVarKind(srcKind, varStart, varLimit - varStart, dstKind, pos);
}
void IntegerRelation::addBound(BoundType type, unsigned pos,
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 00d2204c9c8ef1..945b3d502f6973 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -485,3 +485,139 @@ TEST(IntegerRelationTest, setId) {
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
}
+
+TEST(IntegerRelationTest, convertVarKind) {
+ PresburgerSpace space = PresburgerSpace::getSetSpace(3, 3, 0);
+ space.resetIds();
+
+ // Attach identifiers.
+ int identifiers[6] = {0, 1, 2, 3, 4, 5};
+ space.getId(VarKind::SetDim, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::SetDim, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::SetDim, 2) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+
+ // Cannot call parseIntegerRelation to test convertVarKind as
+ // parseIntegerRelation uses convertVarKind.
+ IntegerRelation rel = parseIntegerPolyhedron(
+ // 0 1 2 3 4 5
+ "(x, y, a)[U, V, W] : (x - U == 0, y + a - W == 0, U - V >= 0,"
+ "y - a >= 0)");
+ rel.setSpace(space);
+
+ // Make a few kind conversions.
+ rel.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 0);
+ rel.convertVarKind(VarKind::Range, 2, 3, VarKind::Domain, 0);
+ rel.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 1);
+ rel.convertVarKind(VarKind::Domain, 1, 2, VarKind::Range, 0);
+ rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 1);
+
+ space = rel.getSpace();
+
+ // Expected rel.
+ IntegerRelation expectedRel = parseIntegerPolyhedron(
+ "(V, a)[U, x, y, W] : (x - U == 0, y + a - W == 0, U - V >= 0,"
+ "y - a >= 0)");
+ expectedRel.setSpace(space);
+
+ EXPECT_TRUE(rel.isEqual(expectedRel));
+
+ EXPECT_EQ(space.getId(VarKind::SetDim, 0), Identifier(&identifiers[4]));
+ EXPECT_EQ(space.getId(VarKind::SetDim, 1), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[5]));
+}
+
+TEST(IntegerRelationTest, convertVarKindToLocal) {
+ // Convert all range variables to local variables.
+ IntegerRelation rel = parseRelationFromSet(
+ "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+ 1);
+ PresburgerSpace space = rel.getSpace();
+ space.resetIds();
+ // Attach identifiers.
+ char identifiers[5] = {'x', 'y', 'z', 'N', 'M'};
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 2) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ rel.setSpace(space);
+ rel.convertToLocal(VarKind::Range, 0, rel.getNumRangeVars());
+ IntegerRelation expectedRel =
+ parseRelationFromSet("(x)[N, M] : (x - N >= 0, 2 * M - 5 >= 0)", 1);
+ EXPECT_TRUE(rel.isEqual(expectedRel));
+ space = rel.getSpace();
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+ // Convert all domain variables to local variables.
+ IntegerRelation rel2 = parseRelationFromSet(
+ "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+ 2);
+ space = rel2.getSpace();
+ space.resetIds();
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ rel2.setSpace(space);
+ rel2.convertToLocal(VarKind::Domain, 0, rel2.getNumDomainVars());
+ expectedRel =
+ parseIntegerPolyhedron("(z)[N, M] : (3 - z >= 0, 2 * M - 5 >= 0)");
+ EXPECT_TRUE(rel2.isEqual(expectedRel));
+ space = rel2.getSpace();
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+ // Convert a prefix of range variables to local variables.
+ IntegerRelation rel3 = parseRelationFromSet(
+ "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+ 1);
+ space = rel3.getSpace();
+ space.resetIds();
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ rel3.setSpace(space);
+ rel3.convertToLocal(VarKind::Range, 0, 1);
+ expectedRel = parseRelationFromSet(
+ "(x, z)[N, M] : (x - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)", 1);
+ EXPECT_TRUE(rel3.isEqual(expectedRel));
+ space = rel3.getSpace();
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+ // Convert a suffix of domain variables to local variables.
+ IntegerRelation rel4 = parseRelationFromSet(
+ "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+ 2);
+ space = rel4.getSpace();
+ space.resetIds();
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ rel4.setSpace(space);
+ rel4.convertToLocal(VarKind::Domain, rel4.getNumDomainVars() - 1,
+ rel4.getNumDomainVars());
+ // expectedRel same as before.
+ EXPECT_TRUE(rel4.isEqual(expectedRel));
+ space = rel4.getSpace();
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+}
>From a8aad6f8d33b12c99161e8c329af34cb7f85c4a7 Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Fri, 5 Apr 2024 23:23:44 +0530
Subject: [PATCH 2/2] [MLIR][Presburger] Use Identifiers outside Presburger
library
---
.../Analysis/FlatLinearValueConstraints.h | 107 ++++++-------
.../Analysis/Presburger/IntegerRelation.h | 7 +
.../Analysis/Presburger/PresburgerSpace.h | 5 +-
.../Dialect/Affine/Analysis/AffineAnalysis.h | 3 +-
.../Affine/Analysis/AffineStructures.h | 13 +-
.../Analysis/FlatLinearValueConstraints.cpp | 140 ++++++------------
.../Analysis/Presburger/IntegerRelation.cpp | 58 ++++++++
.../Affine/Analysis/AffineAnalysis.cpp | 72 +++++----
.../Affine/Analysis/AffineStructures.cpp | 42 +++---
.../Presburger/IntegerRelationTest.cpp | 4 +-
10 files changed, 253 insertions(+), 198 deletions(-)
diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index e4de5b0661571c..6994813bf9a7ff 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -205,6 +205,10 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
/// where each non-local variable can have an SSA Value attached to it.
class FlatLinearValueConstraints : public FlatLinearConstraints {
public:
+ /// The SSA Values attached to each non-local variable are stored as
+ /// identifiers in the constraint system's space.
+ using Identifier = presburger::Identifier;
+
/// Constructs a constraint system reserving memory for the specified number
/// of constraints and variables. `valArgs` are the optional SSA values
/// associated with each dimension/symbol. These must either be empty or match
@@ -217,11 +221,11 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
: FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
numReservedCols, numDims, numSymbols, numLocals) {
assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
- values.reserve(numReservedCols);
- if (valArgs.empty())
- values.resize(getNumDimAndSymbolVars(), std::nullopt);
- else
- values.append(valArgs.begin(), valArgs.end());
+ // Store Values in space's identifiers.
+ space.resetIds();
+ for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+ if (valArgs[i])
+ setValue(i, *valArgs[i]);
}
/// Constructs a constraint system reserving memory for the specified number
@@ -236,11 +240,11 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
: FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
numReservedCols, numDims, numSymbols, numLocals) {
assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
- values.reserve(numReservedCols);
- if (valArgs.empty())
- values.resize(getNumDimAndSymbolVars(), std::nullopt);
- else
- values.append(valArgs.begin(), valArgs.end());
+ // Store Values in space's identifiers.
+ space.resetIds();
+ for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+ if (valArgs[i])
+ setValue(i, valArgs[i]);
}
/// Constructs a constraint system with the specified number of dimensions
@@ -272,11 +276,15 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
FlatLinearValueConstraints(const IntegerPolyhedron &fac,
ArrayRef<std::optional<Value>> valArgs = {})
: FlatLinearConstraints(fac) {
- assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
+ // Do not reset values assigned by FlatLinearConstraints' constructor.
if (valArgs.empty())
- values.resize(getNumDimAndSymbolVars(), std::nullopt);
- else
- values.append(valArgs.begin(), valArgs.end());
+ return;
+ assert(valArgs.size() == getNumDimAndSymbolVars());
+ // Store Values in space's identifiers.
+ space.resetIds();
+ for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+ if (valArgs[i])
+ setValue(i, *valArgs[i]);
}
/// Creates an affine constraint system from an IntegerSet.
@@ -290,9 +298,6 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
cst->getKind() <= Kind::FlatAffineRelation;
}
- /// Replaces the contents of this FlatLinearValueConstraints with `other`.
- void clearAndCopyFrom(const IntegerRelation &other) override;
-
/// Adds a constant bound for the variable associated with the given Value.
void addBound(presburger::BoundType type, Value val, int64_t value);
using FlatLinearConstraints::addBound;
@@ -302,7 +307,9 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
inline Value getValue(unsigned pos) const {
assert(pos < getNumDimAndSymbolVars() && "Invalid position");
assert(hasValue(pos) && "variable's Value not set");
- return *values[pos];
+ VarKind kind = getVarKindAt(pos);
+ unsigned relativePos = pos - getVarKindOffset(kind);
+ return space.getId(kind, relativePos).getValue<Value>();
}
/// Returns the Values associated with variables in range [start, end).
@@ -313,25 +320,44 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
assert(start <= end && "invalid start position");
values->clear();
values->reserve(end - start);
- for (unsigned i = start; i < end; i++)
+ for (unsigned i = start; i < end; ++i)
values->push_back(getValue(i));
}
- inline ArrayRef<std::optional<Value>> getMaybeValues() const {
- return {values.data(), values.size()};
+ inline SmallVector<std::optional<Value>> getMaybeValues() const {
+ SmallVector<std::optional<Value>> maybeValues;
+ maybeValues.reserve(getNumDimAndSymbolVars());
+ for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
+ if (hasValue(i)) {
+ maybeValues.push_back(getValue(i));
+ } else {
+ maybeValues.push_back(std::nullopt);
+ }
+ return maybeValues;
}
- inline ArrayRef<std::optional<Value>>
+ inline SmallVector<std::optional<Value>>
getMaybeValues(presburger::VarKind kind) const {
assert(kind != VarKind::Local &&
"Local variables do not have any value attached to them.");
- return {values.data() + getVarKindOffset(kind), getNumVarKind(kind)};
+ SmallVector<std::optional<Value>> maybeValues;
+ maybeValues.reserve(getNumVarKind(kind));
+ const unsigned offset = space.getVarKindOffset(kind);
+ for (unsigned i = 0, e = getNumVarKind(kind); i < e; ++i) {
+ if (hasValue(offset + i))
+ maybeValues.push_back(getValue(offset + i));
+ else
+ maybeValues.push_back(std::nullopt);
+ }
+ return maybeValues;
}
/// Returns true if the pos^th variable has an associated Value.
inline bool hasValue(unsigned pos) const {
assert(pos < getNumDimAndSymbolVars() && "Invalid position");
- return values[pos].has_value();
+ VarKind kind = getVarKindAt(pos);
+ unsigned relativePos = pos - getVarKindOffset(kind);
+ return space.getId(kind, relativePos).hasValue();
}
unsigned appendDimVar(ValueRange vals);
@@ -358,9 +384,15 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
using IntegerPolyhedron::removeVarRange;
/// Sets the Value associated with the pos^th variable.
+ /// Stores the Value in the space's identifiers.
inline void setValue(unsigned pos, Value val) {
assert(pos < getNumDimAndSymbolVars() && "invalid var position");
- values[pos] = val;
+ VarKind kind = getVarKindAt(pos);
+ unsigned relativePos = pos - getVarKindOffset(kind);
+ if (!space.isUsingIds()) {
+ space.resetIds();
+ }
+ space.getId(kind, relativePos) = presburger::Identifier(val);
}
/// Sets the Values associated with the variables in the range [start, end).
@@ -387,9 +419,6 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
void projectOut(Value val);
using IntegerPolyhedron::projectOut;
- /// Swap the posA^th variable with the posB^th variable.
- void swapVar(unsigned posA, unsigned posB) override;
-
/// Prints the number of constraints, dimensions, symbols and locals in the
/// FlatAffineValueConstraints. Also, prints for each variable whether there
/// is an SSA Value attached to it.
@@ -444,28 +473,6 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
/// output = {0 <= d0 <= 6, 1 <= d1 <= 15}
LogicalResult unionBoundingBox(const FlatLinearValueConstraints &other);
using IntegerPolyhedron::unionBoundingBox;
-
-protected:
- /// Eliminates the variable at the specified position using Fourier-Motzkin
- /// variable elimination, but uses Gaussian elimination if there is an
- /// equality involving that variable. If the result of the elimination is
- /// integer exact, `*isResultIntegerExact` is set to true. If `darkShadow` is
- /// set to true, a potential under approximation (subset) of the rational
- /// shadow / exact integer shadow is computed.
- // See implementation comments for more details.
- void fourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
- bool *isResultIntegerExact = nullptr) override;
-
- /// Returns false if the fields corresponding to various variable counts, or
- /// equality/inequality buffer sizes aren't consistent; true otherwise. This
- /// is meant to be used within an assert internally.
- bool hasConsistentState() const override;
-
- /// Values corresponding to the (column) non-local variables of this
- /// constraint system appearing in the order the variables correspond to
- /// columns. Variables that aren't associated with any Value are set to
- /// std::nullopt.
- SmallVector<std::optional<Value>, 8> values;
};
/// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 27dc382c1d5dbe..7efbbd003b0371 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -127,6 +127,8 @@ class IntegerRelation {
/// the variable.
void setId(VarKind kind, unsigned i, Identifier id);
+ void resetIds() { space.resetIds(); }
+
/// Returns a copy of the space without locals.
PresburgerSpace getSpaceWithoutLocals() const {
return PresburgerSpace::getRelationSpace(space.getNumDomainVars(),
@@ -674,6 +676,11 @@ class IntegerRelation {
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);
+ /// Given a relation `other: (A -> B)`, this operation merges the symbol and
+ /// local variables and then takes the composition of `other` on `this: (B ->
+ /// C)`. The resulting relation represents tuples of the form: `A -> C`.
+ void mergeAndCompose(const IntegerRelation &other);
+
/// Compute an equivalent representation of the same set, such that all local
/// vars in all disjuncts have division representations. This representation
/// may involve local vars that correspond to divisions, and may also be a
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 91ed349f461c69..b4f2cf3970e5aa 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -250,11 +250,12 @@ class PresburgerSpace {
/// locals).
bool isEqual(const PresburgerSpace &other) const;
- /// Get the identifier of the specified variable.
+ /// Get a mutable reference to the identifier of the specified variable.
Identifier &getId(VarKind kind, unsigned pos) {
assert(kind != VarKind::Local && "Local variables have no identifiers");
return identifiers[getVarKindOffset(kind) + pos];
}
+
Identifier getId(VarKind kind, unsigned pos) const {
assert(kind != VarKind::Local && "Local variables have no identifiers");
return identifiers[getVarKindOffset(kind) + pos];
@@ -265,6 +266,8 @@ class PresburgerSpace {
return {identifiers.data() + getVarKindOffset(kind), getNumVarKind(kind)};
}
+ ArrayRef<Identifier> getIds() const { return identifiers; }
+
/// Returns if identifiers are being used.
bool isUsingIds() const { return usingIds; }
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
index a27583877b603c..4134aef8174bc1 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
@@ -15,6 +15,7 @@
#ifndef MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
#define MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
@@ -115,7 +116,7 @@ struct MemRefAccess {
///
/// Returns failure for yet unimplemented/unsupported cases (see docs of
/// mlir::getIndexSet and mlir::getRelationFromMap for these cases).
- LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const;
+ LogicalResult getAccessRelation(presburger::IntegerRelation &accessRel) const;
/// Populates 'accessMap' with composition of AffineApplyOps reachable from
/// 'indices'.
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
index 7c500f13895af1..c9d3dc1abfd1df 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
@@ -154,6 +154,14 @@ class FlatAffineValueConstraints : public FlatLinearValueConstraints {
/// represented as a FlatAffineValueConstraints with separation of dimension
/// variables into domain and range. The variables are stored as:
/// [domainVars, rangeVars, symbolVars, localVars, constant].
+///
+/// Deprecated: use IntegerRelation and store SSA Values in the PresburgerSpace
+/// of the relation using PresburgerSpace::identifiers. Note that
+/// FlatAffineRelation::numDomainDims and FlatAffineRelation::numRangeDims are
+/// independent of numDomain and numRange of the relation's space. In
+/// particular, operations such as FlatAffineRelation::compose do not ensure
+/// consistency between numDomainDims/numRangeDims and numDomain/numRange which
+/// may lead to unexpected behaviour.
class FlatAffineRelation : public FlatAffineValueConstraints {
public:
FlatAffineRelation(unsigned numReservedInequalities,
@@ -251,9 +259,10 @@ class FlatAffineRelation : public FlatAffineValueConstraints {
/// For AffineValueMap, the domain and symbols have Value set corresponding to
/// the Value in `map`. Returns failure if the AffineMap could not be flattened
/// (i.e., semi-affine is not yet handled).
-LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
+LogicalResult getRelationFromMap(AffineMap &map,
+ presburger::IntegerRelation &rel);
LogicalResult getRelationFromMap(const AffineValueMap &map,
- FlatAffineRelation &rel);
+ presburger::IntegerRelation &rel);
} // namespace affine
} // namespace mlir
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 69846a356e0cc4..8a09c741979ffb 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -9,6 +9,7 @@
#include "mlir/Analysis//FlatLinearValueConstraints.h"
#include "mlir/Analysis/Presburger/LinearTransform.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -817,13 +818,13 @@ FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
set.getNumDims() + set.getNumSymbols() + 1,
set.getNumDims(), set.getNumSymbols(),
/*numLocals=*/0) {
- // Populate values.
- if (operands.empty()) {
- values.resize(getNumDimAndSymbolVars(), std::nullopt);
- } else {
- assert(set.getNumInputs() == operands.size() && "operand count mismatch");
- values.assign(operands.begin(), operands.end());
- }
+ assert(operands.empty() ||
+ set.getNumInputs() == operands.size() && "operand count mismatch");
+ // Use values in space for FlatLinearValueConstraints.
+ space.resetIds();
+ // Set the values for the non-local variables.
+ for (unsigned i = 0, e = operands.size(); i < e; ++i)
+ setValue(i, operands[i]);
// Flatten expressions and add them to the constraint system.
std::vector<SmallVector<int64_t, 8>> flatExprs;
@@ -873,11 +874,6 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
unsigned num) {
unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
- if (kind != VarKind::Local) {
- values.insert(values.begin() + absolutePos, num, std::nullopt);
- assert(values.size() == getNumDimAndSymbolVars());
- }
-
return absolutePos;
}
@@ -890,11 +886,10 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
// If a Value is provided, insert it; otherwise use std::nullopt.
- for (unsigned i = 0; i < num; ++i)
- values.insert(values.begin() + absolutePos + i,
- vals[i] ? std::optional<Value>(vals[i]) : std::nullopt);
+ for (unsigned i = 0, e = vals.size(); i < e; ++i)
+ if (vals[i])
+ setValue(absolutePos + i, vals[i]);
- assert(values.size() == getNumDimAndSymbolVars());
return absolutePos;
}
@@ -902,10 +897,14 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
/// associated with the same set of variables, appearing in the same order.
static bool areVarsAligned(const FlatLinearValueConstraints &a,
const FlatLinearValueConstraints &b) {
- return a.getNumDimVars() == b.getNumDimVars() &&
- a.getNumSymbolVars() == b.getNumSymbolVars() &&
- a.getNumVars() == b.getNumVars() &&
- a.getMaybeValues().equals(b.getMaybeValues());
+ if (a.getNumDomainVars() != b.getNumDomainVars() ||
+ a.getNumRangeVars() != b.getNumRangeVars() ||
+ a.getNumSymbolVars() != b.getNumSymbolVars())
+ return false;
+ SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
+ bMaybeValues = b.getMaybeValues();
+ return std::equal(aMaybeValues.begin(), aMaybeValues.end(),
+ bMaybeValues.begin(), bMaybeValues.end());
}
/// Calls areVarsAligned to check if two constraint systems have the same set
@@ -928,12 +927,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
return true;
SmallPtrSet<Value, 8> uniqueVars;
- ArrayRef<std::optional<Value>> maybeValues =
- cst.getMaybeValues().slice(start, end - start);
- for (std::optional<Value> val : maybeValues) {
+ SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
+ ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
+ maybeValuesAll.data() + end};
+
+ for (std::optional<Value> val : maybeValues)
if (val && !uniqueVars.insert(*val).second)
return false;
- }
+
return true;
}
@@ -1058,20 +1059,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
"expected same number of symbols");
}
-bool FlatLinearValueConstraints::hasConsistentState() const {
- return IntegerPolyhedron::hasConsistentState() &&
- values.size() == getNumDimAndSymbolVars();
-}
-
void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
unsigned varLimit) {
IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
- unsigned offset = getVarKindOffset(kind);
-
- if (kind != VarKind::Local) {
- values.erase(values.begin() + varStart + offset,
- values.begin() + varLimit + offset);
- }
}
AffineMap
@@ -1089,14 +1079,14 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
dims.reserve(getNumDimVars());
syms.reserve(getNumSymbolVars());
- for (unsigned i = getVarKindOffset(VarKind::SetDim),
- e = getVarKindEnd(VarKind::SetDim);
- i < e; ++i)
- dims.push_back(values[i] ? *values[i] : Value());
- for (unsigned i = getVarKindOffset(VarKind::Symbol),
- e = getVarKindEnd(VarKind::Symbol);
- i < e; ++i)
- syms.push_back(values[i] ? *values[i] : Value());
+ for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) {
+ Identifier id = space.getId(VarKind::SetDim, i);
+ dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
+ }
+ for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) {
+ Identifier id = space.getId(VarKind::Symbol, i);
+ syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
+ }
AffineMap alignedMap =
alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
@@ -1109,38 +1099,18 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
unsigned offset) const {
- unsigned i = offset;
- for (const auto &mayBeVar :
- ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
- if (mayBeVar && *mayBeVar == val) {
+ SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
+ for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
+ if (maybeValues[i] && maybeValues[i].value() == val) {
*pos = i;
return true;
}
- i++;
- }
return false;
}
bool FlatLinearValueConstraints::containsVar(Value val) const {
- return llvm::any_of(values, [&](const std::optional<Value> &mayBeVar) {
- return mayBeVar && *mayBeVar == val;
- });
-}
-
-void FlatLinearValueConstraints::swapVar(unsigned posA, unsigned posB) {
- IntegerPolyhedron::swapVar(posA, posB);
-
- if (getVarKindAt(posA) == VarKind::Local &&
- getVarKindAt(posB) == VarKind::Local)
- return;
-
- // Treat value of a local variable as std::nullopt.
- if (getVarKindAt(posA) == VarKind::Local)
- values[posB] = std::nullopt;
- else if (getVarKindAt(posB) == VarKind::Local)
- values[posA] = std::nullopt;
- else
- std::swap(values[posA], values[posB]);
+ unsigned pos;
+ return findVar(val, &pos, 0);
}
void FlatLinearValueConstraints::addBound(BoundType type, Value val,
@@ -1180,31 +1150,6 @@ void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
os << "const)\n";
}
-void FlatLinearValueConstraints::clearAndCopyFrom(
- const IntegerRelation &other) {
-
- if (auto *otherValueSet =
- dyn_cast<const FlatLinearValueConstraints>(&other)) {
- *this = *otherValueSet;
- } else {
- *static_cast<IntegerRelation *>(this) = other;
- values.clear();
- values.resize(getNumDimAndSymbolVars(), std::nullopt);
- }
-}
-
-void FlatLinearValueConstraints::fourierMotzkinEliminate(
- unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
- SmallVector<std::optional<Value>, 8> newVals = values;
- if (getVarKindAt(pos) != VarKind::Local)
- newVals.erase(newVals.begin() + pos);
- // Note: Base implementation discards all associated Values.
- IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
- isResultIntegerExact);
- values = newVals;
- assert(values.size() == getNumDimAndSymbolVars());
-}
-
void FlatLinearValueConstraints::projectOut(Value val) {
unsigned pos;
bool ret = findVar(val, &pos);
@@ -1216,9 +1161,12 @@ void FlatLinearValueConstraints::projectOut(Value val) {
LogicalResult FlatLinearValueConstraints::unionBoundingBox(
const FlatLinearValueConstraints &otherCst) {
assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
- assert(otherCst.getMaybeValues()
- .slice(0, getNumDimVars())
- .equals(getMaybeValues().slice(0, getNumDimVars())) &&
+ SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
+ otherMaybeValues =
+ otherCst.getMaybeValues();
+ assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
+ otherMaybeValues.begin(),
+ otherMaybeValues.begin() + getNumDimVars()) &&
"dim values mismatch");
assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
assert(getNumLocalVars() == 0 && "local vars not supported yet here");
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 69b815a137b7a5..4f243391f6aca9 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2516,6 +2516,64 @@ bool IntegerRelation::isFullDim() {
});
}
+void IntegerRelation::mergeAndCompose(const IntegerRelation &other) {
+ assert(getNumDomainVars() == other.getNumRangeVars() &&
+ "Domain of this and range of other do not match");
+ // assert(std::equal(values.begin(), values.begin() +
+ // other.getNumDomainVars(),
+ // otherValues.begin() + other.getNumDomainVars()) &&
+ // "Domain of this and range of other do not match");
+
+ IntegerRelation result = other;
+
+ const unsigned thisDomain = getNumDomainVars();
+ const unsigned thisRange = getNumRangeVars();
+ const unsigned otherDomain = other.getNumDomainVars();
+ const unsigned otherRange = other.getNumRangeVars();
+
+ // Add dimension variables temporarily to merge symbol and local vars.
+ // Convert `this` from
+ // [thisDomain] -> [thisRange]
+ // to
+ // [otherDomain thisDomain] -> [otherRange thisRange].
+ // and `result` from
+ // [otherDomain] -> [otherRange]
+ // to
+ // [otherDomain thisDomain] -> [otherRange thisRange]
+ insertVar(VarKind::Domain, 0, otherDomain);
+ insertVar(VarKind::Range, 0, otherRange);
+ result.insertVar(VarKind::Domain, otherDomain, thisDomain);
+ result.insertVar(VarKind::Range, otherRange, thisRange);
+
+ // Merge symbol and local variables.
+ mergeAndAlignSymbols(result);
+ mergeLocalVars(result);
+
+ // Convert `result` from [otherDomain thisDomain] -> [otherRange thisRange] to
+ // [otherDomain] -> [thisRange]
+ result.removeVarRange(VarKind::Domain, otherDomain, otherDomain + thisDomain);
+ result.convertToLocal(VarKind::Range, 0, otherRange);
+ // Convert `this` from [otherDomain thisDomain] -> [otherRange thisRange] to
+ // [otherDomain] -> [thisRange]
+ convertToLocal(VarKind::Domain, otherDomain, otherDomain + thisDomain);
+ removeVarRange(VarKind::Range, 0, otherRange);
+
+ // Add and match domain of `result` to domain of `this`.
+ for (unsigned i = 0, e = result.getNumDomainVars(); i < e; ++i)
+ if (result.getSpace().getId(VarKind::Domain, i).hasValue())
+ setId(VarKind::Domain, i, result.getSpace().getId(VarKind::Domain, i));
+ // Add and match range of `this` to range of `result`.
+ for (unsigned i = 0, e = getNumRangeVars(); i < e; ++i)
+ if (space.getId(VarKind::Range, i).hasValue())
+ result.setId(VarKind::Range, i, space.getId(VarKind::Range, i));
+
+ // Append `this` to `result` and simplify constraints.
+ result.append(*this);
+ result.removeRedundantLocalVars();
+
+ *this = result;
+}
+
void IntegerRelation::print(raw_ostream &os) const {
assert(hasConsistentState());
printSpace(os);
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 69b3d41e17c2d4..6df3e868ecc62e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
@@ -379,11 +381,10 @@ static bool srcAppearsBeforeDstInAncestralBlock(const MemRefAccess &srcAccess,
// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
-static void
-addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
- const FlatAffineValueConstraints &dstDomain,
- unsigned loopDepth,
- FlatAffineValueConstraints *dependenceDomain) {
+static void addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
+ const FlatAffineValueConstraints &dstDomain,
+ unsigned loopDepth,
+ IntegerRelation *dependenceDomain) {
unsigned numCols = dependenceDomain->getNumCols();
SmallVector<int64_t, 4> eq(numCols);
unsigned numSrcDims = srcDomain.getNumDimVars();
@@ -409,7 +410,7 @@ addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
static void computeDirectionVector(
const FlatAffineValueConstraints &srcDomain,
const FlatAffineValueConstraints &dstDomain, unsigned loopDepth,
- FlatAffineValueConstraints *dependenceDomain,
+ IntegerPolyhedron *dependenceDomain,
SmallVector<DependenceComponent, 2> *dependenceComponents) {
// Find the number of common loops shared by src and dst accesses.
SmallVector<AffineForOp, 4> commonLoops;
@@ -421,7 +422,8 @@ static void computeDirectionVector(
unsigned numIdsToEliminate = dependenceDomain->getNumVars();
// Add new variables to 'dependenceDomain' to represent the direction
// constraints for each shared loop.
- dependenceDomain->insertDimVar(/*pos=*/0, /*num=*/numCommonLoops);
+ dependenceDomain->insertVar(VarKind::SetDim, /*pos=*/0,
+ /*num=*/numCommonLoops);
// Add equality constraints for each common loop, setting newly introduced
// variable at column 'j' to the 'dst' IV minus the 'src IV.
@@ -455,7 +457,7 @@ static void computeDirectionVector(
}
}
-LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
+LogicalResult MemRefAccess::getAccessRelation(IntegerRelation &rel) const {
// Create set corresponding to domain of access.
FlatAffineValueConstraints domain;
if (failed(getOpIndexSet(opInst, &domain)))
@@ -467,28 +469,35 @@ LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
if (failed(getRelationFromMap(accessValueMap, rel)))
return failure();
- FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0,
- domain);
-
- // Merge and align domain ids of `ret` and ids of `domain`. Since the domain
+ // Merge and align domain ids of `rel` with ids of `domain`. Since the domain
// of the access map is a subset of the domain of access, the domain ids of
- // `ret` are guranteed to be a subset of ids of `domain`.
+ // `rel` are guranteed to be a subset of ids of `domain`.
+ unsigned inserts = 0;
for (unsigned i = 0, e = domain.getNumDimVars(); i < e; ++i) {
- unsigned loc;
- if (rel.findVar(domain.getValue(i), &loc)) {
- rel.swapVar(i, loc);
+ const Identifier domainIdi = Identifier(domain.getValue(i));
+ const PresburgerSpace &relSpace = rel.getSpace();
+ const Identifier *findBegin = relSpace.getIds(VarKind::SetDim).begin() + i;
+ const Identifier *findEnd = relSpace.getIds(VarKind::SetDim).end();
+ const Identifier *itr = std::find(findBegin, findEnd, domainIdi);
+ if (itr != findEnd) {
+ rel.swapVar(i, i + std::distance(findBegin, itr));
} else {
- rel.insertDomainVar(i);
- rel.setValue(i, domain.getValue(i));
+ ++inserts;
+ rel.insertVar(VarKind::SetDim, i);
+ rel.setId(VarKind::SetDim, i, domainIdi);
}
}
// Append domain constraints to `rel`.
- domainRel.appendRangeVar(rel.getNumRangeDims());
- domainRel.mergeSymbolVars(rel);
+ IntegerRelation domainRel = domain;
+ domainRel.appendVar(VarKind::Range, accessValueMap.getNumResults());
+ domainRel.mergeAndAlignSymbols(rel);
domainRel.mergeLocalVars(rel);
rel.append(domainRel);
+ rel.convertVarKind(VarKind::SetDim, 0, accessValueMap.getNumDims() + inserts,
+ VarKind::Domain);
+
return success();
}
@@ -622,14 +631,15 @@ DependenceResult mlir::affine::checkMemrefAccessDependence(
return DependenceResult::Failure;
// Create access relation from each MemRefAccess.
- FlatAffineRelation srcRel, dstRel;
+ PresburgerSpace space = PresburgerSpace::getRelationSpace();
+ IntegerRelation srcRel(space), dstRel(space);
if (failed(srcAccess.getAccessRelation(srcRel)))
return DependenceResult::Failure;
if (failed(dstAccess.getAccessRelation(dstRel)))
return DependenceResult::Failure;
- FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
- FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();
+ FlatAffineValueConstraints srcDomain(srcRel.getDomainSet());
+ FlatAffineValueConstraints dstDomain(dstRel.getDomainSet());
// Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
// operation of 'srcAccess' does not properly dominate the ancestor
@@ -648,25 +658,29 @@ DependenceResult mlir::affine::checkMemrefAccessDependence(
// `srcAccess` to the iteration domain of `dstAccess` which access the same
// memory locations.
dstRel.inverse();
- dstRel.compose(srcRel);
+ dstRel.mergeAndCompose(srcRel);
+ dstRel.convertVarKind(VarKind::Domain, 0, dstRel.getNumDomainVars(),
+ VarKind::Range, 0);
+ IntegerPolyhedron dependenceDomain(dstRel);
// Add 'src' happens before 'dst' ordering constraints.
- addOrderingConstraints(srcDomain, dstDomain, loopDepth, &dstRel);
+ addOrderingConstraints(srcDomain, dstDomain, loopDepth, &dependenceDomain);
// Return 'NoDependence' if the solution space is empty: no dependence.
- if (dstRel.isEmpty())
+ if (dependenceDomain.isEmpty())
return DependenceResult::NoDependence;
// Compute dependence direction vector and return true.
if (dependenceComponents != nullptr)
- computeDirectionVector(srcDomain, dstDomain, loopDepth, &dstRel,
+ computeDirectionVector(srcDomain, dstDomain, loopDepth, &dependenceDomain,
dependenceComponents);
LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
- LLVM_DEBUG(dstRel.dump());
+ LLVM_DEBUG(dependenceDomain.dump());
+ FlatAffineValueConstraints result(dependenceDomain);
if (dependenceConstraints)
- *dependenceConstraints = dstRel;
+ *dependenceConstraints = result;
return DependenceResult::HasDependence;
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 469298d3e8f43f..89fea34bad4e84 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/LinearTransform.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
@@ -372,9 +373,8 @@ FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
void FlatAffineRelation::compose(const FlatAffineRelation &other) {
assert(getNumDomainDims() == other.getNumRangeDims() &&
"Domain of this and range of other do not match");
- assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
- other.values.begin() + other.getNumDomainDims()) &&
- "Domain of this and range of other do not match");
+ assert(space.getDomainSpace().isAligned(other.getSpace().getRangeSpace()) &&
+ "Values of domain of this and range of other do not match");
FlatAffineRelation rel = other;
@@ -491,21 +491,27 @@ void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart,
}
LogicalResult mlir::affine::getRelationFromMap(AffineMap &map,
- FlatAffineRelation &rel) {
+ IntegerRelation &rel) {
// Get flattened affine expressions.
std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineValueConstraints localVarCst;
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
return failure();
- unsigned oldDimNum = localVarCst.getNumDimVars();
- unsigned oldCols = localVarCst.getNumCols();
- unsigned numRangeVars = map.getNumResults();
- unsigned numDomainVars = map.getNumDims();
+ const unsigned oldDimNum = localVarCst.getNumDimVars();
+ const unsigned oldCols = localVarCst.getNumCols();
+ const unsigned numRangeVars = map.getNumResults();
+ const unsigned numDomainVars = map.getNumDims();
// Add range as the new expressions.
+ localVarCst.resetIds();
localVarCst.appendDimVar(numRangeVars);
+ // Add identifiers to the local constraints as getFlattenedAffineExprs creates
+ // a FlatLinearConstraints with no identifiers.
+ for (unsigned i = 0, e = localVarCst.getNumDimAndSymbolVars(); i < e; ++i)
+ localVarCst.setValue(i, Value());
+
// Add equalities between source and range.
SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -521,24 +527,26 @@ LogicalResult mlir::affine::getRelationFromMap(AffineMap &map,
localVarCst.addEquality(eq);
}
- // Create relation and return success.
- rel = FlatAffineRelation(numDomainVars, numRangeVars, localVarCst);
+ rel = localVarCst;
return success();
}
LogicalResult mlir::affine::getRelationFromMap(const AffineValueMap &map,
- FlatAffineRelation &rel) {
+ IntegerRelation &rel) {
AffineMap affineMap = map.getAffineMap();
if (failed(getRelationFromMap(affineMap, rel)))
return failure();
- // Set symbol values for domain dimensions and symbols.
- for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
- rel.setValue(i, map.getOperand(i));
- for (unsigned i = rel.getNumDimVars(), e = rel.getNumDimAndSymbolVars();
- i < e; ++i)
- rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
+ // Set identifiers for domain and symbol variables.
+ for (unsigned i = 0, e = affineMap.getNumDims(); i < e; ++i)
+ rel.setId(VarKind::SetDim, i, Identifier(map.getOperand(i)));
+
+ const unsigned mapNumResults = affineMap.getNumResults();
+ for (unsigned i = 0, e = rel.getNumSymbolVars(); i < e; ++i)
+ rel.setId(
+ VarKind::Symbol, i,
+ Identifier(map.getOperand(rel.getNumDimVars() + i - mapNumResults)));
return success();
}
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 945b3d502f6973..a1bee404bf6664 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -542,8 +542,8 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
// Attach identifiers.
char identifiers[5] = {'x', 'y', 'z', 'N', 'M'};
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
- space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
- space.getId(VarKind::Range, 2) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[2]);
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
rel.setSpace(space);
More information about the Mlir-commits
mailing list