[Mlir-commits] [mlir] [MLIR][Presburger] Use Identifiers outside Presburger library (PR #77316)

Bharathi Ramana Joshi llvmlistbot at llvm.org
Wed Jan 24 09:15:04 PST 2024


https://github.com/iambrj updated https://github.com/llvm/llvm-project/pull/77316

>From 5a61c363c0b619df2bbd06080fa9e51f3ee838da Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Sun, 21 Jan 2024 01:39:04 +0530
Subject: [PATCH 1/2] [MLIR][Presburger] Implement preserve identifiers in
 IntegerRelation::convertVarKind

---
 .../Analysis/Presburger/IntegerRelation.cpp   |  26 ++--
 .../Presburger/IntegerRelationTest.cpp        | 136 ++++++++++++++++++
 2 files changed, 146 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 7d2a63d17676f57..a07a1fb07d4ea30 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1452,28 +1452,22 @@ void IntegerRelation::removeRedundantLocalVars() {
 void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
                                      unsigned varLimit, VarKind dstKind,
                                      unsigned pos) {
-  assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range");
+  assert(varLimit <= getNumVarKind(srcKind) && "invalid id range");
 
   if (varStart >= varLimit)
     return;
 
-  // Append new local variables corresponding to the dimensions to be converted.
+  unsigned srcOffset = getVarKindOffset(srcKind);
+  unsigned dstOffset = getVarKindOffset(dstKind);
   unsigned convertCount = varLimit - varStart;
-  unsigned newVarsBegin = insertVar(dstKind, pos, convertCount);
+  int forwardMoveOffset = dstOffset > srcOffset ? -convertCount : 0;
 
-  // Swap the new local variables with dimensions.
-  //
-  // Essentially, this moves the information corresponding to the specified ids
-  // of kind `srcKind` to the `convertCount` newly created ids of kind
-  // `dstKind`. In particular, this moves the columns in the constraint
-  // matrices, and zeros out the initially occupied columns (because the newly
-  // created ids we're swapping with were zero-initialized).
-  unsigned offset = getVarKindOffset(srcKind);
-  for (unsigned i = 0; i < convertCount; ++i)
-    swapVar(offset + varStart + i, newVarsBegin + i);
-
-  // Complete the move by deleting the initially occupied columns.
-  removeVarRange(srcKind, varStart, varLimit);
+  equalities.moveColumns(srcOffset + varStart, convertCount,
+                         dstOffset + pos + forwardMoveOffset);
+  inequalities.moveColumns(srcOffset + varStart, convertCount,
+                           dstOffset + pos + forwardMoveOffset);
+
+  space.convertVarKind(srcKind, varStart, varLimit - varStart, dstKind, pos);
 }
 
 void IntegerRelation::addBound(BoundType type, unsigned pos,
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 00d2204c9c8ef18..945b3d502f6973b 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -485,3 +485,139 @@ TEST(IntegerRelationTest, setId) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
 }
+
+TEST(IntegerRelationTest, convertVarKind) {
+  PresburgerSpace space = PresburgerSpace::getSetSpace(3, 3, 0);
+  space.resetIds();
+
+  // Attach identifiers.
+  int identifiers[6] = {0, 1, 2, 3, 4, 5};
+  space.getId(VarKind::SetDim, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::SetDim, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::SetDim, 2) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+
+  // Cannot call parseIntegerRelation to test convertVarKind as
+  // parseIntegerRelation uses convertVarKind.
+  IntegerRelation rel = parseIntegerPolyhedron(
+      // 0  1  2  3  4  5
+      "(x, y, a)[U, V, W] : (x - U == 0, y + a - W == 0, U - V >= 0,"
+      "y - a >= 0)");
+  rel.setSpace(space);
+
+  // Make a few kind conversions.
+  rel.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 0);
+  rel.convertVarKind(VarKind::Range, 2, 3, VarKind::Domain, 0);
+  rel.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 1);
+  rel.convertVarKind(VarKind::Domain, 1, 2, VarKind::Range, 0);
+  rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 1);
+
+  space = rel.getSpace();
+
+  // Expected rel.
+  IntegerRelation expectedRel = parseIntegerPolyhedron(
+      "(V, a)[U, x, y, W] : (x - U == 0, y + a - W == 0, U - V >= 0,"
+      "y - a >= 0)");
+  expectedRel.setSpace(space);
+
+  EXPECT_TRUE(rel.isEqual(expectedRel));
+
+  EXPECT_EQ(space.getId(VarKind::SetDim, 0), Identifier(&identifiers[4]));
+  EXPECT_EQ(space.getId(VarKind::SetDim, 1), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[5]));
+}
+
+TEST(IntegerRelationTest, convertVarKindToLocal) {
+  // Convert all range variables to local variables.
+  IntegerRelation rel = parseRelationFromSet(
+      "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+      1);
+  PresburgerSpace space = rel.getSpace();
+  space.resetIds();
+  // Attach identifiers.
+  char identifiers[5] = {'x', 'y', 'z', 'N', 'M'};
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  rel.setSpace(space);
+  rel.convertToLocal(VarKind::Range, 0, rel.getNumRangeVars());
+  IntegerRelation expectedRel =
+      parseRelationFromSet("(x)[N, M] : (x - N >= 0, 2 * M - 5 >= 0)", 1);
+  EXPECT_TRUE(rel.isEqual(expectedRel));
+  space = rel.getSpace();
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+  // Convert all domain variables to local variables.
+  IntegerRelation rel2 = parseRelationFromSet(
+      "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+      2);
+  space = rel2.getSpace();
+  space.resetIds();
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  rel2.setSpace(space);
+  rel2.convertToLocal(VarKind::Domain, 0, rel2.getNumDomainVars());
+  expectedRel =
+      parseIntegerPolyhedron("(z)[N, M] : (3 - z >= 0, 2 * M - 5 >= 0)");
+  EXPECT_TRUE(rel2.isEqual(expectedRel));
+  space = rel2.getSpace();
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+  // Convert a prefix of range variables to local variables.
+  IntegerRelation rel3 = parseRelationFromSet(
+      "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+      1);
+  space = rel3.getSpace();
+  space.resetIds();
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  rel3.setSpace(space);
+  rel3.convertToLocal(VarKind::Range, 0, 1);
+  expectedRel = parseRelationFromSet(
+      "(x, z)[N, M] : (x - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)", 1);
+  EXPECT_TRUE(rel3.isEqual(expectedRel));
+  space = rel3.getSpace();
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+
+  // Convert a suffix of domain variables to local variables.
+  IntegerRelation rel4 = parseRelationFromSet(
+      "(x, y, z)[N, M] : (x - y >= 0, y - N >= 0, 3 - z >= 0, 2 * M - 5 >= 0)",
+      2);
+  space = rel4.getSpace();
+  space.resetIds();
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  rel4.setSpace(space);
+  rel4.convertToLocal(VarKind::Domain, rel4.getNumDomainVars() - 1,
+                      rel4.getNumDomainVars());
+  // expectedRel same as before.
+  EXPECT_TRUE(rel4.isEqual(expectedRel));
+  space = rel4.getSpace();
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+}

>From fd0a4926ac75b4b4b2e1d20cf2554c29841a442c Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Wed, 24 Jan 2024 22:40:37 +0530
Subject: [PATCH 2/2] [MLIR][Presburger] Use Identifiers outside Presburger
 library

---
 .../Analysis/FlatLinearValueConstraints.h     |  94 ++++++++------
 .../Dialect/Affine/Analysis/AffineAnalysis.h  |   3 +-
 .../Affine/Analysis/AffineStructures.h        |   4 +-
 .../Analysis/FlatLinearValueConstraints.cpp   | 117 ++++++------------
 .../Affine/Analysis/AffineAnalysis.cpp        |  52 +++++---
 .../Affine/Analysis/AffineStructures.cpp      |  33 +++--
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |   2 +
 7 files changed, 156 insertions(+), 149 deletions(-)

diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index e4de5b0661571c8..8929b5ee1aa5bb0 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -205,6 +205,8 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
 /// where each non-local variable can have an SSA Value attached to it.
 class FlatLinearValueConstraints : public FlatLinearConstraints {
 public:
+  using Identifier = presburger::Identifier;
+
   /// Constructs a constraint system reserving memory for the specified number
   /// of constraints and variables. `valArgs` are the optional SSA values
   /// associated with each dimension/symbol. These must either be empty or match
@@ -217,11 +219,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
       : FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
                               numReservedCols, numDims, numSymbols, numLocals) {
     assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
-    values.reserve(numReservedCols);
-    if (valArgs.empty())
-      values.resize(getNumDimAndSymbolVars(), std::nullopt);
-    else
-      values.append(valArgs.begin(), valArgs.end());
+    // Use values in space for FlatLinearValueConstraints.
+    space.resetIds();
+    // Set the values for the non-local variables.
+    for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+      if (valArgs[i])
+        setValue(i, *valArgs[i]);
   }
 
   /// Constructs a constraint system reserving memory for the specified number
@@ -236,11 +239,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
       : FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
                               numReservedCols, numDims, numSymbols, numLocals) {
     assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
-    values.reserve(numReservedCols);
-    if (valArgs.empty())
-      values.resize(getNumDimAndSymbolVars(), std::nullopt);
-    else
-      values.append(valArgs.begin(), valArgs.end());
+    // Use values in space for FlatLinearValueConstraints.
+    space.resetIds();
+    // Set the values for the non-local variables.
+    for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+      if (valArgs[i])
+        setValue(i, valArgs[i]);
   }
 
   /// Constructs a constraint system with the specified number of dimensions
@@ -273,10 +277,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
                              ArrayRef<std::optional<Value>> valArgs = {})
       : FlatLinearConstraints(fac) {
     assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
-    if (valArgs.empty())
-      values.resize(getNumDimAndSymbolVars(), std::nullopt);
-    else
-      values.append(valArgs.begin(), valArgs.end());
+    // Use values in space for FlatLinearValueConstraints.
+    space.resetIds();
+    // Set the values for the non-local variables.
+    for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
+      if (valArgs[i])
+        setValue(i, *valArgs[i]);
   }
 
   /// Creates an affine constraint system from an IntegerSet.
@@ -302,7 +308,9 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
   inline Value getValue(unsigned pos) const {
     assert(pos < getNumDimAndSymbolVars() && "Invalid position");
     assert(hasValue(pos) && "variable's Value not set");
-    return *values[pos];
+    VarKind kind = getVarKindAt(pos);
+    unsigned relativePos = pos - getVarKindOffset(kind);
+    return space.getId(kind, relativePos).getValue<Value>();
   }
 
   /// Returns the Values associated with variables in range [start, end).
@@ -317,21 +325,44 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
       values->push_back(getValue(i));
   }
 
-  inline ArrayRef<std::optional<Value>> getMaybeValues() const {
-    return {values.data(), values.size()};
+  inline SmallVector<std::optional<Value>> getMaybeValues() const {
+    SmallVector<std::optional<Value>> maybeValues;
+    maybeValues.reserve(getNumDimAndSymbolVars());
+    for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
+      if (hasValue(i))
+        maybeValues.push_back(getValue(i));
+      else
+        maybeValues.push_back(std::nullopt);
+    }
+    return maybeValues;
   }
 
-  inline ArrayRef<std::optional<Value>>
-  getMaybeValues(presburger::VarKind kind) const {
-    assert(kind != VarKind::Local &&
-           "Local variables do not have any value attached to them.");
-    return {values.data() + getVarKindOffset(kind), getNumVarKind(kind)};
-  }
+  inline SmallVector<std::optional<Value>>
+    getMaybeValues(presburger::VarKind kind) const {
+      assert(kind != VarKind::Local &&
+          "Local variables do not have any value attached to them.");
+      SmallVector<std::optional<Value>> maybeValues;
+      maybeValues.reserve(getNumVarKind(kind));
+      for (unsigned i = 0, e = getNumVarKind(kind); i < e; i++) {
+        Identifier id = space.getId(kind, i);
+        if (id.hasValue())
+          maybeValues.push_back(space.getId(kind, i).getValue<Value>());
+        else
+          maybeValues.push_back(std::nullopt);
+      }
+      return maybeValues;
+    }
 
   /// Returns true if the pos^th variable has an associated Value.
   inline bool hasValue(unsigned pos) const {
     assert(pos < getNumDimAndSymbolVars() && "Invalid position");
-    return values[pos].has_value();
+    VarKind kind = getVarKindAt(pos);
+    unsigned relativePos = pos - getVarKindOffset(kind);
+    return space.getId(kind, relativePos).hasValue();
+  }
+
+  void resetValues() {
+    space.resetIds();
   }
 
   unsigned appendDimVar(ValueRange vals);
@@ -360,7 +391,9 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
   /// Sets the Value associated with the pos^th variable.
   inline void setValue(unsigned pos, Value val) {
     assert(pos < getNumDimAndSymbolVars() && "invalid var position");
-    values[pos] = val;
+    VarKind kind = getVarKindAt(pos);
+    unsigned relativePos = pos - getVarKindOffset(kind);
+    space.getId(kind, relativePos) = presburger::Identifier(val);
   }
 
   /// Sets the Values associated with the variables in the range [start, end).
@@ -455,17 +488,6 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
   // See implementation comments for more details.
   void fourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
                                bool *isResultIntegerExact = nullptr) override;
-
-  /// Returns false if the fields corresponding to various variable counts, or
-  /// equality/inequality buffer sizes aren't consistent; true otherwise. This
-  /// is meant to be used within an assert internally.
-  bool hasConsistentState() const override;
-
-  /// Values corresponding to the (column) non-local variables of this
-  /// constraint system appearing in the order the variables correspond to
-  /// columns. Variables that aren't associated with any Value are set to
-  /// std::nullopt.
-  SmallVector<std::optional<Value>, 8> values;
 };
 
 /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
index a27583877b603ce..4134aef8174bc16 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
 #define MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
 
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/SmallVector.h"
@@ -115,7 +116,7 @@ struct MemRefAccess {
   ///
   /// Returns failure for yet unimplemented/unsupported cases (see docs of
   /// mlir::getIndexSet and mlir::getRelationFromMap for these cases).
-  LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const;
+  LogicalResult getAccessRelation(presburger::IntegerRelation &accessRel) const;
 
   /// Populates 'accessMap' with composition of AffineApplyOps reachable from
   /// 'indices'.
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
index 7c500f13895af14..efd28a88e93752c 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
@@ -251,9 +251,9 @@ class FlatAffineRelation : public FlatAffineValueConstraints {
 /// For AffineValueMap, the domain and symbols have Value set corresponding to
 /// the Value in `map`. Returns failure if the AffineMap could not be flattened
 /// (i.e., semi-affine is not yet handled).
-LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
+LogicalResult getRelationFromMap(AffineMap &map, presburger::IntegerRelation &rel);
 LogicalResult getRelationFromMap(const AffineValueMap &map,
-                                 FlatAffineRelation &rel);
+                                 presburger::IntegerRelation &rel);
 
 } // namespace affine
 } // namespace mlir
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 69846a356e0cc42..7f4be6475d51cc7 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -817,13 +817,12 @@ FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
                             set.getNumDims() + set.getNumSymbols() + 1,
                             set.getNumDims(), set.getNumSymbols(),
                             /*numLocals=*/0) {
-  // Populate values.
-  if (operands.empty()) {
-    values.resize(getNumDimAndSymbolVars(), std::nullopt);
-  } else {
-    assert(set.getNumInputs() == operands.size() && "operand count mismatch");
-    values.assign(operands.begin(), operands.end());
-  }
+  // Use values in space for FlatLinearValueConstraints.
+  space.resetIds();
+  // Set the values for the non-local variables.
+  for (unsigned i = 0, e = operands.size(); i < e; ++i)
+    setValue(i, operands[i]);
+
 
   // Flatten expressions and add them to the constraint system.
   std::vector<SmallVector<int64_t, 8>> flatExprs;
@@ -873,11 +872,6 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
                                                unsigned num) {
   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
 
-  if (kind != VarKind::Local) {
-    values.insert(values.begin() + absolutePos, num, std::nullopt);
-    assert(values.size() == getNumDimAndSymbolVars());
-  }
-
   return absolutePos;
 }
 
@@ -890,11 +884,9 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
 
   // If a Value is provided, insert it; otherwise use std::nullopt.
-  for (unsigned i = 0; i < num; ++i)
-    values.insert(values.begin() + absolutePos + i,
-                  vals[i] ? std::optional<Value>(vals[i]) : std::nullopt);
+  for (unsigned i = 0, e = vals.size(); i < e; ++i)
+    setValue(absolutePos + i, vals[i]);
 
-  assert(values.size() == getNumDimAndSymbolVars());
   return absolutePos;
 }
 
@@ -902,10 +894,13 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
 /// associated with the same set of variables, appearing in the same order.
 static bool areVarsAligned(const FlatLinearValueConstraints &a,
                            const FlatLinearValueConstraints &b) {
-  return a.getNumDimVars() == b.getNumDimVars() &&
-         a.getNumSymbolVars() == b.getNumSymbolVars() &&
-         a.getNumVars() == b.getNumVars() &&
-         a.getMaybeValues().equals(b.getMaybeValues());
+  if(a.getNumDomainVars() != b.getNumDomainVars() ||
+     a.getNumRangeVars() != b.getNumRangeVars() ||
+     a.getNumSymbolVars() != b.getNumSymbolVars())
+    return false;
+  SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(), bMaybeValues = b.getMaybeValues();
+  return std::equal(aMaybeValues.begin(), aMaybeValues.end(),
+      bMaybeValues.begin(), bMaybeValues.end());
 }
 
 /// Calls areVarsAligned to check if two constraint systems have the same set
@@ -928,12 +923,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
     return true;
 
   SmallPtrSet<Value, 8> uniqueVars;
-  ArrayRef<std::optional<Value>> maybeValues =
-      cst.getMaybeValues().slice(start, end - start);
-  for (std::optional<Value> val : maybeValues) {
+  SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
+  ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
+                                                maybeValuesAll.data() + end};
+
+  for (std::optional<Value> val : maybeValues)
     if (val && !uniqueVars.insert(*val).second)
       return false;
-  }
+
   return true;
 }
 
@@ -1058,20 +1055,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
          "expected same number of symbols");
 }
 
-bool FlatLinearValueConstraints::hasConsistentState() const {
-  return IntegerPolyhedron::hasConsistentState() &&
-         values.size() == getNumDimAndSymbolVars();
-}
-
 void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
                                                 unsigned varLimit) {
   IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
-  unsigned offset = getVarKindOffset(kind);
-
-  if (kind != VarKind::Local) {
-    values.erase(values.begin() + varStart + offset,
-                 values.begin() + varLimit + offset);
-  }
 }
 
 AffineMap
@@ -1089,14 +1075,15 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
 
   dims.reserve(getNumDimVars());
   syms.reserve(getNumSymbolVars());
-  for (unsigned i = getVarKindOffset(VarKind::SetDim),
-                e = getVarKindEnd(VarKind::SetDim);
-       i < e; ++i)
-    dims.push_back(values[i] ? *values[i] : Value());
-  for (unsigned i = getVarKindOffset(VarKind::Symbol),
-                e = getVarKindEnd(VarKind::Symbol);
-       i < e; ++i)
-    syms.push_back(values[i] ? *values[i] : Value());
+  for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) {
+    Identifier id = space.getId(VarKind::SetDim, i);
+    dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
+  }
+  for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) {
+    Identifier id = space.getId(VarKind::Symbol, i);
+    syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
+  }
+
 
   AffineMap alignedMap =
       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
@@ -1110,8 +1097,7 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
 bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
                                          unsigned offset) const {
   unsigned i = offset;
-  for (const auto &mayBeVar :
-       ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
+  for (const auto &mayBeVar : getMaybeValues()) {
     if (mayBeVar && *mayBeVar == val) {
       *pos = i;
       return true;
@@ -1122,25 +1108,12 @@ bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
 }
 
 bool FlatLinearValueConstraints::containsVar(Value val) const {
-  return llvm::any_of(values, [&](const std::optional<Value> &mayBeVar) {
-    return mayBeVar && *mayBeVar == val;
-  });
+  unsigned pos;
+  return findVar(val, &pos, 0);
 }
 
 void FlatLinearValueConstraints::swapVar(unsigned posA, unsigned posB) {
   IntegerPolyhedron::swapVar(posA, posB);
-
-  if (getVarKindAt(posA) == VarKind::Local &&
-      getVarKindAt(posB) == VarKind::Local)
-    return;
-
-  // Treat value of a local variable as std::nullopt.
-  if (getVarKindAt(posA) == VarKind::Local)
-    values[posB] = std::nullopt;
-  else if (getVarKindAt(posB) == VarKind::Local)
-    values[posA] = std::nullopt;
-  else
-    std::swap(values[posA], values[posB]);
 }
 
 void FlatLinearValueConstraints::addBound(BoundType type, Value val,
@@ -1182,27 +1155,13 @@ void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
 
 void FlatLinearValueConstraints::clearAndCopyFrom(
     const IntegerRelation &other) {
-
-  if (auto *otherValueSet =
-          dyn_cast<const FlatLinearValueConstraints>(&other)) {
-    *this = *otherValueSet;
-  } else {
-    *static_cast<IntegerRelation *>(this) = other;
-    values.clear();
-    values.resize(getNumDimAndSymbolVars(), std::nullopt);
-  }
+  IntegerPolyhedron::clearAndCopyFrom(other);
 }
 
 void FlatLinearValueConstraints::fourierMotzkinEliminate(
     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
-  SmallVector<std::optional<Value>, 8> newVals = values;
-  if (getVarKindAt(pos) != VarKind::Local)
-    newVals.erase(newVals.begin() + pos);
-  // Note: Base implementation discards all associated Values.
   IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
                                              isResultIntegerExact);
-  values = newVals;
-  assert(values.size() == getNumDimAndSymbolVars());
 }
 
 void FlatLinearValueConstraints::projectOut(Value val) {
@@ -1216,10 +1175,10 @@ void FlatLinearValueConstraints::projectOut(Value val) {
 LogicalResult FlatLinearValueConstraints::unionBoundingBox(
     const FlatLinearValueConstraints &otherCst) {
   assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
-  assert(otherCst.getMaybeValues()
-             .slice(0, getNumDimVars())
-             .equals(getMaybeValues().slice(0, getNumDimVars())) &&
-         "dim values mismatch");
+  SmallVector<std::optional<Value>> maybeValues = getMaybeValues(), otherMaybeValues = otherCst.getMaybeValues();
+  assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
+                    otherMaybeValues.begin(), otherMaybeValues.begin() + getNumDimVars()) &&
+                    "dim values mismatch");
   assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
   assert(getNumLocalVars() == 0 && "local vars not supported yet here");
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 1ba0bc8b6bfbe5e..ca418e7dcbaef49 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -12,7 +12,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -385,7 +388,7 @@ static void
 addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
                        const FlatAffineValueConstraints &dstDomain,
                        unsigned loopDepth,
-                       FlatAffineValueConstraints *dependenceDomain) {
+                       IntegerRelation *dependenceDomain) {
   unsigned numCols = dependenceDomain->getNumCols();
   SmallVector<int64_t, 4> eq(numCols);
   unsigned numSrcDims = srcDomain.getNumDimVars();
@@ -411,7 +414,7 @@ addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
 static void computeDirectionVector(
     const FlatAffineValueConstraints &srcDomain,
     const FlatAffineValueConstraints &dstDomain, unsigned loopDepth,
-    FlatAffineValueConstraints *dependenceDomain,
+    IntegerRelation *dependenceDomain,
     SmallVector<DependenceComponent, 2> *dependenceComponents) {
   // Find the number of common loops shared by src and dst accesses.
   SmallVector<AffineForOp, 4> commonLoops;
@@ -423,7 +426,7 @@ static void computeDirectionVector(
   unsigned numIdsToEliminate = dependenceDomain->getNumVars();
   // Add new variables to 'dependenceDomain' to represent the direction
   // constraints for each shared loop.
-  dependenceDomain->insertDimVar(/*pos=*/0, /*num=*/numCommonLoops);
+  dependenceDomain->insertVar(VarKind::Domain, /*pos=*/0, /*num=*/numCommonLoops);
 
   // Add equality constraints for each common loop, setting newly introduced
   // variable at column 'j' to the 'dst' IV minus the 'src IV.
@@ -457,7 +460,7 @@ static void computeDirectionVector(
   }
 }
 
-LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
+LogicalResult MemRefAccess::getAccessRelation(IntegerRelation &rel) const {
   // Create set corresponding to domain of access.
   FlatAffineValueConstraints domain;
   if (failed(getOpIndexSet(opInst, &domain)))
@@ -469,25 +472,29 @@ LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
   if (failed(getRelationFromMap(accessValueMap, rel)))
     return failure();
 
-  FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0,
-                               domain);
+  IntegerRelation domainRel = domain;
+  domainRel.convertVarKind(VarKind::SetDim, 0, rel.getNumDomainVars(), VarKind::Domain);
 
-  // Merge and align domain ids of `ret` and ids of `domain`. Since the domain
+  // Merge and align domain ids of `rel` and ids of `domain`. Since the domain
   // of the access map is a subset of the domain of access, the domain ids of
-  // `ret` are guranteed to be a subset of ids of `domain`.
+  // `rel` are guranteed to be a subset of ids of `domain`.
   for (unsigned i = 0, e = domain.getNumDimVars(); i < e; ++i) {
-    unsigned loc;
-    if (rel.findVar(domain.getValue(i), &loc)) {
-      rel.swapVar(i, loc);
+    const PresburgerSpace &relSpace = rel.getSpace();
+    const Identifier *findBegin = relSpace.getIds(VarKind::Domain).begin() + i;
+    const Identifier *findEnd = relSpace.getIds(VarKind::Domain).end();
+    const Identifier domainIdi = Identifier(domain.getValue(i));
+    const Identifier *itr = std::find(findBegin, findEnd, domainIdi);
+    if (itr != findEnd) {
+      rel.swapVar(i, i + std::distance(findBegin, itr));
     } else {
-      rel.insertDomainVar(i);
-      rel.setValue(i, domain.getValue(i));
+      rel.insertVar(VarKind::Domain, i);
+      rel.setId(VarKind::Domain, i, domainIdi);
     }
   }
 
   // Append domain constraints to `rel`.
-  domainRel.appendRangeVar(rel.getNumRangeDims());
-  domainRel.mergeSymbolVars(rel);
+  domainRel.appendVar(VarKind::Range, rel.getNumRangeVars());
+  domainRel.mergeAndAlignSymbols(rel);
   domainRel.mergeLocalVars(rel);
   rel.append(domainRel);
 
@@ -624,14 +631,21 @@ DependenceResult mlir::affine::checkMemrefAccessDependence(
     return DependenceResult::Failure;
 
   // Create access relation from each MemRefAccess.
-  FlatAffineRelation srcRel, dstRel;
+  PresburgerSpace space = PresburgerSpace::getRelationSpace();
+  IntegerRelation srcRel(space), dstRel(space);
   if (failed(srcAccess.getAccessRelation(srcRel)))
     return DependenceResult::Failure;
   if (failed(dstAccess.getAccessRelation(dstRel)))
     return DependenceResult::Failure;
 
-  FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
-  FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();
+  // XXX: how to downcast here? An assignment by itself is insufficient to carry
+  // over the space's identifiers.
+  IntegerPolyhedron srcDomainPoly = srcRel.getDomainSet();
+  IntegerPolyhedron dstDomainPoly = dstRel.getDomainSet();
+  FlatAffineValueConstraints srcDomain = srcDomainPoly;
+  srcDomain.setSpace(srcDomainPoly.getSpace());
+  FlatAffineValueConstraints dstDomain = dstDomainPoly;
+  dstDomain.setSpace(dstDomainPoly.getSpace());
 
   // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
   // operation of 'srcAccess' does not properly dominate the ancestor
@@ -668,7 +682,7 @@ DependenceResult mlir::affine::checkMemrefAccessDependence(
   LLVM_DEBUG(dstRel.dump());
 
   if (dependenceConstraints)
-    *dependenceConstraints = dstRel;
+    *dependenceConstraints = cast<FlatAffineValueConstraints>(dstRel);
   return DependenceResult::HasDependence;
 }
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 469298d3e8f43ff..b488ae3c0270049 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -11,12 +11,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/LinearTransform.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/LLVM.h"
@@ -372,9 +374,8 @@ FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
 void FlatAffineRelation::compose(const FlatAffineRelation &other) {
   assert(getNumDomainDims() == other.getNumRangeDims() &&
          "Domain of this and range of other do not match");
-  assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
-                    other.values.begin() + other.getNumDomainDims()) &&
-         "Domain of this and range of other do not match");
+  assert(space.getDomainSpace().isAligned(other.getSpace().getRangeSpace()) &&
+      "Values of domain of this and range of other do not match");
 
   FlatAffineRelation rel = other;
 
@@ -491,12 +492,15 @@ void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart,
 }
 
 LogicalResult mlir::affine::getRelationFromMap(AffineMap &map,
-                                               FlatAffineRelation &rel) {
+                                               IntegerRelation &rel) {
   // Get flattened affine expressions.
   std::vector<SmallVector<int64_t, 8>> flatExprs;
   FlatAffineValueConstraints localVarCst;
   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
     return failure();
+  // Add identifiers to the local constraints. We need to do this since
+  // getFlattenedAffineExprs creates a FlatLinearConstraints with no
+  // identifiers.
 
   unsigned oldDimNum = localVarCst.getNumDimVars();
   unsigned oldCols = localVarCst.getNumCols();
@@ -506,6 +510,10 @@ LogicalResult mlir::affine::getRelationFromMap(AffineMap &map,
   // Add range as the new expressions.
   localVarCst.appendDimVar(numRangeVars);
 
+  localVarCst.resetValues();
+  for(unsigned i = 0, e = localVarCst.getNumDimAndSymbolVars(); i < e; ++i)
+    localVarCst.setValue(i, Value());
+
   // Add equalities between source and range.
   SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -522,23 +530,24 @@ LogicalResult mlir::affine::getRelationFromMap(AffineMap &map,
   }
 
   // Create relation and return success.
-  rel = FlatAffineRelation(numDomainVars, numRangeVars, localVarCst);
+  rel = localVarCst;
+  rel.convertVarKind(VarKind::SetDim, 0, numDomainVars, VarKind::Domain);
   return success();
 }
 
 LogicalResult mlir::affine::getRelationFromMap(const AffineValueMap &map,
-                                               FlatAffineRelation &rel) {
+                                               IntegerRelation &rel) {
 
   AffineMap affineMap = map.getAffineMap();
   if (failed(getRelationFromMap(affineMap, rel)))
     return failure();
 
-  // Set symbol values for domain dimensions and symbols.
-  for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
-    rel.setValue(i, map.getOperand(i));
-  for (unsigned i = rel.getNumDimVars(), e = rel.getNumDimAndSymbolVars();
-       i < e; ++i)
-    rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
+  // Set identifiers for domain and symbol variables.
+  for (unsigned i = 0, e = rel.getNumDomainVars(); i < e; ++i)
+    rel.setId(VarKind::Domain, i, Identifier(map.getOperand(i)));
+
+  for(unsigned i = 0, e = rel.getNumSymbolVars(); i < e; ++i)
+    rel.setId(VarKind::Symbol, i, Identifier(map.getOperand(rel.getNumDomainVars() + i)));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 578d03c629285a8..2a35b02185190cd 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -13,6 +13,8 @@
 
 #include "mlir/Dialect/Affine/Utils.h"
 
+#include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"



More information about the Mlir-commits mailing list