[Mlir-commits] [mlir] b574bcf - [mlir][TD] Support padding with poison (#152003)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 8 09:09:06 PDT 2025


Author: James Newling
Date: 2025-08-08T09:09:03-07:00
New Revision: b574bcf0361de60ef8c183c583a9b59a0f5cccca

URL: https://github.com/llvm/llvm-project/commit/b574bcf0361de60ef8c183c583a9b59a0f5cccca
DIFF: https://github.com/llvm/llvm-project/commit/b574bcf0361de60ef8c183c583a9b59a0f5cccca.diff

LOG: [mlir][TD] Support padding with poison (#152003)

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
    mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 87547436eb474..639e0feabc9bd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
 
     // Convert the padding values to attributes.
     SmallVector<Attribute> paddingValues;
-    for (auto const &it :
+    for (auto const &[untypedAttr, elementOrTensorType] :
          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
-      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+      if (isa<ub::PoisonAttr>(untypedAttr)) {
+        paddingValues.push_back(untypedAttr);
+        continue;
+      }
+      auto attr = dyn_cast<TypedAttr>(untypedAttr);
       if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+        emitOpError("expects padding values to be typed attributes or poison");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
-      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      Type elementType = getElementTypeOrSelf(elementOrTensorType);
       // Try to parse string attributes to obtain an attribute of element type.
       if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
         auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
             /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
         if (!parsedAttr || parsedAttr.getType() != elementType) {
           auto diag = this->emitOpError("expects a padding that parses to ")
-                      << elementType << ", got " << std::get<0>(it);
+                      << elementType << ", got " << untypedAttr;
           diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
           return DiagnosedSilenceableFailure::definiteFailure();
         }
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
          llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
       auto attr = dyn_cast<TypedAttr>(untypedAttr);
       Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+      if (isa<ub::PoisonAttr>(untypedAttr)) {
+        paddingValues.push_back(untypedAttr);
+        continue;
+      }
       if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+        emitOpError("expects padding values to be typed attributes or poison");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
       // Try to parse string attributes to obtain an attribute of element type.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
   Value paddingValue;
   if (auto complexTy =
           dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
-    auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
-    paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                               complexTy, complexAttr);
-  } else {
-    paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
-                                             cast<TypedAttr>(paddingValueAttr));
+    if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+      paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+                                                 complexTy, complexAttr);
+    }
+  } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+    paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+                                        getElementTypeOrSelf(v.getType()));
+  } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+    paddingValue =
+        arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
   }
+  assert(paddingValue && "failed to create value from padding attribute");
 
   // Pad the operand to the bounding box defined by `paddedShape`.
   SmallVector<int64_t> tensorShape;

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..9a3dcf0b485d5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -14,11 +14,11 @@ module attributes {transform.with_named_sequence} {
       : (!transform.any_op) -> !transform.any_op
 
     // Tile to 5 then pad to 8
-    %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] 
+    %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
-      padding_values=[0.0 : f32, 0.0 : f32]
+      padding_values= [#ub.poison, 0.0 : f32]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     transform.yield
@@ -33,9 +33,9 @@ func.func @pad_lhs(
      -> tensor<24x25xf32>
 {
   //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
-  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:   tensor.pad %{{.*}}
   //      CHECK:     : tensor<?x12xf32> to tensor<8x12xf32>
-  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:   tensor.pad %{{.*}}
   //      CHECK:     : tensor<?x25xf32> to tensor<8x25xf32>
   //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
   //      CHECK:   tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -92,7 +92,7 @@ module {
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }
@@ -147,7 +147,7 @@ module {
       %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
         padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
       } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-      transform.yield 
+      transform.yield
     }
   }
 }


        


More information about the Mlir-commits mailing list