[Mlir-commits] [mlir] [mlir][tensor] Support padding with poison (PR #152003)
James Newling
llvmlistbot at llvm.org
Wed Aug 6 10:52:43 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/152003
>From 6ca5fa64d35d8f49b2329b536ea565a1d7311848 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 4 Aug 2025 09:25:18 -0700
Subject: [PATCH 1/2] ability to use poison as padding value
---
.../Linalg/Transforms/PadTilingInterface.cpp | 18 ++++++++++++------
.../transform-op-pad-tiling-interface.mlir | 4 +++-
2 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..2857b53103779 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,6 +4,7 @@
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
{
+ // %goo = ub.poison : f32
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
@@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
- padding_values=[0.0 : f32, 0.0 : f32]
+ // padding_values= [poison, 0.0 : f32]
+ padding_values= [0.0 : f32, 0.0 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
>From 46b37d0b300cdd6d86dd38442a471b3570576a14 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 5 Aug 2025 19:59:22 -0700
Subject: [PATCH 2/2] add test with poison
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../TransformOps/LinalgTransformOps.cpp | 24 ++++++++++++++-----
.../transform-op-pad-tiling-interface.mlir | 14 +++++------
2 files changed, 24 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bdfc8d020e58f..4c2686aea0794 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
- for (auto const &it :
+ for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
+ auto attr = dyn_cast<TypedAttr>(untypedAttr);
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ Type elementType = getElementTypeOrSelf(elementOrTensorType);
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
- << elementType << ", got " << std::get<0>(it);
+ << elementType << ", got " << untypedAttr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 2857b53103779..9a3dcf0b485d5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,7 +4,6 @@
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
{
- // %goo = ub.poison : f32
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
@@ -15,12 +14,11 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
// Tile to 5 then pad to 8
- %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
+ %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
- // padding_values= [poison, 0.0 : f32]
- padding_values= [0.0 : f32, 0.0 : f32]
+ padding_values= [#ub.poison, 0.0 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -35,9 +33,9 @@ func.func @pad_lhs(
-> tensor<24x25xf32>
{
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x12xf32> to tensor<8x12xf32>
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<8x25xf32>
// CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
// CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -94,7 +92,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
@@ -149,7 +147,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
More information about the Mlir-commits
mailing list