[Mlir-commits] [mlir] [mlir][affine] Wrap SimplifyAffineMinMax in a pass (PR #145741)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 26 05:05:40 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145741
>From 80bb38cab64a8770759528e452b2c0eff2966722 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 25 Jun 2025 18:38:27 +0200
Subject: [PATCH] [mlir][affine] Wrap SimplifyAffineMinMax in a pass
This revision adds a pass working on FunctionOpInterface to connect recently introduced AffineMin/Max simplification patterns.
---
mlir/include/mlir/Dialect/Affine/Passes.h | 4 +
mlir/include/mlir/Dialect/Affine/Passes.td | 17 ++++
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 11 +++
.../Transforms/SimplifyAffineMinMax.cpp | 94 ++++++++++++++++++-
.../transform-op-simplify-min-max-ops.mlir | 71 +++++++++++++-
5 files changed, 194 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 2f70f24dd3ef2..907a61170ae9f 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -43,6 +43,10 @@ enum FusionMode { Greedy, ProducerConsumer, Sibling };
std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyAffineStructuresPass();
+/// Creates a simplification pass for affine min/max/apply.
+std::unique_ptr<InterfacePass<FunctionOpInterface>>
+createSimplifyAffineMinMaxPass();
+
/// Creates a loop invariant code motion pass that hoists loop invariant
/// operations out of affine loops.
std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 67f9138589c47..bc7ffe23e7f52 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -414,6 +414,23 @@ def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"
let constructor = "mlir::affine::createSimplifyAffineStructuresPass()";
}
+def SimplifyAffineMinMax : InterfacePass<"affine-simplify-min-max", "FunctionOpInterface"> {
+ let summary = "Simplify affine min/max/apply";
+ let description = [{
+ Apply the SimplifyAffineMaxOp, SimplifyAffineMinOp and SimplifyAffineApplyOp
+ patterns in addition to AffineMin/Max canonicalization patterns until a
+ fixed point is reached.
+ These patterns apply ValueBoundsOp interface on AffineMin/Max ops and
+ additional simplifications such as:
+ ```
+ min(x, y, cst) / cst -> 1
+ ```
+ when x, y, cst are all >= 0.
+ This is typically useful to extract more static informationfrom IR after
+ tiling but can also come at a cost due to Presburger-style analysis.
+ }];
+}
+
def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
let summary = "Lower affine operations operating on indices into more fundamental operations";
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f577883085608..bd4b2e56808b6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -42,6 +42,7 @@ using llvm::divideFloorSigned;
using llvm::mod;
#define DEBUG_TYPE "affine-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
@@ -1065,6 +1066,10 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
ValueRange syms) {
AffineMap affineMinMap = minOp.getAffineMap();
+ LLVM_DEBUG({
+ DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
+ });
+
// Check the value is positive.
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
// Compare each expression in the minimum against 0.
@@ -1263,6 +1268,12 @@ void mlir::affine::fullyComposeAffineMapAndOperands(
})) {
composeAffineMapAndOperands(map, operands, composeAffineMin);
}
+ // Additional trailing step for AffineMinOps in case no chains of AffineApply.
+ if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
+ return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
+ })) {
+ composeAffineMapAndOperands(map, operands, composeAffineMin);
+ }
}
AffineApplyOp
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index c992badcfa493..21443a55a35ad 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -10,13 +10,18 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Passes.h"
+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -44,6 +49,12 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
+ LLVM_DEBUG({
+ DBGS() << "- constructed variables are: "
+ << llvm::interleaved_array(llvm::map_range(
+ variables, [](const Variable &v) { return v.getMap(); }))
+ << "`\n";
+ });
// Get the comparison operation.
ComparisonOperator cmpOp =
@@ -125,8 +136,17 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
- affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
- results, rewriter.getContext());
+ LLVM_DEBUG({
+ DBGS() << "- starting from map: " << affineMap << "\n";
+ DBGS() << "- creating new map with: \n";
+ DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
+ DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
+ DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
+ });
+
+ affineMap =
+ AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
+ results, rewriter.getContext());
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
@@ -172,3 +192,73 @@ LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
*modified = changed;
return success();
}
+
+namespace {
+
+struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
+ using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineMaxOp affineOp,
+ PatternRewriter &rewriter) const override {
+ return success(simplifyAffineMaxOp(rewriter, affineOp));
+ }
+};
+
+struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
+ using OpRewritePattern<AffineMinOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineMinOp affineOp,
+ PatternRewriter &rewriter) const override {
+ return success(simplifyAffineMinOp(rewriter, affineOp));
+ }
+};
+
+struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
+ using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineApplyOp affineOp,
+ PatternRewriter &rewriter) const override {
+ AffineMap map = affineOp.getAffineMap();
+ SmallVector<Value> operands{affineOp->getOperands().begin(),
+ affineOp->getOperands().end()};
+ fullyComposeAffineMapAndOperands(&map, &operands,
+ /*composeAffineMin=*/true);
+
+ // No change => failure to apply.
+ if (map == affineOp.getAffineMap())
+ return failure();
+
+ rewriter.modifyOpInPlace(affineOp, [&]() {
+ affineOp.setMap(map);
+ affineOp->setOperands(operands);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAX
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+/// Creates a simplification pass for affine min/max/apply.
+struct SimplifyAffineMinMaxPass
+ : public affine::impl::SimplifyAffineMinMaxBase<SimplifyAffineMinMaxPass> {
+ void runOnOperation() override;
+};
+
+void SimplifyAffineMinMaxPass::runOnOperation() {
+ FunctionOpInterface func = getOperation();
+ RewritePatternSet patterns(func.getContext());
+ AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
+ AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
+ patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
+ func.getContext());
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
+ return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
index 948f434f3fa5e..d72be27218a72 100644
--- a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
@@ -66,3 +66,72 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK-NOT: affine.min
+ // CHECK-NOT: affine.max
+ // CHECK: return %[[V0]], %[[V1]]
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @overlapping_constraints() -> (index, index) {
+ %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
+ %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ return %r0, %r1 : index, index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %ff = transform.apply_registered_pass "affine-simplify-minmax" to %func : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list