[Mlir-commits] [mlir] 3110e7b - [mlir] Introduce AffineMinSCF folding as a pattern

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 7 11:33:15 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-07T14:30:38-04:00
New Revision: 3110e7b077d0031e8743614f742a500ccc522c77

URL: https://github.com/llvm/llvm-project/commit/3110e7b077d0031e8743614f742a500ccc522c77
DIFF: https://github.com/llvm/llvm-project/commit/3110e7b077d0031e8743614f742a500ccc522c77.diff

LOG: [mlir] Introduce AffineMinSCF folding as a pattern

This revision adds a folding pattern to replace affine.min ops by the actual min value, when it can be determined statically from the strides and bounds of enclosing scf loop .

This matches the type of expressions that Linalg produces during tiling and simplifies boundary checks. For now Linalg depends both on Affine and SCF but they do not depend on each other, so the pattern is added there.
In the future this will move to a more appropriate place when it is determined.

The canonicalization of AffineMinOp operations in the context of enclosing scf.for and scf.parallel proceeds by:
  1. building an affine map where uses of the induction variable of a loop
  are replaced by `%lb + %step * floordiv(%iv - %lb, %step)` expressions.
  2. checking if any of the results of this affine map divides all the other
  results (in which case it is also guaranteed to be the min).
  3. replacing the AffineMinOp by the result of (2).

The algorithm is functional in simple parametric tiling cases by using semi-affine maps. However simplifications of such semi-affine maps are not yet available and the canonicalization does not succeed yet.

Differential Revision: https://reviews.llvm.org/D82009

Added: 
    mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/IR/AffineExpr.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 65f1bbb833ba..8fce95781c4d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -502,6 +502,26 @@ struct LinalgCopyVTWForwardingPattern
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and
+/// scf.parallel by:
+///   1. building an affine map where uses of the induction variable of a loop
+///   are replaced by either the min (i.e. `%lb`) of the max
+///   (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending
+///   on whether the induction variable is used with a positive or negative
+///   coefficient.
+///   2. checking whether any of the results of this affine map is known to be
+///   greater than all other results.
+///   3. replacing the AffineMinOp by the result of (2).
+// TODO: move to a more appropriate place when it is determined. For now Linalg
+// depends both on Affine and SCF but they do not depend on each other.
+struct AffineMinSCFCanonicalizationPattern
+    : public OpRewritePattern<AffineMinOp> {
+  using OpRewritePattern<AffineMinOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineMinOp minOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//
@@ -519,6 +539,7 @@ LogicalResult applyStagedPatterns(
     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
     const OwningRewritePatternList &stage2Patterns,
     function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 2df16ee2bfc9..2b9153e0d0fc 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -115,9 +115,20 @@ class AffineExpr {
 
   /// This method substitutes any uses of dimensions and symbols (e.g.
   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
+  /// This is a dense replacement method: a replacement must be specified for
+  /// every single dim and symbol.
   AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
                                    ArrayRef<AffineExpr> symReplacements) const;
 
+  /// Sparse replace method. Replace `expr` by `replacement` and return the
+  /// modified expression tree.
+  AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
+
+  /// Sparse replace method. If `*this` appears in `map` replaces it by
+  /// `map[*this]` and return the modified expression tree. Otherwise traverse
+  /// `*this` and apply replace with `map` on its subexpressions.
+  AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
+
   /// Replace symbols[0 .. numDims - 1] by
   /// symbols[shift .. shift + numDims - 1].
   AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6af4de7e9d83..afac3d5f5f9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -36,6 +36,7 @@ using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
@@ -235,3 +236,177 @@ LogicalResult mlir::linalg::applyStagedPatterns(
   }
   return success();
 }
+
+/// Traverse `e` and return an AffineExpr where all occurrences of `dim` have
+/// been replaced by either:
+///  - `min` if `positivePath` is true when we reach an occurrence of `dim`
+///  - `max` if `positivePath` is true when we reach an occurrence of `dim`
+/// `positivePath` is negated each time we hit a multiplicative or divisive
+/// binary op with a constant negative coefficient.
+static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
+                               AffineExpr max, bool positivePath = true) {
+  if (e == dim)
+    return positivePath ? min : max;
+  if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
+    AffineExpr lhs = bin.getLHS();
+    AffineExpr rhs = bin.getRHS();
+    if (bin.getKind() == mlir::AffineExprKind::Add)
+      return substWithMin(lhs, dim, min, max, positivePath) +
+             substWithMin(rhs, dim, min, max, positivePath);
+
+    auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
+    auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
+    if (c1 && c1.getValue() < 0)
+      return getAffineBinaryOpExpr(
+          bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
+    if (c2 && c2.getValue() < 0)
+      return getAffineBinaryOpExpr(
+          bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
+    return getAffineBinaryOpExpr(
+        bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
+        substWithMin(rhs, dim, min, max, positivePath));
+  }
+  return e;
+}
+
+/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
+/// `ubVal` to `dims` and `stepVal` to `symbols`.
+/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
+/// with positions matching the newly appended values. Substitute occurrences of
+/// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression
+/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether
+/// the induction variable is used with a positive or negative  coefficient.
+static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
+                                       Value lbVal, Value ubVal, Value stepVal,
+                                       SmallVectorImpl<Value> &dims,
+                                       SmallVectorImpl<Value> &symbols) {
+  MLIRContext *ctx = lbVal.getContext();
+  AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
+  dims.push_back(lbVal);
+  AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
+  dims.push_back(ubVal);
+  AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
+  symbols.push_back(stepVal);
+  LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
+  AffineExpr ee = substWithMin(expr, dimExpr, lb,
+                               lb + step * ((ub - 1) - lb).floorDiv(step));
+  LLVM_DEBUG(DBGS() << "After: " << expr << "\n");
+  return ee;
+}
+
+/// Traverse the `dims` and substitute known min or max expressions in place of
+/// induction variables in `exprs`.
+static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
+                            SmallVectorImpl<Value> &symbols) {
+  auto exprs = llvm::to_vector<4>(map.getResults());
+  for (AffineExpr &expr : exprs) {
+    bool substituted = true;
+    while (substituted) {
+      substituted = false;
+      for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
+        Value dim = dims[dimIdx];
+        AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
+        LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
+        AffineExpr substitutedExpr;
+        if (auto forOp = scf::getForInductionVarOwner(dim))
+          substitutedExpr = substituteLoopInExpr(
+              expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
+              forOp.step(), dims, symbols);
+
+        if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
+          for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
+               ++idx)
+            substitutedExpr = substituteLoopInExpr(
+                expr, dimExpr, parallelForOp.lowerBound()[idx],
+                parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
+                dims, symbols);
+
+        if (!substitutedExpr)
+          continue;
+
+        substituted = (substitutedExpr != expr);
+        expr = substitutedExpr;
+      }
+    }
+
+    // Cleanup and simplify the results.
+    // This needs to happen outside of the loop iterating on dims.size() since
+    // it modifies dims.
+    SmallVector<Value, 4> operands(dims.begin(), dims.end());
+    operands.append(symbols.begin(), symbols.end());
+    auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
+                              exprs.front().getContext());
+
+    LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
+
+    // Pull in affine.apply operations and compose them fully into the
+    // result.
+    fullyComposeAffineMapAndOperands(&map, &operands);
+    canonicalizeMapAndOperands(&map, &operands);
+    map = simplifyAffineMap(map);
+    // Assign the results.
+    exprs.assign(map.getResults().begin(), map.getResults().end());
+    dims.assign(operands.begin(), operands.begin() + map.getNumDims());
+    symbols.assign(operands.begin() + map.getNumDims(), operands.end());
+
+    LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
+  }
+
+  assert(!exprs.empty() && "Unexpected empty exprs");
+  return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
+}
+
+LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
+    AffineMinOp minOp, PatternRewriter &rewriter) const {
+  LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
+                    << "\n");
+
+  SmallVector<Value, 4> dims(minOp.getDimOperands()),
+      symbols(minOp.getSymbolOperands());
+  AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
+
+  LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
+
+  // Check whether any of the expressions, when subtracted from all other
+  // expressions, produces only >= 0 constants. If so, it is the min.
+  for (auto e : minOp.getAffineMap().getResults()) {
+    LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
+    if (!e.isSymbolicOrConstant())
+      continue;
+
+    auto isNonPositive = [](AffineExpr e) {
+      if (auto cst = e.dyn_cast<AffineConstantExpr>())
+        return cst.getValue() < 0;
+      return true;
+    };
+
+    // Build the subMap and check everything is statically known to be
+    // positive.
+    SmallVector<AffineExpr, 4> subExprs;
+    subExprs.reserve(map.getNumResults());
+    for (auto ee : map.getResults())
+      subExprs.push_back(ee - e);
+    MLIRContext *ctx = minOp.getContext();
+    AffineMap subMap = simplifyAffineMap(
+        AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
+    LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
+    if (llvm::any_of(subMap.getResults(), isNonPositive))
+      continue;
+
+    // Static min found.
+    if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
+      rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
+    } else {
+      auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
+      SmallVector<Value, 4> resultOperands = dims;
+      resultOperands.append(symbols.begin(), symbols.end());
+      canonicalizeMapAndOperands(&resultMap, &resultOperands);
+      resultMap = simplifyAffineMap(resultMap);
+      rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
+                                                 resultOperands);
+    }
+    return success();
+  }
+
+  return failure();
+}

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0d4d9d08c935..c78e7e1eac57 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -101,6 +101,37 @@ AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
   return replaceDimsAndSymbols({}, symbols);
 }
 
+/// Sparse replace method. Return the modified expression tree.
+AffineExpr
+AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
+  auto it = map.find(*this);
+  if (it != map.end())
+    return it->second;
+  switch (getKind()) {
+  default:
+    return *this;
+  case AffineExprKind::Add:
+  case AffineExprKind::Mul:
+  case AffineExprKind::FloorDiv:
+  case AffineExprKind::CeilDiv:
+  case AffineExprKind::Mod:
+    auto binOp = cast<AffineBinaryOpExpr>();
+    auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
+    auto newLHS = lhs.replace(map);
+    auto newRHS = rhs.replace(map);
+    if (newLHS == lhs && newRHS == rhs)
+      return *this;
+    return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
+  }
+  llvm_unreachable("Unknown AffineExpr");
+}
+
+/// Sparse replace method. Return the modified expression tree.
+AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
+  DenseMap<AffineExpr, AffineExpr> map;
+  map.insert(std::make_pair(expr, replacement));
+  return replace(map);
+}
 /// Returns true if this expression is made out of only symbols and
 /// constants (no dimensional identifiers).
 bool AffineExpr::isSymbolicOrConstant() const {

diff  --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
new file mode 100644
index 000000000000..84c56ee3d840
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns
+//| FileCheck %s
+
+// CHECK-LABEL: scf_for
+func @scf_for(%A : memref<i64>, %step : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c7 = constant 7 : index
+  %c4 = constant 4 : index
+  %c16 = constant 16 : index
+  %c1024 = constant 1024 : index
+
+  //      CHECK: scf.for
+  // CHECK-NEXT:   %[[C2:.*]] = constant 2 : index
+  // CHECK-NEXT:   %[[C2I64:.*]] = index_cast %[[C2:.*]]
+  // CHECK-NEXT:   store %[[C2I64]], %{{.*}}[] : memref<i64>
+  scf.for %i = %c0 to %c4 step %c2 {
+    %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  //      CHECK: scf.for
+  // CHECK-NEXT:   %[[C2:.*]] = constant 2 : index
+  // CHECK-NEXT:   %[[C2I64:.*]] = index_cast %[[C2:.*]]
+  // CHECK-NEXT:   store %[[C2I64]], %{{.*}}[] : memref<i64>
+  scf.for %i = %c1 to %c7 step %c2 {
+    %1 = affine.min affine_map<(d0)[s0] -> (s0 - d0, 2)> (%i)[%c7]
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This should not canonicalize because: 4 - %i may take the value 1 < 2.
+  //     CHECK:   scf.for
+  //     CHECK:     affine.min
+  //     CHECK:     index_cast
+  scf.for %i = %c1 to %c4 step %c2 {
+    %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c4]
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This should not canonicalize because: 16 - %i may take the value 15 < 1024.
+  //     CHECK:   scf.for
+  //     CHECK:     affine.min
+  //     CHECK:     index_cast
+  scf.for %i = %c1 to %c16 step %c1024 {
+    %1 = affine.min affine_map<(d0) -> (1024, 16 - d0)> (%i)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This example should simplify but affine_map is currently missing
+  // semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0`
+  // should evaluate to 41 * s0.
+  // Note that this may require positivity assumptions on `s0`.
+  // Revisit when support is added.
+  // CHECK: scf.for
+  // CHECK:   affine.min
+  // CHECK:   index_cast
+  %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step)
+  scf.for %i = %c0 to %ub step %step {
+    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d1 - d2)> (%step, %ub, %i)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This example should simplify but affine_map is currently missing
+  // semi-affine canonicalizations.
+  // This example should simplify but affine_map is currently missing
+  // semi-affine canonicalizations: ` -(((s0 * s0 - 1) floordiv s0) * s0)`
+  // should evaluate to (s0 - 1) * s0.
+  // Note that this may require positivity assumptions on `s0`.
+  // Revisit when support is added.
+  // CHECK: scf.for
+  // CHECK:   affine.min
+  // CHECK:   index_cast
+  %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step]
+  scf.for %i = %c0 to %ub2 step %step {
+    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  return
+}
+
+// CHECK-LABEL: scf_parallel
+func @scf_parallel(%A : memref<i64>, %step : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c7 = constant 7 : index
+  %c4 = constant 4 : index
+
+  // CHECK: scf.parallel
+  // CHECK-NEXT:   %[[C2:.*]] = constant 2 : index
+  // CHECK-NEXT:   %[[C2I64:.*]] = index_cast %[[C2:.*]]
+  // CHECK-NEXT:   store %[[C2I64]], %{{.*}}[] : memref<i64>
+  scf.parallel (%i) = (%c0) to (%c4) step (%c2) {
+    %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // CHECK: scf.parallel
+  // CHECK-NEXT:   %[[C2:.*]] = constant 2 : index
+  // CHECK-NEXT:   %[[C2I64:.*]] = index_cast %[[C2:.*]]
+  // CHECK-NEXT:   store %[[C2I64]], %{{.*}}[] : memref<i64>
+  scf.parallel (%i) = (%c1) to (%c7) step (%c2) {
+    %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7]
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This example should simplify but affine_map is currently missing
+  // semi-affine canonicalizations.
+  // This affine map does not currently evaluate to (0, 0):
+  //   (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * 42) mod s0)
+  // TODO: Revisit when support is added.
+  // CHECK: scf.parallel
+  // CHECK:   affine.min
+  // CHECK:   index_cast
+  %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step)
+  scf.parallel (%i) = (%c0) to (%ub) step (%step) {
+    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  // This example should simplify but affine_map is currently missing
+  // semi-affine canonicalizations.
+  // This affine map does not currently evaluate to (0, 0):
+  //   (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * s0) mod s0)
+  // TODO: Revisit when support is added.
+  // CHECK: scf.parallel
+  // CHECK:   affine.min
+  // CHECK:   index_cast
+  %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step]
+  scf.parallel (%i) = (%c0) to (%ub2) step (%step) {
+    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2)
+    %2 = index_cast %1: index to i64
+    store %2, %A[]: memref<i64>
+  }
+
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index e356eb72fa42..ff37110f093a 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -59,6 +59,10 @@ struct TestLinalgTransforms
       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
                      "in vector.contract form"),
       llvm::cl::init(false)};
+  Option<bool> testAffineMinSCFCanonicalizationPatterns{
+      *this, "test-affine-min-scf-canonicalization-patterns",
+      llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -316,6 +320,15 @@ static void applyContractionToVectorPatterns(FuncOp funcOp) {
   applyPatternsAndFoldGreedily(funcOp, patterns);
 }
 
+static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
+  OwningRewritePatternList foldPattern;
+  foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
+  // Explicitly walk and apply the pattern locally to avoid more general folding
+  // on the rest of the IR.
+  funcOp.walk([&foldPattern](AffineMinOp minOp) {
+    applyOpPatternsAndFold(minOp, foldPattern);
+  });
+}
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   auto lambda = [&](void *) {
@@ -341,6 +354,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyVectorTransferForwardingPatterns(getFunction());
   if (testGenericToVectorPattern)
     return applyContractionToVectorPatterns(getFunction());
+  if (testAffineMinSCFCanonicalizationPatterns)
+    return applyAffineMinSCFCanonicalizationPatterns(getFunction());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list