[Mlir-commits] [mlir] b574bcf - [mlir][TD] Support padding with poison (#152003)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 8 09:09:06 PDT 2025
Author: James Newling
Date: 2025-08-08T09:09:03-07:00
New Revision: b574bcf0361de60ef8c183c583a9b59a0f5cccca
URL: https://github.com/llvm/llvm-project/commit/b574bcf0361de60ef8c183c583a9b59a0f5cccca
DIFF: https://github.com/llvm/llvm-project/commit/b574bcf0361de60ef8c183c583a9b59a0f5cccca.diff
LOG: [mlir][TD] Support padding with poison (#152003)
Signed-off-by: James Newling <james.newling at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 87547436eb474..639e0feabc9bd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
- for (auto const &it :
+ for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
+ auto attr = dyn_cast<TypedAttr>(untypedAttr);
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ Type elementType = getElementTypeOrSelf(elementOrTensorType);
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
- << elementType << ", got " << std::get<0>(it);
+ << elementType << ", got " << untypedAttr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..9a3dcf0b485d5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -14,11 +14,11 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
// Tile to 5 then pad to 8
- %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
+ %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
- padding_values=[0.0 : f32, 0.0 : f32]
+ padding_values= [#ub.poison, 0.0 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -33,9 +33,9 @@ func.func @pad_lhs(
-> tensor<24x25xf32>
{
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x12xf32> to tensor<8x12xf32>
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<8x25xf32>
// CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
// CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -92,7 +92,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
@@ -147,7 +147,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
More information about the Mlir-commits
mailing list