[Mlir-commits] [mlir] Fix support for complex types in `transform.structured.pad` (PR #139841)
Christopher Bate
llvmlistbot at llvm.org
Tue May 13 21:47:38 PDT 2025
https://github.com/christopherbate created https://github.com/llvm/llvm-project/pull/139841
Fixes verification of the pad element attribute when the operand(s) have element type `complex<...>`.
>From da5bcb522520497ab2f98c0c2c6fd17a1f56fb8a Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Wed, 14 May 2025 04:35:18 +0000
Subject: [PATCH] Fix support for complex types in `transform.structured.pad`
Fixes verification of the pad element attribute when the operand(s)
have element type `complex<...>`.
---
.../TransformOps/LinalgTransformOps.cpp | 16 ++++++--
.../test/Dialect/Linalg/transform-op-pad.mlir | 37 +++++++++++++++++++
2 files changed, 49 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..ea02886c1b65a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Attribute> paddingValues;
for (auto const &it :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
- if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ Attribute attr = std::get<0>(it);
+ if (!llvm::isa<TypedAttr, ArrayAttr>(attr)) {
+ emitOpError("expects padding values to be typed attributes or array "
+ "attributes (for complex numbers)");
return DiagnosedSilenceableFailure::definiteFailure();
}
Type elementType = getElementTypeOrSelf(std::get<1>(it));
@@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
continue;
}
// Otherwise, add the attribute directly.
- if (attr.getType() != elementType) {
+ if (isa<TypedAttr>(attr) &&
+ cast<TypedAttr>(attr).getType() != elementType) {
+ auto diag = this->emitOpError("expects a padding value of type ")
+ << elementType << ", got " << attr;
+ diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ if (isa<ArrayAttr>(attr) && !isa<ComplexType>(elementType)) {
auto diag = this->emitOpError("expects a padding value of type ")
<< elementType << ", got " << attr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..c838713f368a3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!type = tensor<10x10xcomplex<f32>>
+// CHECK-LABEL: @pad_matmul
+func.func @pad_matmul(%arg0: !type,
+ %arg1: !type,
+ %arg2: !type
+ ) -> !type {
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: linalg.matmul
+ %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type
+ func.return %0 : !type
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] {
+ padding_values=[
+ [0.1 : f32, 0.2 : f32],
+ [0.3 : f32, 0.4 : f32],
+ [0.5 : f32, 0.6 : f32]
+ ],
+ padding_dimensions = [0, 1, 2]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list