[Mlir-commits] [mlir] 93d640f - [mlir][SCF][Utils][NFC] Make some utils public for better reuse

Matthias Springer llvmlistbot at llvm.org
Wed Feb 22 01:40:19 PST 2023


Author: Matthias Springer
Date: 2023-02-22T10:35:48+01:00
New Revision: 93d640f3922b2a15501101b229f8be40e8528a63

URL: https://github.com/llvm/llvm-project/commit/93d640f3922b2a15501101b229f8be40e8528a63
DIFF: https://github.com/llvm/llvm-project/commit/93d640f3922b2a15501101b229f8be40e8528a63.diff

LOG: [mlir][SCF][Utils][NFC] Make some utils public for better reuse

These functions will be used in a subsequent change. Also some minor refactoring.

Differential Revision: https://reviews.llvm.org/D143909

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
index fea870455f2a3..88c93dbfd7f18 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
@@ -39,6 +39,16 @@ class IfOp;
 using LoopMatcherFn = function_ref<LogicalResult(
     Value, OpFoldResult &, OpFoldResult &, OpFoldResult &)>;
 
+/// Match "for loop"-like operations from the SCF dialect.
+LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub,
+                               OpFoldResult &step);
+
+/// Populate the given constraint set with induction variable constraints of a
+/// "for" loop with the given range and step.
+LogicalResult addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
+                                      Value iv, OpFoldResult lb,
+                                      OpFoldResult ub, OpFoldResult step);
+
 /// Try to canonicalize the given affine.min/max operation in the context of
 /// for `loops` with a known range.
 ///

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index aee10633c59f8..8cbca1b6da914 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -158,40 +158,7 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
-                          OpFoldResult &step) {
-      if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
-        lb = forOp.getLowerBound();
-        ub = forOp.getUpperBound();
-        step = forOp.getStep();
-        return success();
-      }
-      if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
-        for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
-          if (parOp.getInductionVars()[idx] == iv) {
-            lb = parOp.getLowerBound()[idx];
-            ub = parOp.getUpperBound()[idx];
-            step = parOp.getStep()[idx];
-            return success();
-          }
-        }
-        return failure();
-      }
-      if (scf::ForallOp forallOp = scf::getForallOpThreadIndexOwner(iv)) {
-        for (int64_t idx = 0; idx < forallOp.getRank(); ++idx) {
-          if (forallOp.getInductionVar(idx) == iv) {
-            lb = forallOp.getMixedLowerBound()[idx];
-            ub = forallOp.getMixedUpperBound()[idx];
-            step = forallOp.getMixedStep()[idx];
-            return success();
-          }
-        }
-        return failure();
-      }
-      return failure();
-    };
-
-    return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher);
+    return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop);
   }
 };
 

diff  --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 4ee27e4d00343..7799fa9231c62 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -12,12 +12,12 @@
 
 #include <utility>
 
-#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
@@ -29,6 +29,39 @@
 using namespace mlir;
 using namespace presburger;
 
+LogicalResult scf::matchForLikeLoop(Value iv, OpFoldResult &lb,
+                                    OpFoldResult &ub, OpFoldResult &step) {
+  if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
+    lb = forOp.getLowerBound();
+    ub = forOp.getUpperBound();
+    step = forOp.getStep();
+    return success();
+  }
+  if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
+    for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
+      if (parOp.getInductionVars()[idx] == iv) {
+        lb = parOp.getLowerBound()[idx];
+        ub = parOp.getUpperBound()[idx];
+        step = parOp.getStep()[idx];
+        return success();
+      }
+    }
+    return failure();
+  }
+  if (scf::ForallOp forallOp = scf::getForallOpThreadIndexOwner(iv)) {
+    for (int64_t idx = 0; idx < forallOp.getRank(); ++idx) {
+      if (forallOp.getInductionVar(idx) == iv) {
+        lb = forallOp.getMixedLowerBound()[idx];
+        ub = forallOp.getMixedUpperBound()[idx];
+        step = forallOp.getMixedStep()[idx];
+        return success();
+      }
+    }
+    return failure();
+  }
+  return failure();
+}
+
 static FailureOr<AffineApplyOp>
 canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op,
                      FlatAffineValueConstraints constraints) {
@@ -42,37 +75,38 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op,
       op, simplified->getAffineMap(), simplified->getOperands());
 }
 
-static LogicalResult
-addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
-                        OpFoldResult lb, OpFoldResult ub, OpFoldResult step,
-                        RewriterBase &rewriter) {
+LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
+                                           Value iv, OpFoldResult lb,
+                                           OpFoldResult ub, OpFoldResult step) {
+  Builder b(iv.getContext());
+
   // IntegerPolyhedron does not support semi-affine expressions.
   // Therefore, only constant step values are supported.
   auto stepInt = getConstantIntValue(step);
   if (!stepInt)
     return failure();
 
-  unsigned dimIv = constraints.appendDimVar(iv);
+  unsigned dimIv = cstr.appendDimVar(iv);
   auto lbv = lb.dyn_cast<Value>();
-  unsigned symLb = lbv ? constraints.appendSymbolVar(lbv)
-                       : constraints.appendSymbolVar(/*num=*/1);
+  unsigned symLb =
+      lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
   auto ubv = ub.dyn_cast<Value>();
-  unsigned symUb = ubv ? constraints.appendSymbolVar(ubv)
-                       : constraints.appendSymbolVar(/*num=*/1);
+  unsigned symUb =
+      ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
 
   // If loop lower/upper bounds are constant: Add EQ constraint.
   std::optional<int64_t> lbInt = getConstantIntValue(lb);
   std::optional<int64_t> ubInt = getConstantIntValue(ub);
   if (lbInt)
-    constraints.addBound(IntegerPolyhedron::EQ, symLb, *lbInt);
+    cstr.addBound(IntegerPolyhedron::EQ, symLb, *lbInt);
   if (ubInt)
-    constraints.addBound(IntegerPolyhedron::EQ, symUb, *ubInt);
+    cstr.addBound(IntegerPolyhedron::EQ, symUb, *ubInt);
 
   // Lower bound: iv >= lb (equiv.: iv - lb >= 0)
-  SmallVector<int64_t> ineqLb(constraints.getNumCols(), 0);
+  SmallVector<int64_t> ineqLb(cstr.getNumCols(), 0);
   ineqLb[dimIv] = 1;
   ineqLb[symLb] = -1;
-  constraints.addInequality(ineqLb);
+  cstr.addInequality(ineqLb);
 
   // Upper bound
   AffineExpr ivUb;
@@ -81,26 +115,23 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
     // iv < lb + 1
     // TODO: Try to derive this constraint by simplifying the expression in
     // the else-branch.
-    ivUb =
-        rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars()) + 1;
+    ivUb = b.getAffineSymbolExpr(symLb - cstr.getNumDimVars()) + 1;
   } else {
     // The loop may have more than one iteration.
     // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
     AffineExpr exprLb =
-        lbInt
-            ? rewriter.getAffineConstantExpr(*lbInt)
-            : rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars());
+        lbInt ? b.getAffineConstantExpr(*lbInt)
+              : b.getAffineSymbolExpr(symLb - cstr.getNumDimVars());
     AffineExpr exprUb =
-        ubInt
-            ? rewriter.getAffineConstantExpr(*ubInt)
-            : rewriter.getAffineSymbolExpr(symUb - constraints.getNumDimVars());
+        ubInt ? b.getAffineConstantExpr(*ubInt)
+              : b.getAffineSymbolExpr(symUb - cstr.getNumDimVars());
     ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt)));
   }
   auto map = AffineMap::get(
-      /*dimCount=*/constraints.getNumDimVars(),
-      /*symbolCount=*/constraints.getNumSymbolVars(), /*result=*/ivUb);
+      /*dimCount=*/cstr.getNumDimVars(),
+      /*symbolCount=*/cstr.getNumSymbolVars(), /*result=*/ivUb);
 
-  return constraints.addBound(IntegerPolyhedron::UB, dimIv, map);
+  return cstr.addBound(IntegerPolyhedron::UB, dimIv, map);
 }
 
 /// Canonicalize min/max operations in the context of for loops with a known
@@ -132,8 +163,7 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
       continue;
     allIvs.insert(iv);
 
-    if (failed(
-            addLoopRangeConstraints(constraints, iv, lb, ub, step, rewriter)))
+    if (failed(addLoopRangeConstraints(constraints, iv, lb, ub, step)))
       return failure();
   }
 


        


More information about the Mlir-commits mailing list