[Mlir-commits] [mlir] Fix support for complex types in `transform.structured.pad` (PR #139841)

Christopher Bate llvmlistbot at llvm.org
Tue May 13 21:47:38 PDT 2025


https://github.com/christopherbate created https://github.com/llvm/llvm-project/pull/139841

Fixes verification of the pad element attribute when the operand(s) have element type `complex<...>`.

>From da5bcb522520497ab2f98c0c2c6fd17a1f56fb8a Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Wed, 14 May 2025 04:35:18 +0000
Subject: [PATCH] Fix support for complex types in `transform.structured.pad`

Fixes verification of the pad element attribute when the operand(s)
have element type `complex<...>`.
---
 .../TransformOps/LinalgTransformOps.cpp       | 16 ++++++--
 .../test/Dialect/Linalg/transform-op-pad.mlir | 37 +++++++++++++++++++
 2 files changed, 49 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..ea02886c1b65a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     SmallVector<Attribute> paddingValues;
     for (auto const &it :
          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
-      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
-      if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+      Attribute attr = std::get<0>(it);
+      if (!llvm::isa<TypedAttr, ArrayAttr>(attr)) {
+        emitOpError("expects padding values to be typed attributes or array "
+                    "attributes (for complex numbers)");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
       Type elementType = getElementTypeOrSelf(std::get<1>(it));
@@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
         continue;
       }
       // Otherwise, add the attribute directly.
-      if (attr.getType() != elementType) {
+      if (isa<TypedAttr>(attr) &&
+          cast<TypedAttr>(attr).getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      if (isa<ArrayAttr>(attr) && !isa<ComplexType>(elementType)) {
         auto diag = this->emitOpError("expects a padding value of type ")
                     << elementType << ", got " << attr;
         diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..c838713f368a3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!type = tensor<10x10xcomplex<f32>>
+// CHECK-LABEL: @pad_matmul
+func.func @pad_matmul(%arg0: !type,
+                           %arg1: !type,
+                           %arg2: !type
+                           ) -> !type {  
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: linalg.matmul
+  %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type
+  func.return %0 : !type
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] {
+      padding_values=[
+        [0.1 : f32, 0.2 : f32],
+        [0.3 : f32, 0.4 : f32],
+        [0.5 : f32, 0.6 : f32]
+      ],      
+      padding_dimensions = [0, 1, 2]      
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list