[Mlir-commits] [mlir] [mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface (PR #145867)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 26 04:28:52 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145867

>From bfabf8d16c582cd47014032b2a8281724b3d972a Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 25 Jun 2025 16:46:44 +0200
Subject: [PATCH 1/2] [mlir][tensor] Make tensor::PadOp a
 ReifyRankedShapedTypeOpInterface

---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  1 +
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 25 +++++++++++++++++++
 mlir/test/Dialect/Linalg/pad_fusion.mlir      |  6 ++---
 .../resolve-shaped-type-result-dims.mlir      |  4 +--
 .../transform-op-bufferize-to-allocation.mlir |  2 +-
 ...-rewrite-in-destination-passing-style.mlir |  6 ++---
 mlir/test/Dialect/Tensor/bufferize.mlir       |  2 +-
 .../value-bounds-op-interface-impl.mlir       | 17 +++++++++++++
 8 files changed, 53 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..47962f75558ea 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..20f48436b1901 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3793,6 +3794,30 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
 
 } // namespace
 
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+                         ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  SmallVector<OpFoldResult> lp = getMixedLowPad();
+  SmallVector<OpFoldResult> hp = getMixedHighPad();
+  for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+      continue;
+    }
+    Location loc = getLoc();
+    Value dim = b.createOrFold<tensor::DimOp>(
+        loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+    affine::AffineBuilder ab(b, loc);
+    AffineExpr d0, d1, d2;
+    bindDims(b.getContext(), d0, d1, d2);
+    reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+        b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+  }
+  return success();
+}
+
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir
index a0d9a6ded34c4..56951b365b299 100644
--- a/mlir/test/Dialect/Linalg/pad_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir
@@ -34,9 +34,9 @@ func.func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : in
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
 //  CHECK-DAG:   %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
-//  CHECK-DAG:   %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
+//  CHECK-DAG:   %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D0]], %[[ARG1]], %[[ARG3]]]
 //  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
-//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
+//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D1]], %[[ARG2]], %[[ARG4]]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[TARGET_D0]], %[[TARGET_D1]])
 //      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]]
 //  CHECK-DAG:   %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
@@ -80,7 +80,7 @@ func.func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : ind
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
 //  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
-//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
+//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D1]], %[[ARG1]], %[[ARG2]]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[TARGET_D1]])
 //      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]]
 //  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 3bc1f56d816d7..7fbb123d4e16b 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -268,9 +268,9 @@ func.func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index
 //  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //  CHECK-DAG:   %[[C12:.+]] = arith.constant 12 : index
 //      CHECK:   %[[IN_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//      CHECK:   %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
+//      CHECK:   %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[IN_DIM1]], %[[ARG1]]]
 //      CHECK:   %[[IN_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
+//      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[IN_DIM2]], %[[ARG2]]]
 //      CHECK:   return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 4d7ddc8a513c4..15a0a94079b2e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -9,7 +9,7 @@
 //   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[c50:.*]] = arith.constant 50 : index
 //   CHECK-DAG:   %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
-//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
 //   CHECK-DAG:   %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
 //       CHECK:   %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xindex>
 //       CHECK:   linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref<?x?xindex>)
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index fb37d78b50ce4..63c9f1f27517b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -119,7 +119,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
 //   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
-//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
 //   CHECK-DAG:   %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
 //       CHECK:   %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
 //       CHECK:   %[[generic:.*]] = linalg.generic
@@ -162,7 +162,7 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[c50:.*]] = arith.constant 50 : index
 //   CHECK-DAG:   %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
-//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
 //   CHECK-DAG:   %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
 //       CHECK:   %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
 //       CHECK:   %[[filled:.*]] = linalg.fill ins(%[[c50]] : index) outs(%[[empty]] : tensor<?x?xindex>)
@@ -197,7 +197,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index, %[[padding:.*]]: index
 //   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
-//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
 //   CHECK-DAG:   %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
 //       CHECK:   %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
 //       CHECK:   %[[filled:.*]] = linalg.fill ins(%[[padding]] : index) outs(%[[empty]] : tensor<?x?xindex>)
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e9c3ba7e3b970..c0adc8a49bf70 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -571,7 +571,7 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
   // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
-  // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[h1]], %[[dim0]]]
+  // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[dim0]], %[[h1]]]
   // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map_2]]()[%[[l2]], %[[h2]]]
   // CHECK:     %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
   // CHECK:     %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]]
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 6610d3180cf02..e526adb18cf1e 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -213,3 +213,20 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
   "test.compare"(%dim0, %dim1) : (index, index) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+  ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  // CHECK: arith.constant 256 : index
+  %1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
+  return
+}

>From da51d56a035e13fd578a4dd2de0d7dd54eee7747 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <Nico.Vasilache at amd.com>
Date: Thu, 26 Jun 2025 13:17:26 +0200
Subject: [PATCH 2/2] Update mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Co-authored-by: Fabian Mora <fmora.dev at gmail.com>
---
 .../Bufferization/IR/BufferizableOpInterface.cpp       | 10 +++++++---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp               |  1 -
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..e5242436a197d 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -195,9 +196,12 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
         reifiedShapes = true;
         auto &shape =
             resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
-        for (const auto &dim : enumerate(tensorType.getShape()))
-          if (ShapedType::isDynamic(dim.value()))
-            dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
+        for (const auto &dim : enumerate(tensorType.getShape())) {
+          if (ShapedType::isDynamic(dim.value())) {
+            dynamicSizes.push_back(
+                getValueOrCreateConstantIndexOp(b, loc, shape[dim.index()]));
+          }
+        }
       }
     }
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 20f48436b1901..92da0cf426a47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3809,7 +3809,6 @@ PadOp::reifyResultShapes(OpBuilder &b,
     Value dim = b.createOrFold<tensor::DimOp>(
         loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
 
-    affine::AffineBuilder ab(b, loc);
     AffineExpr d0, d1, d2;
     bindDims(b.getContext(), d0, d1, d2);
     reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(



More information about the Mlir-commits mailing list