[Mlir-commits] [mlir] [mlir][affine] Wrap SimplifyAffineMinMax in a pass (PR #145741)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 26 06:20:31 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145741

>From 8d710669c4809e1b7f71c50a9e1ea9f954f4727b Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 25 Jun 2025 18:38:27 +0200
Subject: [PATCH] [mlir][affine] Wrap SimplifyAffineMinMax in a pass

This revision adds a pass working on FunctionOpInterface to connect recently introduced AffineMin/Max simplification patterns.
---
 mlir/include/mlir/Dialect/Affine/Passes.td    | 17 ++++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 11 +++
 .../Transforms/SimplifyAffineMinMax.cpp       | 94 ++++++++++++++++++-
 .../Dialect/Affine/simplify-min-max-ops.mlir  | 60 ++++++++++++
 .../transform-op-simplify-min-max-ops.mlir    |  2 +-
 5 files changed, 181 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/Affine/simplify-min-max-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 67f9138589c47..bc7ffe23e7f52 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -414,6 +414,23 @@ def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"
   let constructor = "mlir::affine::createSimplifyAffineStructuresPass()";
 }
 
+def SimplifyAffineMinMax : InterfacePass<"affine-simplify-min-max", "FunctionOpInterface"> {
+  let summary = "Simplify affine min/max/apply";
+  let description = [{
+    Apply the SimplifyAffineMaxOp, SimplifyAffineMinOp and SimplifyAffineApplyOp
+    patterns in addition to AffineMin/Max canonicalization patterns until a
+    fixed point is reached.
+    These patterns apply ValueBoundsOp interface on AffineMin/Max ops and
+    additional simplifications such as:
+    ```
+       min(x, y, cst) / cst -> 1
+    ```
+    when x, y, cst are all >= 0.
+    This is typically useful to extract more static informationfrom IR after
+    tiling but can also come at a cost due to Presburger-style analysis.
+  }];
+}
+
 def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
   let summary = "Lower affine operations operating on indices into more fundamental operations";
   let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f577883085608..bd4b2e56808b6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -42,6 +42,7 @@ using llvm::divideFloorSigned;
 using llvm::mod;
 
 #define DEBUG_TYPE "affine-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
 
@@ -1065,6 +1066,10 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
                                                            ValueRange syms) {
   AffineMap affineMinMap = minOp.getAffineMap();
 
+  LLVM_DEBUG({
+    DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
+  });
+
   // Check the value is positive.
   for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
     // Compare each expression in the minimum against 0.
@@ -1263,6 +1268,12 @@ void mlir::affine::fullyComposeAffineMapAndOperands(
   })) {
     composeAffineMapAndOperands(map, operands, composeAffineMin);
   }
+  // Additional trailing step for AffineMinOps in case no chains of AffineApply.
+  if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
+        return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
+      })) {
+    composeAffineMapAndOperands(map, operands, composeAffineMin);
+  }
 }
 
 AffineApplyOp
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index c992badcfa493..21443a55a35ad 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -10,13 +10,18 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/Passes.h"
+
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/IntEqClasses.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/InterleavedRange.h"
 
 #define DEBUG_TYPE "affine-min-max"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -44,6 +49,12 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
       [&](unsigned i) {
         return Variable(affineMap.getSliceMap(i, 1), operands);
       });
+  LLVM_DEBUG({
+    DBGS() << "- constructed variables are: "
+           << llvm::interleaved_array(llvm::map_range(
+                  variables, [](const Variable &v) { return v.getMap(); }))
+           << "`\n";
+  });
 
   // Get the comparison operation.
   ComparisonOperator cmpOp =
@@ -125,8 +136,17 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
   for (auto [k, bound] : bounds)
     results.push_back(bound->getMap().getResult(0));
 
-  affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
-                             results, rewriter.getContext());
+  LLVM_DEBUG({
+    DBGS() << "- starting from map: " << affineMap << "\n";
+    DBGS() << "- creating new map with: \n";
+    DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
+    DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
+    DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
+  });
+
+  affineMap =
+      AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
+                     results, rewriter.getContext());
 
   // Update the affine op.
   rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
@@ -172,3 +192,73 @@ LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
     *modified = changed;
   return success();
 }
+
+namespace {
+
+struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
+  using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineMaxOp affineOp,
+                                PatternRewriter &rewriter) const override {
+    return success(simplifyAffineMaxOp(rewriter, affineOp));
+  }
+};
+
+struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
+  using OpRewritePattern<AffineMinOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineMinOp affineOp,
+                                PatternRewriter &rewriter) const override {
+    return success(simplifyAffineMinOp(rewriter, affineOp));
+  }
+};
+
+struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
+  using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineApplyOp affineOp,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = affineOp.getAffineMap();
+    SmallVector<Value> operands{affineOp->getOperands().begin(),
+                                affineOp->getOperands().end()};
+    fullyComposeAffineMapAndOperands(&map, &operands,
+                                     /*composeAffineMin=*/true);
+
+    // No change => failure to apply.
+    if (map == affineOp.getAffineMap())
+      return failure();
+
+    rewriter.modifyOpInPlace(affineOp, [&]() {
+      affineOp.setMap(map);
+      affineOp->setOperands(operands);
+    });
+    return success();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAX
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+/// Creates a simplification pass for affine min/max/apply.
+struct SimplifyAffineMinMaxPass
+    : public affine::impl::SimplifyAffineMinMaxBase<SimplifyAffineMinMaxPass> {
+  void runOnOperation() override;
+};
+
+void SimplifyAffineMinMaxPass::runOnOperation() {
+  FunctionOpInterface func = getOperation();
+  RewritePatternSet patterns(func.getContext());
+  AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
+  AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
+  patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
+      func.getContext());
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+  if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/Affine/simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..1d0cb14618d16
--- /dev/null
+++ b/mlir/test/Dialect/Affine/simplify-min-max-ops.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt  %s  --affine-simplify-min-max --split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK-NOT: affine.min
+  // CHECK-NOT: affine.max
+  // CHECK: return %[[V0]], %[[V1]]
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+  // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @overlapping_constraints() -> (index, index) {
+  %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
+  %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
+  // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
+  %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+  %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+  return %r0, %r1 : index, index
+}
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
index 948f434f3fa5e..aa9bdf2b34eb9 100644
--- a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt  %s  --transform-interpreter | FileCheck %s
+// RUN: mlir-opt  %s  --transform-interpreter --split-input-file | FileCheck %s
 
 // CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
 // CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>



More information about the Mlir-commits mailing list