[Mlir-commits] [mlir] [mlir][tensor] Support padding with poison (PR #152003)

James Newling llvmlistbot at llvm.org
Wed Aug 6 10:52:43 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/152003

>From 6ca5fa64d35d8f49b2329b536ea565a1d7311848 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 4 Aug 2025 09:25:18 -0700
Subject: [PATCH 1/2] ability to use poison as padding value

---
 .../Linalg/Transforms/PadTilingInterface.cpp   | 18 ++++++++++++------
 .../transform-op-pad-tiling-interface.mlir     |  4 +++-
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
   Value paddingValue;
   if (auto complexTy =
           dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
-    auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
-    paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                               complexTy, complexAttr);
-  } else {
-    paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                             cast<TypedAttr>(paddingValueAttr));
+    if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+      paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+                                                 complexTy, complexAttr);
+    }
+  } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+    paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+                                        getElementTypeOrSelf(v.getType()));
+  } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+    paddingValue =
+        arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
   }
+  assert(paddingValue && "failed to create value from padding attribute");
 
   // Pad the operand to the bounding box defined by `paddedShape`.
   SmallVector<int64_t> tensorShape;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..2857b53103779 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,6 +4,7 @@
 //           CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
 func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
 {
+  // %goo = ub.poison : f32
   %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
@@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} {
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
-      padding_values=[0.0 : f32, 0.0 : f32]
+      // padding_values= [poison, 0.0 : f32]
+      padding_values= [0.0 : f32, 0.0 : f32]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     transform.yield

>From 46b37d0b300cdd6d86dd38442a471b3570576a14 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 5 Aug 2025 19:59:22 -0700
Subject: [PATCH 2/2] add test with poison

Signed-off-by: James Newling <james.newling at gmail.com>
---
 .../TransformOps/LinalgTransformOps.cpp       | 24 ++++++++++++++-----
 .../transform-op-pad-tiling-interface.mlir    | 14 +++++------
 2 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bdfc8d020e58f..4c2686aea0794 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
 
     // Convert the padding values to attributes.
     SmallVector<Attribute> paddingValues;
-    for (auto const &it :
+    for (auto const &[untypedAttr, elementOrTensorType] :
          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
-      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+      if (isa<ub::PoisonAttr>(untypedAttr)) {
+        paddingValues.push_back(untypedAttr);
+        continue;
+      }
+      auto attr = dyn_cast<TypedAttr>(untypedAttr);
       if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+        emitOpError("expects padding values to be typed attributes or poison");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
-      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      Type elementType = getElementTypeOrSelf(elementOrTensorType);
       // Try to parse string attributes to obtain an attribute of element type.
       if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
         auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
             /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
         if (!parsedAttr || parsedAttr.getType() != elementType) {
           auto diag = this->emitOpError("expects a padding that parses to ")
-                      << elementType << ", got " << std::get<0>(it);
+                      << elementType << ", got " << untypedAttr;
           diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
           return DiagnosedSilenceableFailure::definiteFailure();
         }
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
          llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
       auto attr = dyn_cast<TypedAttr>(untypedAttr);
       Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+      if (isa<ub::PoisonAttr>(untypedAttr)) {
+        paddingValues.push_back(untypedAttr);
+        continue;
+      }
       if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+        emitOpError("expects padding values to be typed attributes or poison");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
       // Try to parse string attributes to obtain an attribute of element type.
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 2857b53103779..9a3dcf0b485d5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,7 +4,6 @@
 //           CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
 func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
 {
-  // %goo = ub.poison : f32
   %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
@@ -15,12 +14,11 @@ module attributes {transform.with_named_sequence} {
       : (!transform.any_op) -> !transform.any_op
 
     // Tile to 5 then pad to 8
-    %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] 
+    %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
-      // padding_values= [poison, 0.0 : f32]
-      padding_values= [0.0 : f32, 0.0 : f32]
+      padding_values= [#ub.poison, 0.0 : f32]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     transform.yield
@@ -35,9 +33,9 @@ func.func @pad_lhs(
      -> tensor<24x25xf32>
 {
   //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
-  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:   tensor.pad %{{.*}}
   //      CHECK:     : tensor<?x12xf32> to tensor<8x12xf32>
-  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:   tensor.pad %{{.*}}
   //      CHECK:     : tensor<?x25xf32> to tensor<8x25xf32>
   //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
   //      CHECK:   tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -94,7 +92,7 @@ module {
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }
@@ -149,7 +147,7 @@ module {
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }



More information about the Mlir-commits mailing list