[Mlir-commits] [mlir] 92744f6 - [MLIR] Add flat affine constraints method to round trip integer set

Uday Bondhugula llvmlistbot at llvm.org
Wed Mar 25 23:40:12 PDT 2020


Author: Uday Bondhugula
Date: 2020-03-26T12:07:13+05:30
New Revision: 92744f624783d92a07db25bc76e181b879f17e5b

URL: https://github.com/llvm/llvm-project/commit/92744f624783d92a07db25bc76e181b879f17e5b
DIFF: https://github.com/llvm/llvm-project/commit/92744f624783d92a07db25bc76e181b879f17e5b.diff

LOG: [MLIR] Add flat affine constraints method to round trip integer set

- add method to get back an integer set from flat affine constraints;
  this allows a round trip
- use this to complete the simplification of integer sets in
  -simplify-affine-structures
- update FlatAffineConstraints::removeTrivialRedundancy to also do GCD
  tightening and normalize by GCD (while still keeping it linear time).

Signed-off-by: Uday Bondhugula <uday at polymagelabs.com>

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Analysis/Utils.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/test/Dialect/Affine/simplify-affine-structures.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 5d99320c2860..91435d69147d 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -210,6 +210,12 @@ class FlatAffineConstraints {
                                      ValueRange operands, bool eq,
                                      bool lower = true);
 
+  /// Returns the constraint system as an integer set. Returns a null integer
+  /// set if the system has no constraints, or if an integer set couldn't be
+  /// constructed as a result of a local variable's explicit representation not
+  /// being known and such a local variable appearing in any of the constraints.
+  IntegerSet getAsIntegerSet(MLIRContext *context) const;
+
   /// Computes the lower and upper bounds of the first 'num' dimensional
   /// identifiers (starting at 'offset') as an affine map of the remaining
   /// identifiers (dimensional and symbolic). This method is able to detect
@@ -484,7 +490,8 @@ class FlatAffineConstraints {
   /// that can be detected as redundant as a result of 
diff ering only in their
   /// constant term part. A constraint of the form <non-negative constant> >= 0
   /// is considered trivially true. This method is a linear time method on the
-  /// constraints, does a single scan, and updates in place.
+  /// constraints, does a single scan, and updates in place. It also normalizes
+  /// constraints by their GCD and performs GCD tightening on inequalities.
   void removeTrivialRedundancy();
 
   /// A more expensive check to detect redundant inequalities thatn

diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index 3a40b85cea05..9b69c6f61a4b 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -290,6 +290,12 @@ Optional<int64_t> getMemoryFootprintBytes(AffineForOp forOp,
 /// Returns true if `forOp' is a parallel loop.
 bool isLoopParallel(AffineForOp forOp);
 
+/// Simplify the integer set by simplifying the underlying affine expressions by
+/// flattening and some simple inference. Also, drop any duplicate constraints.
+/// Returns the simplified integer set. This method runs in time linear in the
+/// number of constraints.
+IntegerSet simplifyIntegerSet(IntegerSet set);
+
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_UTILS_H

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 6ebc673c3100..c652aad19035 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -726,15 +726,14 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
 // equality (isEq=true) or inequality (isEq=false) constraints.
 // Returns true and sets row found in search in 'rowIdx'.
 // Returns false otherwise.
-static bool
-findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints,
-                            unsigned colIdx, bool isEq, unsigned *rowIdx) {
+static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
+                                        unsigned colIdx, bool isEq,
+                                        unsigned *rowIdx) {
+  assert(colIdx < cst.getNumCols() && "position out of bounds");
   auto at = [&](unsigned rowIdx) -> int64_t {
-    return isEq ? constraints.atEq(rowIdx, colIdx)
-                : constraints.atIneq(rowIdx, colIdx);
+    return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
   };
-  unsigned e =
-      isEq ? constraints.getNumEqualities() : constraints.getNumInequalities();
+  unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
     if (at(*rowIdx) != 0) {
       return true;
@@ -2191,13 +2190,16 @@ void FlatAffineConstraints::dump() const { print(llvm::errs()); }
 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
 //  remove duplicates in place.
 void FlatAffineConstraints::removeTrivialRedundancy() {
-  SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
+  GCDTightenInequalities();
+  normalizeConstraintsByGCD();
 
   // A map used to detect redundancy stemming from constraints that only 
diff er
   // in their constant term. The value stored is <row position, const term>
   // for a given row.
   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
       rowsWithoutConstTerm;
+  // To unique rows.
+  SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
 
   // Check if constraint is of the form <non-negative-constant> >= 0.
   auto isTriviallyValid = [&](unsigned r) -> bool {
@@ -2690,3 +2692,89 @@ FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
 
   return success();
 }
+
+/// Compute an explicit representation for local vars. For all systems coming
+/// from MLIR integer sets, maps, or expressions where local vars were
+/// introduced to model floordivs and mods, this always succeeds.
+static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
+                                      SmallVectorImpl<AffineExpr> &memo,
+                                      MLIRContext *context) {
+  unsigned numDims = cst.getNumDimIds();
+  unsigned numSyms = cst.getNumSymbolIds();
+
+  // Initialize dimensional and symbolic identifiers.
+  for (unsigned i = 0; i < numDims; i++)
+    memo[i] = getAffineDimExpr(i, context);
+  for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
+    memo[i] = getAffineSymbolExpr(i - numDims, context);
+
+  bool changed;
+  do {
+    // Each time `changed` is true at the end of this iteration, one or more
+    // local vars would have been detected as floordivs and set in memo; so the
+    // number of null entries in memo[...] strictly reduces; so this converges.
+    changed = false;
+    for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
+      if (!memo[numDims + numSyms + i] &&
+          detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
+        changed = true;
+  } while (changed);
+
+  ArrayRef<AffineExpr> localExprs =
+      ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
+  return success(
+      llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
+}
+
+/// Returns true if the pos^th column is all zero for both inequalities and
+/// equalities..
+static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
+  unsigned rowPos;
+  return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
+         !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
+}
+
+IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
+  if (getNumConstraints() == 0)
+    // Return universal set (always true): 0 == 0.
+    return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
+                           getAffineConstantExpr(/*constant=*/0, context),
+                           true);
+
+  // Construct local references.
+  SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
+
+  if (failed(computeLocalVars(*this, memo, context))) {
+    // Check if the local variables without an explicit representation have
+    // zero coefficients everywhere.
+    for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
+      if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
+        LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
+                                   "explicit representation");
+        return IntegerSet();
+      }
+    }
+  }
+
+  ArrayRef<AffineExpr> localExprs =
+      ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
+
+  // Construct the IntegerSet from the equalities/inequalities.
+  unsigned numDims = getNumDimIds();
+  unsigned numSyms = getNumSymbolIds();
+
+  SmallVector<bool, 16> eqFlags(getNumConstraints());
+  std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
+  std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
+
+  SmallVector<AffineExpr, 8> exprs;
+  exprs.reserve(getNumConstraints());
+
+  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
+    exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
+                                              localExprs, context));
+  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
+    exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
+                                              numSyms, localExprs, context));
+  return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
+}

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 9e400e4b6a3c..2b14b6a11086 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/IntegerSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -1007,3 +1008,17 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
   }
   return true;
 }
+
+IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
+  FlatAffineConstraints fac(set);
+  MLIRContext *context = set.getContext();
+  if (fac.isEmpty())
+    return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
+                                   context);
+  fac.removeTrivialRedundancy();
+
+  auto simplifiedSet = fac.getAsIntegerSet(context);
+  assert(simplifiedSet && "guaranteed to succeed while roundtripping");
+
+  return simplifiedSet;
+}

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 60ad1545d350..2ba0814fec3d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -11,12 +11,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/IntegerSet.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Utils.h"
 
 #define DEBUG_TYPE "simplify-affine-structure"
@@ -58,15 +59,7 @@ struct SimplifyAffineStructures
     op->setAttr(name, simplified);
   }
 
-  /// Performs basic integer set simplifications. Checks if it's empty, and
-  /// replaces it with the canonical empty set if it is.
-  IntegerSet simplify(IntegerSet set) {
-    FlatAffineConstraints fac(set);
-    if (fac.isEmpty())
-      return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
-                                     &getContext());
-    return set;
-  }
+  IntegerSet simplify(IntegerSet set) { return simplifyIntegerSet(set); }
 
   /// Performs basic affine map simplifications.
   AffineMap simplify(AffineMap map) {

diff  --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 89f37d0b6c92..2f3ea34c0ad1 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -simplify-affine-structures | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -simplify-affine-structures | FileCheck %s
 
 // CHECK-DAG: [[SET_EMPTY_2D:#set[0-9]+]] = affine_set<(d0, d1) : (1 == 0)>
-// CHECK-DAG: #set1 = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)>
+// CHECK-DAG: #set1 = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)>
 // CHECK-DAG: #set2 = affine_set<(d0, d1)[s0, s1] : (1 == 0)>
 // CHECK-DAG: #set3 = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)>
 // CHECK-DAG: [[SET_EMPTY_1D:#set[0-9]+]] = affine_set<(d0) : (1 == 0)>
@@ -236,3 +236,23 @@ func @test_empty_set(%N : index) {
 
   return
 }
+
+// -----
+
+// CHECK-DAG: #[[SET1:.*]] = affine_set<(d0, d1) : (d0 >= 0, -d0 + 50 >= 0)
+// CHECK-DAG: #[[SET2:.*]] = affine_set<(d0, d1) : (1 == 0)
+// CHECK-DAG: #[[SET3:.*]] = affine_set<(d0, d1) : (0 == 0)
+
+// CHECK-LABEL: func @simplify_set
+func @simplify_set(%a : index, %b : index) {
+  // CHECK: affine.if #[[SET1]]
+  affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) {
+  }
+  // CHECK: affine.if #[[SET2]]
+  affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) {
+  }
+  // CHECK: affine.if #[[SET3]]
+  affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) {
+  }
+	return
+}


        


More information about the Mlir-commits mailing list