[Mlir-commits] [mlir] a5cee3e - [mlir][linalg] Add a padding case for `ComplexType`

Rob Suderman llvmlistbot at llvm.org
Mon Jul 17 17:22:23 PDT 2023


Author: Robert Suderman
Date: 2023-07-17T17:20:38-07:00
New Revision: a5cee3e386bde28ce21ff2ead3fc420f018604ca

URL: https://github.com/llvm/llvm-project/commit/a5cee3e386bde28ce21ff2ead3fc420f018604ca
DIFF: https://github.com/llvm/llvm-project/commit/a5cee3e386bde28ce21ff2ead3fc420f018604ca.diff

LOG: [mlir][linalg] Add a padding case for `ComplexType`

If the paddingAttr is an ArrayAttr with two values we know that
the element type is a `ComplexType` and we should pad the value
accordingly.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D154908

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index fe720aa24cdd77..f87fbbe412cfb1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
@@ -125,8 +126,17 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
     return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
   }
   Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
-  Value paddingValue = rewriter.create<arith::ConstantOp>(
-      opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
+
+  Value paddingValue;
+  if (auto complexTy = dyn_cast<ComplexType>(
+          getElementTypeOrSelf(opOperand->get().getType()))) {
+    auto complexAttr = cast<ArrayAttr>(paddingAttr);
+    paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
+                                                        complexTy, complexAttr);
+  } else {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
+  }
 
   // Pad the operand to the bounding box defined by `paddedShape`.
   auto paddedTensorType = RankedTensorType::get(


        


More information about the Mlir-commits mailing list