[Mlir-commits] [mlir] ded1870 - [mlir][NFC] Refactor linalg substituteMin and AffineMinSCF canonizalizations

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 21 07:19:45 PDT 2021


Author: thomasraoux
Date: 2021-04-21T07:19:36-07:00
New Revision: ded18708f91fd32fd799ae3af3b86c285d6dd0e5

URL: https://github.com/llvm/llvm-project/commit/ded18708f91fd32fd799ae3af3b86c285d6dd0e5
DIFF: https://github.com/llvm/llvm-project/commit/ded18708f91fd32fd799ae3af3b86c285d6dd0e5.diff

LOG: [mlir][NFC] Refactor linalg substituteMin and AffineMinSCF canonizalizations

Break up the dependency between SCF ops and substituteMin helper and make a
more generic version of AffineMinSCFCanonicalization. This reduce dependencies
between linalg and SCF and will allow the logic to be used with other kind of
ops. (Like ID ops).

Differential Revision: https://reviews.llvm.org/D100321

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/SCF/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/SCF/Transforms/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48b1eb8cdf50..884db2e93966 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -10,6 +10,7 @@
 #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
 
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"
@@ -934,24 +935,44 @@ struct LinalgCopyVTWForwardingPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and
-/// scf.parallel by:
-///   1. building an affine map where uses of the induction variable of a loop
-///   are replaced by either the min (i.e. `%lb`) of the max
-///   (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending
-///   on whether the induction variable is used with a positive or negative
-///   coefficient.
+using GetMinMaxExprFn =
+    std::function<Optional<std::pair<AffineExpr, AffineExpr>>(
+        Value value, SmallVectorImpl<Value> &dims,
+        SmallVectorImpl<Value> &symbols)>;
+
+/// Canonicalize AffineMinOp operations in the context of ops with a known range
+/// by:
+///   1. building an affine map where uses of the known ops are replaced by
+///   their min annd max expressions returned by the lambda `getMinMaxFn`.
 ///   2. checking whether any of the results of this affine map is known to be
 ///   greater than all other results.
 ///   3. replacing the AffineMinOp by the result of (2).
-// TODO: move to a more appropriate place when it is determined. For now Linalg
-// depends both on Affine and SCF but they do not depend on each other.
-struct AffineMinSCFCanonicalizationPattern
+struct AffineMinRangeCanonicalizationPattern
     : public OpRewritePattern<AffineMinOp> {
-  using OpRewritePattern<AffineMinOp>::OpRewritePattern;
-
+  AffineMinRangeCanonicalizationPattern(MLIRContext *context,
+                                        GetMinMaxExprFn getMinMaxFn)
+      : OpRewritePattern<AffineMinOp>(context), getMinMaxFn(getMinMaxFn) {}
   LogicalResult matchAndRewrite(AffineMinOp minOp,
                                 PatternRewriter &rewriter) const override;
+
+protected:
+  GetMinMaxExprFn getMinMaxFn;
+};
+
+/// Specialized version of `AffineMinRangeCanonicalizationPattern` pattern
+/// using `getSCFMinMaxExpr` to know the min and max expression of induction
+/// variables from scf loops.
+// TODO: move to a more appropriate place when it is determined. For now Linalg
+// depends both on Affine and SCF but they do not depend on each other.
+struct AffineMinSCFCanonicalizationPattern
+    : public AffineMinRangeCanonicalizationPattern {
+  static Optional<std::pair<AffineExpr, AffineExpr>>
+  getMinMax(Value value, SmallVectorImpl<Value> &dims,
+            SmallVectorImpl<Value> &symbols) {
+    return getSCFMinMaxExpr(value, dims, symbols);
+  }
+  AffineMinSCFCanonicalizationPattern(MLIRContext *context)
+      : AffineMinRangeCanonicalizationPattern(context, getMinMax) {}
 };
 
 /// Helper struct to return the results of `substituteMin`.
@@ -960,23 +981,22 @@ struct AffineMapAndOperands {
   SmallVector<Value> dims;
   SmallVector<Value> symbols;
 };
-/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
-/// induction variables by new expressions involving the lower or upper bound:
-///   - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
-///     replaced by the loop upper bound.
-///   - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
-///     replaced by the loop lower bound.
-/// All loop induction variables are iteratively replaced, unless a
-/// `substituteOperation` hook is passed to more finely determine which
-/// operations are substituted.
+
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
+/// dimensions with known range by new expressions involving the min or max
+/// expression:
+///   - If the AffineDimExpr mapped to a known value has a positive sign, it
+///     is replaced by the min expression.
+///   - If the AffineDimExpr mapped to a known value has a negative sign, it is
+///     replaced by the max expression.
+/// All known values are iteratively replaced.
 /// This is used as an intermediate step in computing bounding boxes and
 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
 /// positive values (positive orthant assumptions).
 /// Return a new AffineMap, dims and symbols that have been canonicalized and
 /// simplified.
-AffineMapAndOperands substituteMin(
-    AffineMinOp affineMinOp,
-    llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+AffineMapAndOperands substituteMin(AffineMinOp affineMinOp,
+                                   GetMinMaxExprFn getMinMaxExpr);
 
 /// Converts Convolution op into vector contraction.
 ///

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
index 06c2528402b0..b4f11356b121 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -14,12 +14,16 @@
 #define MLIR_DIALECT_SCF_UTILS_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 class FuncOp;
 class Operation;
 class OpBuilder;
 class ValueRange;
+class Value;
+class AffineExpr;
+class Operation;
 
 namespace scf {
 class IfOp;
@@ -64,5 +68,14 @@ void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
 bool getInnermostParallelLoops(Operation *rootOp,
                                SmallVectorImpl<scf::ParallelOp> &result);
 
+/// Return the min/max expressions for `value` if it is an induction variable
+/// from scf.for or scf.parallel loop.
+/// if `loopFilter` is passed, the filter determines which loop to consider.
+/// Other induction variables are ignored.
+Optional<std::pair<AffineExpr, AffineExpr>>
+getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
+                 SmallVectorImpl<Value> &symbols,
+                 llvm::function_ref<bool(Operation *)> loopFilter = nullptr);
+
 } // end namespace mlir
 #endif // MLIR_DIALECT_SCF_UTILS_H_

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index c9cdcde382a5..5962a98cb609 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -574,9 +574,15 @@ static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer,
       continue;
     }
     auto sliceMinOp = cast<AffineMinOp>(op);
+    GetMinMaxExprFn getSCFMinMax = [&](Value value,
+                                       SmallVectorImpl<Value> &dims,
+                                       SmallVectorImpl<Value> &symbols) {
+      return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) {
+        return outer->isAncestor(op);
+      });
+    };
     // Perform the substitution of the operands of AffineMinOp.
-    auto mapAndOperands = substituteMin(
-        sliceMinOp, [&](Operation *op) { return outer->isAncestor(op); });
+    auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax);
     SmallVector<Value> resultOperands = mapAndOperands.dims;
     llvm::append_range(resultOperands, mapAndOperands.symbols);
     AffineMap map = mapAndOperands.map;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2e8b1580c5c7..acf460982784 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -507,37 +507,11 @@ LogicalResult mlir::linalg::applyStagedPatterns(
   return success();
 }
 
-/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
-/// `ubVal` to `dims` and `stepVal` to `symbols`.
-/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
-/// with positions matching the newly appended values. Substitute occurrences of
-/// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression
-/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether
-/// the induction variable is used with a positive or negative  coefficient.
-static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
-                                       Value lbVal, Value ubVal, Value stepVal,
-                                       SmallVectorImpl<Value> &dims,
-                                       SmallVectorImpl<Value> &symbols) {
-  MLIRContext *ctx = lbVal.getContext();
-  AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
-  dims.push_back(lbVal);
-  AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
-  dims.push_back(ubVal);
-  AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
-  symbols.push_back(stepVal);
-  LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
-  AffineExpr ee = substWithMin(expr, dimExpr, lb,
-                               lb + step * ((ub - 1) - lb).floorDiv(step));
-  LLVM_DEBUG(DBGS() << "After: " << expr << "\n");
-  return ee;
-}
-
-/// Traverse the `dims` and substitute known min or max expressions in place of
-/// induction variables in `exprs`.
-static AffineMap substitute(
-    AffineMap map, SmallVectorImpl<Value> &dims,
-    SmallVectorImpl<Value> &symbols,
-    llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
+/// Traverse the `dims` and substitute known min or max expressions returned by
+/// the lambda |getMinMaxExpr|.
+static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
+                            SmallVectorImpl<Value> &symbols,
+                            GetMinMaxExprFn getMinMaxExpr) {
   auto exprs = llvm::to_vector<4>(map.getResults());
   for (AffineExpr &expr : exprs) {
     bool substituted = true;
@@ -545,27 +519,18 @@ static AffineMap substitute(
       substituted = false;
       for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
         Value dim = dims[dimIdx];
+        auto minMax = getMinMaxExpr(dim, dims, symbols);
+        if (!minMax)
+          continue;
         AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
         LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
-        AffineExpr substitutedExpr;
-        if (auto forOp = scf::getForInductionVarOwner(dim))
-          if (!substituteOperation || substituteOperation(forOp))
-            substitutedExpr = substituteLoopInExpr(
-                expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
-                forOp.step(), dims, symbols);
-
-        if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
-          if (!substituteOperation || substituteOperation(parallelForOp))
-            for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
-                 ++idx)
-              substitutedExpr = substituteLoopInExpr(
-                  expr, dimExpr, parallelForOp.lowerBound()[idx],
-                  parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
-                  dims, symbols);
-
-        if (!substitutedExpr)
-          continue;
-
+        LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
+        // Substitute occurrences of `dimExpr` by either the min expression or
+        // the max expression depending on whether the value is used with a
+        // positive or negative  coefficient.
+        AffineExpr substitutedExpr =
+            substWithMin(expr, dimExpr, minMax->first, minMax->second);
+        LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n");
         substituted = (substitutedExpr != expr);
         expr = substitutedExpr;
       }
@@ -603,37 +568,36 @@ static AffineMap substitute(
   return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
 }
 
-/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
-/// induction variables by new expressions involving the lower or upper bound:
-///   - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
-///     replaced by the loop upper bound.
-///   - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
-///     replaced by the loop lower bound.
-/// All loop induction variables are iteratively replaced, unless a
-/// `substituteOperation` hook is passed to more finely determine which
-/// operations are substituted.
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
+/// dimensions with known range by new expressions involving the min or max
+/// expression:
+///   - If the AffineDimExpr mapped to a known value has a positive sign, it
+///     is replaced by the min expression.
+///   - If the AffineDimExpr mapped to a known value has a negative sign, it is
+///     replaced by the max expression.
+/// All known values are iteratively replaced.
 /// This is used as an intermediate step in computing bounding boxes and
 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
 /// positive values (positive orthant assumptions).
 /// Return a new AffineMap, dims and symbols that have been canonicalized and
 /// simplified.
-AffineMapAndOperands mlir::linalg::substituteMin(
-    AffineMinOp affineMinOp,
-    llvm::function_ref<bool(Operation *)> substituteOperation) {
+AffineMapAndOperands
+mlir::linalg::substituteMin(AffineMinOp affineMinOp,
+                            GetMinMaxExprFn getMinMaxExpr) {
   AffineMapAndOperands res{affineMinOp.getAffineMap(),
                            SmallVector<Value>(affineMinOp.getDimOperands()),
                            SmallVector<Value>(affineMinOp.getSymbolOperands())};
   res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
-                       substituteOperation);
+                       getMinMaxExpr);
   return res;
 }
 
-LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
+LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
     AffineMinOp minOp, PatternRewriter &rewriter) const {
   LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
                     << "\n");
 
-  auto affineMapAndOperands = substituteMin(minOp);
+  auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn);
   AffineMap map = affineMapAndOperands.map;
 
   LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index 00fa22be754f..3bd37d7a219d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -145,3 +145,43 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
   }
   return rootEnclosesPloops;
 }
+
+/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
+/// `ubVal` to `dims` and `stepVal` to `symbols`.
+/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
+/// with positions matching the newly appended values. Then create a min
+/// expression (i.e. `%lb`) and a max expression
+/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`.
+static std::pair<AffineExpr, AffineExpr>
+getMinMaxLoopIndVar(Value lbVal, Value ubVal, Value stepVal,
+                    SmallVectorImpl<Value> &dims,
+                    SmallVectorImpl<Value> &symbols) {
+  MLIRContext *ctx = lbVal.getContext();
+  AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
+  dims.push_back(lbVal);
+  AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
+  dims.push_back(ubVal);
+  AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
+  symbols.push_back(stepVal);
+  return std::make_pair(lb, lb + step * ((ub - 1) - lb).floorDiv(step));
+}
+
+/// Return the min/max expressions for `value` if it is an induction variable
+/// from scf.for or scf.parallel loop.
+/// if `loopFilter` is passed, the filter determines which loop to consider.
+/// Other induction variables are ignored.
+Optional<std::pair<AffineExpr, AffineExpr>> mlir::getSCFMinMaxExpr(
+    Value value, SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &symbols,
+    llvm::function_ref<bool(Operation *)> substituteOperation) {
+  if (auto forOp = scf::getForInductionVarOwner(value))
+    return getMinMaxLoopIndVar(forOp.lowerBound(), forOp.upperBound(),
+                               forOp.step(), dims, symbols);
+
+  if (auto parallelForOp = scf::getParallelForInductionVarOwner(value))
+    for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx)
+      if (parallelForOp.getInductionVars()[idx] == value)
+        return getMinMaxLoopIndVar(parallelForOp.lowerBound()[idx],
+                                   parallelForOp.upperBound()[idx],
+                                   parallelForOp.step()[idx], dims, symbols);
+  return {};
+}


        


More information about the Mlir-commits mailing list