[Mlir-commits] [mlir] [mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface … (PR #145732)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 09:06:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

<details>
<summary>Changes</summary>

…and add a PadOp::FoldReifiedShape canonicalization

---
Full diff: https://github.com/llvm/llvm-project/pull/145732.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1) 
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+67-1) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..821384eb7d15a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..c66110f6915e9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
 /// Convert `arrayAttr` to a vector of OpFoldResult.
 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
+// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
+// values);
 
 /// Convert int64_t to integer attributes of index type and return them as
 /// OpFoldResult.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..95468caa87f18 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3791,13 +3792,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
   }
 };
 
+struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    if (padOp.getNofold()) {
+      return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+    }
+
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
+    if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
+      return failure();
+
+    SmallVector<int64_t> newShape;
+    for (const auto &[s, ofr] : llvm::zip_equal(
+             padOp.getResultType().getShape(), reifiedResultShapes.front())) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // Reification does not add static information, just use existing shape.
+      if (!maybeCst.has_value()) {
+        newShape.push_back(s);
+        continue;
+      }
+      int64_t cst = *maybeCst;
+      assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
+      newShape.push_back(cst);
+    }
+    if (newShape == padOp.getResultType().getShape())
+      return failure();
+
+    Type oldType = padOp.getResultType();
+    Type newType =
+        RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
+    Location loc = padOp->getLoc();
+    Operation *newPad = rewriter.clone(*padOp);
+    newPad->getResult(0).setType(newType);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
+                                                newPad->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+                         ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  SmallVector<OpFoldResult> lp = getMixedLowPad();
+  SmallVector<OpFoldResult> hp = getMixedHighPad();
+  for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+      continue;
+    }
+    Location loc = getLoc();
+    Value dim = b.createOrFold<tensor::DimOp>(
+        loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+    affine::AffineBuilder ab(b, loc);
+    AffineExpr d0, d1, d2;
+    bindDims(b.getContext(), d0, d1, d2);
+    reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+        b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+  }
+  return success();
+}
+
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
               FoldOrthogonalPaddings, FoldStaticPadding,
-              FoldConsecutiveConstantPadding>(context);
+              FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..358c1c214a3b1 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2543,3 +2543,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
 //  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>
 //       CHECK:   return %[[RES]]
+
+// -----
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+    -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+//       CHECK: tensor.pad
+//       CHECK:   : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+  ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/145732


More information about the Mlir-commits mailing list