[Mlir-commits] [mlir] [mlir][affine] Fix min simplification in makeComposedAffineApply (PR #145376)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 24 03:02:09 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145376

>From 2bfda64889211fe6edfe02e1ddef8c4ddf26d095 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 23 Jun 2025 17:58:31 +0000
Subject: [PATCH 1/2] [mlir][affine] Fix min simplification in
 makeComposedAffineApply

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 108 ++++++++++++++---------
 1 file changed, 65 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3b4d51d914d86..f577883085608 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1046,59 +1046,81 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
                        map.getContext());
 }
 
-/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`.
-/// Assuming that the quantity is of the form:
-///   `affine_min(f(x, y), symbolic_cst)`.
-/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and
-/// proceeds with replacing the patterns:
+/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
+/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
+/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
 /// ```
-///   dimOrSym.ceildiv(symbolic_cst)
-///   (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst)
+///   dimOrSym.ceildiv(x_k)
+///   (dimOrSym + x_k - 1).floordiv(x_k)
 /// ```
-/// by `1`.
+/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
 ///
-/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
-/// inject static information that may not be statically discoverable.
 ///
 /// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
-/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
-static LogicalResult replaceAffineMinBoundingBoxExpression(
-    AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
-    bool affineMinKnownToBeNonNegative = false) {
-  auto affineMinMap = minOp.getAffineMap();
-  if (!affineMinKnownToBeNonNegative) {
-    ValueRange values = minOp->getOperands();
-    for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
-      AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
-      FailureOr<int64_t> lowerBound =
-          ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::LB, {row, values},
-              /*stopCondition=*/nullptr,
-              /*closedUB=*/true);
-      if (failed(lowerBound) || lowerBound.value() <= 0)
-        return failure();
+/// `minOp` is positive.
+static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
+                                                           AffineExpr dimOrSym,
+                                                           AffineMap *map,
+                                                           ValueRange dims,
+                                                           ValueRange syms) {
+  AffineMap affineMinMap = minOp.getAffineMap();
+
+  // Check the value is positive.
+  for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
+    // Compare each expression in the minimum against 0.
+    if (!ValueBoundsConstraintSet::compare(
+            getAsIndexOpFoldResult(minOp.getContext(), 0),
+            ValueBoundsConstraintSet::ComparisonOperator::LT,
+            ValueBoundsConstraintSet::Variable(affineMinMap.getSliceMap(i, 1),
+                                               minOp.getOperands())))
+      return failure();
+  }
+
+  /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
+  /// the apply op affine map.
+  DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
+  SmallVector<unsigned> unmappedDims, unmappedSyms;
+  for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
+    auto it = llvm::find(dims, dim);
+    if (it == dims.end()) {
+      unmappedDims.push_back(i);
+      continue;
     }
+    dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
+        getAffineDimExpr(it.getIndex(), minOp.getContext());
+  }
+  for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
+    auto it = llvm::find(syms, sym);
+    if (it == syms.end()) {
+      unmappedSyms.push_back(i);
+      continue;
+    }
+    dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
+        getAffineSymbolExpr(it.getIndex(), minOp.getContext());
   }
 
-  AffineMap initialMap = *map;
-  for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
-    auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
-    AffineExpr expr = m.getResult(0);
-    if (!expr.isSymbolicOrConstant())
+  // Create the replacement map.
+  DenseMap<AffineExpr, AffineExpr> repl;
+  AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
+  for (AffineExpr expr : affineMinMap.getResults()) {
+    // If we cannot express the result in terms of the apply map symbols and
+    // sims then continue.
+    if (llvm::any_of(unmappedDims,
+                     [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
+        llvm::any_of(unmappedSyms,
+                     [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
       continue;
 
-    DenseMap<AffineExpr, AffineExpr> repl;
+    AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
+
     // dimOrSym.ceilDiv(expr) -> 1
-    repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext());
+    repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
     // (dimOrSym + expr - 1).floorDiv(expr) -> 1
-    repl[(dimOrSym + expr - 1).floorDiv(expr)] =
-        getAffineConstantExpr(1, minOp.getContext());
-    auto newMap = map->replace(repl);
-    if (newMap == *map)
-      continue;
-    *map = newMap;
+    repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
   }
-
+  AffineMap initialMap = *map;
+  *map = initialMap.replace(repl, initialMap.getNumDims(),
+                            initialMap.getNumSymbols());
   return success(*map != initialMap);
 }
 
@@ -1127,11 +1149,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
   if (!v)
     return failure();
 
-  auto minOp = v.getDefiningOp<AffineMinOp>();
-  if (minOp && replaceAffineMin) {
+  if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
     AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
                                            : getAffineSymbolExpr(pos, ctx);
-    return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
+    return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
+                                                 syms);
   }
 
   auto affineApply = v.getDefiningOp<AffineApplyOp>();

>From fdd870d2cfabea1018d6d414b384e44c9df4a30d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Tue, 24 Jun 2025 12:01:54 +0200
Subject: [PATCH 2/2] Add negative that should not fold

---
 .../make-composed-folded-affine-apply.mlir    | 29 ++++++++++++++++++
 .../test/lib/Transforms/TestTransformsOps.cpp | 25 +++++++++++++++-
 mlir/test/lib/Transforms/TestTransformsOps.td | 30 +++++++++++++++++++
 3 files changed, 83 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Transforms/make-composed-folded-affine-apply.mlir

diff --git a/mlir/test/Transforms/make-composed-folded-affine-apply.mlir b/mlir/test/Transforms/make-composed-folded-affine-apply.mlir
new file mode 100644
index 0000000000000..ce93a1550949b
--- /dev/null
+++ b/mlir/test/Transforms/make-composed-folded-affine-apply.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<()[s0, s1] -> (s1)>
+#map1 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map2 = affine_map<() -> (1)>
+module {
+  func.func @min_max_full_simplify() -> (index, index) {
+    %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+    %1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+    %2 = affine.min #map()[%0, %1]
+    // Make compose affine affine.apply affine_map<()[s0, s1] -> (s0 ceildiv s1)>()[%2, %0]
+    // No min folder:
+    %3 = affine.apply #map1()[%2, %0]
+    // Min folder on.
+    %4 = affine.apply #map2()
+    return %3, %4 : index, index
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op = transform.structured.match ops{["affine.affine_apply"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %folded = transform.test.make_composed_folded_affine_apply %op
+        : (!transform.any_op) -> !transform.any_op
+    transform.print %folded {name = "folded: " } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index c05b32bed9b94..6e851b9a1aac6 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -11,9 +11,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Transforms/RegionUtils.h"
 
 #define GET_OP_CLASSES
@@ -56,6 +58,27 @@ transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// Test affine functionality.
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform::TestMakeComposedFoldedAffineApply::applyToOne(
+    TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
+    ApplyToEachResultList &results, TransformState &state) {
+  Location loc = affineApplyOp.getLoc();
+  OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, affineApplyOp.getAffineMap(),
+      getAsOpFoldResult(affineApplyOp.getOperands()),
+      /*composeAffineMin=*/true);
+  if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+    results.push_back(v.getDefiningOp());
+  } else {
+    results.push_back(rewriter.create<arith::ConstantIndexOp>(
+        loc, getConstantIntValue(ofr).value()));
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 
 class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index 495579b452dfc..efb85f7e2906e 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -59,5 +59,35 @@ def TestMoveValueDefns :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test affine functionality.
+//===----------------------------------------------------------------------===//
+def TestMakeComposedFoldedAffineApply :
+    Op<Transform_Dialect, "test.make_composed_folded_affine_apply",
+        [FunctionalStyleTransformOpTrait, 
+         MemoryEffectsOpInterface,
+         TransformOpInterface,
+         TransformEachOpTrait,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Rewrite an affine_apply by using the makeComposedFoldedAffineApply API.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$op);
+  
+  let results = (outs TransformHandleTypeInterface:$composed);
+
+  let assemblyFormat = [{
+    $op attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::affine::AffineApplyOp affineApplyOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
 
 #endif // TEST_TRANSFORM_OPS



More information about the Mlir-commits mailing list