[Mlir-commits] [mlir] [mlir][tensor] Support padding with poison (PR #152003)

James Newling llvmlistbot at llvm.org
Mon Aug 4 10:03:22 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/152003

>From 009bc910a2851b2c02281d3fe9d16994ecdb4ec2 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 4 Aug 2025 09:25:18 -0700
Subject: [PATCH] ability to use poison as padding value

---
 .../Linalg/Transforms/PadTilingInterface.cpp   | 18 ++++++++++++------
 .../transform-op-pad-tiling-interface.mlir     |  4 +++-
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
   Value paddingValue;
   if (auto complexTy =
           dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
-    auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
-    paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                               complexTy, complexAttr);
-  } else {
-    paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                             cast<TypedAttr>(paddingValueAttr));
+    if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+      paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+                                                 complexTy, complexAttr);
+    }
+  } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+    paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+                                        getElementTypeOrSelf(v.getType()));
+  } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+    paddingValue =
+        arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
   }
+  assert(paddingValue && "failed to create value from padding attribute");
 
   // Pad the operand to the bounding box defined by `paddedShape`.
   SmallVector<int64_t> tensorShape;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..2857b53103779 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,6 +4,7 @@
 //           CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
 func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
 {
+  // %goo = ub.poison : f32
   %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
@@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} {
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
-      padding_values=[0.0 : f32, 0.0 : f32]
+      // padding_values= [poison, 0.0 : f32]
+      padding_values= [0.0 : f32, 0.0 : f32]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     transform.yield



More information about the Mlir-commits mailing list