[Mlir-commits] [mlir] [mlir][transform] Plumb a simplified form of AffineMin folding into t… (PR #145170)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 23 03:01:45 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145170

>From d2f87802abea08fc26e3907a8f09220b8ff50482 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 18 Jun 2025 19:14:31 +0200
Subject: [PATCH 1/2] [mlir][transform] Plumb a simplified form of AffineMin
 folding into transform.pad-tiling-interface

This revision introduces a simple variant of AffineMin folding in makeComposedFoldedAffineApply
and makes use of it in transform.pad-tiling-interface.
Since this version explicitly call ValueBoundsInterface, it may be too expensive and is
only activate behind a flag.
It results in better foldings when mixing tiling and padding, including with dynamic shapes.
---
 .../mlir/Dialect/Affine/IR/AffineOps.h        |  18 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 134 ++++++++++----
 .../Linalg/Transforms/PadTilingInterface.cpp  |   5 +-
 ...m-op-pad-tiling-interface-multiple-of.mlir | 164 ++++++++++++++++--
 4 files changed, 271 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 6fdb72c370e6d..2091faa6b0b02 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -410,9 +410,11 @@ void canonicalizeSetAndOperands(IntegerSet *set,
 /// other AffineApplyOps supplying those operands. The operands of the resulting
 /// AffineApplyOp do not change the length of  AffineApplyOp chains.
 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
-                                      ArrayRef<OpFoldResult> operands);
+                                      ArrayRef<OpFoldResult> operands,
+                                      bool composeAffineMin = false);
 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
-                                      ArrayRef<OpFoldResult> operands);
+                                      ArrayRef<OpFoldResult> operands,
+                                      bool composeAffineMin = false);
 
 /// Constructs an AffineApplyOp that applies `map` to `operands` after composing
 /// the map with the maps of any other AffineApplyOp supplying the operands,
@@ -421,16 +423,19 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
 /// map.
 OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
                                            AffineMap map,
-                                           ArrayRef<OpFoldResult> operands);
+                                           ArrayRef<OpFoldResult> operands,
+                                           bool composeAffineMin = false);
 /// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
 OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
                                            AffineExpr expr,
-                                           ArrayRef<OpFoldResult> operands);
+                                           ArrayRef<OpFoldResult> operands,
+                                           bool composeAffineMin = false);
 /// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
 /// Note that this may create as many affine.apply operations as the map has
 /// results given that affine.apply must be single-result.
 SmallVector<OpFoldResult> makeComposedFoldedMultiResultAffineApply(
-    OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands);
+    OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
+    bool composeAffineMin = false);
 
 /// Returns an AffineMinOp obtained by composing `map` and `operands` with
 /// AffineApplyOps supplying those operands.
@@ -459,7 +464,8 @@ OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
 /// terminal symbol, i.e., a symbol defined at the top level or a block/function
 /// argument.
 void fullyComposeAffineMapAndOperands(AffineMap *map,
-                                      SmallVectorImpl<Value> *operands);
+                                      SmallVectorImpl<Value> *operands,
+                                      bool composeAffineMin = false);
 
 } // namespace affine
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 48770d4f4ff7b..5c06159121b81 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -11,12 +11,14 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -26,7 +28,9 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
+#include <limits>
 #include <numeric>
 #include <optional>
 
@@ -1042,6 +1046,59 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
                        map.getContext());
 }
 
+/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`,
+/// replaces the patterns:
+/// ```
+///   dimOrSym.ceildiv(cst) * cst
+///   (dimOrSym + cst - 1).floordiv(cst) * cst
+/// ```
+/// by `cst` in `map`.
+/// This simplification is valid iff `minOp` is guaranteed to be nonnegative.
+/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
+/// inject static information that may not be statically discoverable.
+/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
+/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
+static LogicalResult replaceAffineMinBoundingBoxExpression(
+    AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
+    bool affineMinKnownToBeNonNegative = false) {
+  auto affineMinMap = minOp.getAffineMap();
+  if (!affineMinKnownToBeNonNegative) {
+    ValueRange values = minOp->getOperands();
+    for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
+      AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
+      FailureOr<int64_t> lowerBound =
+          ValueBoundsConstraintSet::computeConstantBound(
+              presburger::BoundType::LB, {row, values},
+              /*stopCondition=*/nullptr,
+              /*closedUB=*/true);
+      if (failed(lowerBound) || lowerBound.value() < 0)
+        return failure();
+    }
+  }
+
+  AffineMap initialMap = *map;
+  for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
+    auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
+    // TODO: this should also work with nonnegative symbolic divisors.
+    if (!m.isSingleConstant())
+      continue;
+
+    auto cst = m.getSingleConstantResult();
+    DenseMap<AffineExpr, AffineExpr> repl;
+    // dimOrSym.ceilDiv(cst) * cst -> cst
+    repl[dimOrSym.ceilDiv(cst) * cst] =
+        getAffineConstantExpr(cst, minOp.getContext());
+    // (dimOrSym + cst - 1).floorDiv(cst) * cst -> cst
+    repl[(dimOrSym + cst - 1).floorDiv(cst) * cst] =
+        getAffineConstantExpr(cst, minOp.getContext());
+    auto newMap = map->replace(repl);
+    if (newMap == *map)
+      continue;
+    *map = newMap;
+  }
+  return success(*map != initialMap);
+}
+
 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
 /// defining AffineApplyOp expression and operands.
 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1052,10 +1109,13 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
 ///   2. `map` dim and symbols are gradually shifted to higher positions.
 ///   3. Old `dim` and `sym` entries are replaced by nullptr
 /// This avoids the need for any bookkeeping.
+/// If `replaceAffineMin` is set to true, additionally triggers more expensive
+/// replacements involving affine_min operations.
 static LogicalResult replaceDimOrSym(AffineMap *map,
                                      unsigned dimOrSymbolPosition,
                                      SmallVectorImpl<Value> &dims,
-                                     SmallVectorImpl<Value> &syms) {
+                                     SmallVectorImpl<Value> &syms,
+                                     bool replaceAffineMin) {
   MLIRContext *ctx = map->getContext();
   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
   unsigned pos = isDimReplacement ? dimOrSymbolPosition
@@ -1064,6 +1124,13 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
   if (!v)
     return failure();
 
+  auto minOp = v.getDefiningOp<AffineMinOp>();
+  if (minOp && replaceAffineMin) {
+    AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
+                                           : getAffineSymbolExpr(pos, ctx);
+    return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
+  }
+
   auto affineApply = v.getDefiningOp<AffineApplyOp>();
   if (!affineApply)
     return failure();
@@ -1101,7 +1168,8 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
 /// iteratively. Perform canonicalization of map and operands as well as
 /// AffineMap simplification. `map` and `operands` are mutated in place.
 static void composeAffineMapAndOperands(AffineMap *map,
-                                        SmallVectorImpl<Value> *operands) {
+                                        SmallVectorImpl<Value> *operands,
+                                        bool composeAffineMin = false) {
   if (map->getNumResults() == 0) {
     canonicalizeMapAndOperands(map, operands);
     *map = simplifyAffineMap(*map);
@@ -1122,7 +1190,8 @@ static void composeAffineMapAndOperands(AffineMap *map,
   while (true) {
     bool changed = false;
     for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
-      if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
+      if ((changed |=
+           succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
         break;
     if (!changed)
       break;
@@ -1163,38 +1232,41 @@ static void composeAffineMapAndOperands(AffineMap *map,
 }
 
 void mlir::affine::fullyComposeAffineMapAndOperands(
-    AffineMap *map, SmallVectorImpl<Value> *operands) {
+    AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
   while (llvm::any_of(*operands, [](Value v) {
     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
   })) {
-    composeAffineMapAndOperands(map, operands);
+    composeAffineMapAndOperands(map, operands, composeAffineMin);
   }
 }
 
 AffineApplyOp
 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
-                                      ArrayRef<OpFoldResult> operands) {
+                                      ArrayRef<OpFoldResult> operands,
+                                      bool composeAffineMin) {
   SmallVector<Value> valueOperands;
   map = foldAttributesIntoMap(b, map, operands, valueOperands);
-  composeAffineMapAndOperands(&map, &valueOperands);
+  composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
   assert(map);
   return b.create<AffineApplyOp>(loc, map, valueOperands);
 }
 
 AffineApplyOp
 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
-                                      ArrayRef<OpFoldResult> operands) {
+                                      ArrayRef<OpFoldResult> operands,
+                                      bool composeAffineMin) {
   return makeComposedAffineApply(
       b, loc,
       AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext())
           .front(),
-      operands);
+      operands, composeAffineMin);
 }
 
 /// Composes the given affine map with the given list of operands, pulling in
 /// the maps from any affine.apply operations that supply the operands.
 static void composeMultiResultAffineMap(AffineMap &map,
-                                        SmallVectorImpl<Value> &operands) {
+                                        SmallVectorImpl<Value> &operands,
+                                        bool composeAffineMin = false) {
   // Compose and canonicalize each expression in the map individually because
   // composition only applies to single-result maps, collecting potentially
   // duplicate operands in a single list with shifted dimensions and symbols.
@@ -1203,7 +1275,8 @@ static void composeMultiResultAffineMap(AffineMap &map,
   for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
     SmallVector<Value> submapOperands(operands.begin(), operands.end());
     AffineMap submap = map.getSubMap({i});
-    fullyComposeAffineMapAndOperands(&submap, &submapOperands);
+    fullyComposeAffineMapAndOperands(&submap, &submapOperands,
+                                     composeAffineMin);
     canonicalizeMapAndOperands(&submap, &submapOperands);
     unsigned numNewDims = submap.getNumDims();
     submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
@@ -1221,10 +1294,9 @@ static void composeMultiResultAffineMap(AffineMap &map,
   canonicalizeMapAndOperands(&map, &operands);
 }
 
-OpFoldResult
-mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
-                                            AffineMap map,
-                                            ArrayRef<OpFoldResult> operands) {
+OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
+    OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
+    bool composeAffineMin) {
   assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
 
   // Create new builder without a listener, so that no notification is
@@ -1236,7 +1308,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
 
   // Create op.
   AffineApplyOp applyOp =
-      makeComposedAffineApply(newBuilder, loc, map, operands);
+      makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
 
   // Get constant operands.
   SmallVector<Attribute> constOperands(applyOp->getNumOperands());
@@ -1256,26 +1328,25 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
   return llvm::getSingleElement(foldResults);
 }
 
-OpFoldResult
-mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
-                                            AffineExpr expr,
-                                            ArrayRef<OpFoldResult> operands) {
+OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
+    OpBuilder &b, Location loc, AffineExpr expr,
+    ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
   return makeComposedFoldedAffineApply(
       b, loc,
       AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext())
           .front(),
-      operands);
+      operands, composeAffineMin);
 }
 
 SmallVector<OpFoldResult>
 mlir::affine::makeComposedFoldedMultiResultAffineApply(
-    OpBuilder &b, Location loc, AffineMap map,
-    ArrayRef<OpFoldResult> operands) {
-  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
-                             [&](unsigned i) {
-                               return makeComposedFoldedAffineApply(
-                                   b, loc, map.getSubMap({i}), operands);
-                             });
+    OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
+    bool composeAffineMin) {
+  return llvm::map_to_vector(
+      llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
+        return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
+                                             operands, composeAffineMin);
+      });
 }
 
 template <typename OpTy>
@@ -3024,7 +3095,8 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result,
 /// `set` by composing the maps of such affine.apply ops with the integer
 /// set constraints.
 static void composeSetAndOperands(IntegerSet &set,
-                                  SmallVectorImpl<Value> &operands) {
+                                  SmallVectorImpl<Value> &operands,
+                                  bool composeAffineMin) {
   // We will simply reuse the API of the map composition by viewing the LHSs of
   // the equalities and inequalities of `set` as the affine exprs of an affine
   // map. Convert to equivalent map, compose, and convert back to set.
@@ -3035,7 +3107,7 @@ static void composeSetAndOperands(IntegerSet &set,
                     [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
     return;
 
-  composeAffineMapAndOperands(&map, &operands);
+  composeAffineMapAndOperands(&map, &operands, composeAffineMin);
   set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
                         set.getEqFlags());
 }
@@ -3044,7 +3116,7 @@ static void composeSetAndOperands(IntegerSet &set,
 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
-  composeSetAndOperands(set, operands);
+  composeSetAndOperands(set, operands, /*composeAffineMin=*/false);
   canonicalizeSetAndOperands(&set, &operands);
 
   // Check if the canonicalization or composition led to any change.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 5383ae48aeb3a..42dac0776bace 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -84,7 +84,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
       getDimsToSize(rewriter, indexingSizes, options);
 
   // For each dimension in the operand's shape, iterate over indexingSizes and
-  // add
+  // add the various term contributions.
   for (const auto &enResults : enumerate(indexingMap.getResults())) {
     int64_t resultIndex = enResults.index();
     AffineMap partialIndexingMap = indexingMap.getSubMap(
@@ -122,7 +122,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
         AffineMap composedMap = projectedMap.compose(ceilMap);
         OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, composedMap,
-            {indexingSizes[paddingDim], paddingSize});
+            {indexingSizes[paddingDim], paddingSize},
+            /*composeAffineMin=*/true);
         terms.push_back(paddingDimOfr);
       } else {
         // Otherwise just set to paddingSize.
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 5ac35c14be3fb..845fe25193019 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -19,7 +19,7 @@ func.func @pad_lhs(
   //      CHECK:     : tensor<?x25xf32> to tensor<?x25xf32>
 
   //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<?x25xf32>) -> tensor<?x25xf32>
-  
+
   //      CHECK:   tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
   //      CHECK:     : tensor<?x25xf32> to tensor<?x25xf32>
   //      CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
@@ -29,8 +29,8 @@ func.func @pad_lhs(
 }
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // Tile to 5 then pad to 8 (supposedly to better hit vector ops).
@@ -71,13 +71,13 @@ module {
     return %0 : tensor<7x11x12xf32>
   }
   module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of {
-        padding_dimensions = [0, 2], 
+        padding_dimensions = [0, 2],
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }
@@ -126,13 +126,155 @@ module {
     return %0 : tensor<?x11x?xf32>
   }
   module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of {
-        padding_dimensions = [0, 2], 
+        padding_dimensions = [0, 2],
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> ((s0 ceildiv 16) * 16)>
+//     CHECK-LABEL: pad_lhs
+func.func @pad_lhs(
+  %arg0: tensor<24x?xf32>, %arg1: tensor<?x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  //      CHECK: %[[D0_0:.*]] = tensor.dim
+  //      CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0] high[0, %[[H0]]]
+  //      CHECK:   : tensor<24x?xf32> to tensor<24x?xf32>
+
+  //      CHECK: %[[D0_2:.*]] = tensor.dim
+  //      CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D0_2]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0] high[%[[H1]], 0]
+  //      CHECK:   : tensor<?x25xf32> to tensor<?x25xf32>
+  //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
+
+  //      CHECK:    linalg.matmul ins(%{{.*}}, %{{.*}}: tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+
+  //      CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [8, 25] [1, 1]
+  // CHECK-SAME:     : tensor<8x25xf32> into tensor<24x25xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x?xf32>, tensor<?x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    // Pad then tile should produce static shapes.
+    %matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul to padding_sizes [8, 16] pad_to_multiple_of {
+      padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0, 2]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %m, %l0, %l1 = transform.structured.tile_using_for %matmul_padded tile_sizes [8, 0, 16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    %func = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    %func2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %func
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func2 {
+      transform.apply_patterns.canonicalization
+    } {apply_cse} : !transform.any_op
+    %minmax = transform.structured.match ops{["affine.min", "affine.max"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.affine.simplify_min_max_affine_ops %minmax : !transform.any_op
+    transform.apply_patterns to %func2 {
+      transform.apply_patterns.canonicalization
+    } {apply_cse} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (-d0 + 16)>
+
+//     CHECK-LABEL: pad_lhs
+func.func @pad_lhs(
+  %arg0: tensor<24x?xf32>, %arg1: tensor<?x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
+  //      CHECK:   %[[MIN:.*]] = affine.min #[[$MAP0]](%{{.*}})
+  //      CHECK:   %[[H0:.*]] = affine.apply #[[$MAP1]](%[[MIN]])
+  //      CHECK:   tensor.pad %{{.*}} low[0, 0] high[0, %[[H0]]]
+  //      CHECK:     : tensor<8x?xf32> to tensor<8x16xf32>
+
+  //      CHECK:   %[[H1:.*]] = affine.apply #[[$MAP1]](%[[MIN]])
+  //      CHECK:   tensor.pad %{{.*}} low[0, 0] high[%[[H1]], 0]
+  //      CHECK:     : tensor<?x25xf32> to tensor<16x25xf32>
+
+  //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+
+  //      CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [8, 25] [1, 1]
+  // CHECK-SAME:     : tensor<8x25xf32> into tensor<24x25xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x?xf32>, tensor<?x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    // Tile then pad should produce static shapes.
+    %m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    %matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of {
+      padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0, 2]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (-d0 + 20, 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (-d0 + 8)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (-d0 + 16)>
+
+//     CHECK-LABEL: pad_lhs
+func.func @pad_lhs(
+  %arg0: tensor<20x?xf32>, %arg1: tensor<?x25xf32>, %arg2: tensor<20x25xf32>)
+     -> tensor<20x25xf32>
+{
+  //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<20x?xf32>, tensor<?x25xf32>) outs(%arg2 : tensor<20x25xf32>) -> tensor<20x25xf32>
+  func.return %0 : tensor<20x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    // Tile then pad should produce static shapes.
+    %m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    %matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of {
+      padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0, 2]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.yield
+  }
+}
+

>From 7959b1635e7332e9f760b8fcfb3f6490a18b6caf Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 12:01:18 +0200
Subject: [PATCH 2/2] Better description of simplification, extend to symbolic
 constant cases

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 35 +++++++++++++-----------
 1 file changed, 19 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5c06159121b81..bd9488f1ac18f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1046,16 +1046,20 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
                        map.getContext());
 }
 
-/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`,
-/// replaces the patterns:
+/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`.
+/// Assuming that the quantity is of the form:
+///   `affine_min(f(x, y), symbolic_cst)`.
+/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and
+/// proceeds with replacing the patterns:
 /// ```
-///   dimOrSym.ceildiv(cst) * cst
-///   (dimOrSym + cst - 1).floordiv(cst) * cst
+///   dimOrSym.ceildiv(symbolic_cst)
+///   (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst)
 /// ```
-/// by `cst` in `map`.
-/// This simplification is valid iff `minOp` is guaranteed to be nonnegative.
+/// by `1`.
+///
 /// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
 /// inject static information that may not be statically discoverable.
+///
 /// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
 /// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
 static LogicalResult replaceAffineMinBoundingBoxExpression(
@@ -1071,7 +1075,7 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(
               presburger::BoundType::LB, {row, values},
               /*stopCondition=*/nullptr,
               /*closedUB=*/true);
-      if (failed(lowerBound) || lowerBound.value() < 0)
+      if (failed(lowerBound) || lowerBound.value() <= 0)
         return failure();
     }
   }
@@ -1079,23 +1083,22 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(
   AffineMap initialMap = *map;
   for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
     auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
-    // TODO: this should also work with nonnegative symbolic divisors.
-    if (!m.isSingleConstant())
+    AffineExpr expr = m.getResult(0);
+    if (!expr.isSymbolicOrConstant())
       continue;
 
-    auto cst = m.getSingleConstantResult();
     DenseMap<AffineExpr, AffineExpr> repl;
-    // dimOrSym.ceilDiv(cst) * cst -> cst
-    repl[dimOrSym.ceilDiv(cst) * cst] =
-        getAffineConstantExpr(cst, minOp.getContext());
-    // (dimOrSym + cst - 1).floorDiv(cst) * cst -> cst
-    repl[(dimOrSym + cst - 1).floorDiv(cst) * cst] =
-        getAffineConstantExpr(cst, minOp.getContext());
+    // dimOrSym.ceilDiv(expr) -> 1
+    repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext());
+    // (dimOrSym + expr - 1).floorDiv(expr) -> 1
+    repl[(dimOrSym + expr - 1).floorDiv(expr)] =
+        getAffineConstantExpr(1, minOp.getContext());
     auto newMap = map->replace(repl);
     if (newMap == *map)
       continue;
     *map = newMap;
   }
+
   return success(*map != initialMap);
 }
 



More information about the Mlir-commits mailing list