[Mlir-commits] [mlir] 5b2d850 - [mlir][Linalg] NFC - Expose helper function `substituteMin`.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Mar 19 09:31:41 PDT 2021


Author: Nicolas Vasilache
Date: 2021-03-19T16:26:52Z
New Revision: 5b2d8503d1d4b925e30fd2b91f97bfd625f03157

URL: https://github.com/llvm/llvm-project/commit/5b2d8503d1d4b925e30fd2b91f97bfd625f03157
DIFF: https://github.com/llvm/llvm-project/commit/5b2d8503d1d4b925e30fd2b91f97bfd625f03157.diff

LOG: [mlir][Linalg] NFC - Expose helper function `substituteMin`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 54a4aec9f867..6d428384080b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -893,6 +893,30 @@ struct AffineMinSCFCanonicalizationPattern
                                 PatternRewriter &rewriter) const override;
 };
 
+  /// Helper struct to return the results of `substituteMin`.
+struct AffineMapAndOperands {
+  AffineMap map;
+  SmallVector<Value> dims;
+  SmallVector<Value> symbols;
+};
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+///   - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+///     replaced by the loop upper bound.
+///   - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+///     replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands substituteMin(
+  AffineMinOp affineMinOp,
+  llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+
 /// Converts Convolution op into vector contraction.
 ///
 /// Conversion expects ConvOp to have dimensions marked in the *mask* as

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c2e52c63eabd..fef6dd8f996f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -536,8 +536,10 @@ static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
 
 /// Traverse the `dims` and substitute known min or max expressions in place of
 /// induction variables in `exprs`.
-static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
-                            SmallVectorImpl<Value> &symbols) {
+static AffineMap substitute(
+    AffineMap map, SmallVectorImpl<Value> &dims,
+    SmallVectorImpl<Value> &symbols,
+    llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
   auto exprs = llvm::to_vector<4>(map.getResults());
   for (AffineExpr &expr : exprs) {
     bool substituted = true;
@@ -549,17 +551,19 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
         LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
         AffineExpr substitutedExpr;
         if (auto forOp = scf::getForInductionVarOwner(dim))
-          substitutedExpr = substituteLoopInExpr(
-              expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
-              forOp.step(), dims, symbols);
+          if (!substituteOperation || substituteOperation(forOp))
+            substitutedExpr = substituteLoopInExpr(
+                expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
+                forOp.step(), dims, symbols);
 
         if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
-          for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
-               ++idx)
-            substitutedExpr = substituteLoopInExpr(
-                expr, dimExpr, parallelForOp.lowerBound()[idx],
-                parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
-                dims, symbols);
+          if (!substituteOperation || substituteOperation(parallelForOp))
+            for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
+                 ++idx)
+              substitutedExpr = substituteLoopInExpr(
+                  expr, dimExpr, parallelForOp.lowerBound()[idx],
+                  parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
+                  dims, symbols);
 
         if (!substitutedExpr)
           continue;
@@ -578,6 +582,9 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
                               exprs.front().getContext());
 
     LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
+    LLVM_DEBUG(DBGS() << "Operands:\n");
+    for (Value v : operands)
+      LLVM_DEBUG(DBGS() << v << "\n");
 
     // Pull in affine.apply operations and compose them fully into the
     // result.
@@ -596,14 +603,38 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
   return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
 }
 
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+///   - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+///     replaced by the loop upper bound.
+///   - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+///     replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands mlir::linalg::substituteMin(
+    AffineMinOp affineMinOp,
+    llvm::function_ref<bool(Operation *)> substituteOperation) {
+  AffineMapAndOperands res{affineMinOp.getAffineMap(),
+                           SmallVector<Value>(affineMinOp.getDimOperands()),
+                           SmallVector<Value>(affineMinOp.getSymbolOperands())};
+  res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
+                       substituteOperation);
+  return res;
+}
+
 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
     AffineMinOp minOp, PatternRewriter &rewriter) const {
   LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
                     << "\n");
 
-  SmallVector<Value, 4> dims(minOp.getDimOperands()),
-      symbols(minOp.getSymbolOperands());
-  AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
+  auto affineMapAndOperands = substituteMin(minOp);
+  AffineMap map = affineMapAndOperands.map;
 
   LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
 
@@ -638,8 +669,8 @@ LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
       rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
     } else {
       auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
-      SmallVector<Value, 4> resultOperands = dims;
-      resultOperands.append(symbols.begin(), symbols.end());
+      SmallVector<Value> resultOperands = affineMapAndOperands.dims;
+      llvm::append_range(resultOperands, affineMapAndOperands.symbols);
       canonicalizeMapAndOperands(&resultMap, &resultOperands);
       resultMap = simplifyAffineMap(resultMap);
       rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,


        


More information about the Mlir-commits mailing list