[Mlir-commits] [mlir] a5cee3e - [mlir][linalg] Add a padding case for `ComplexType`
Rob Suderman
llvmlistbot at llvm.org
Mon Jul 17 17:22:23 PDT 2023
Author: Robert Suderman
Date: 2023-07-17T17:20:38-07:00
New Revision: a5cee3e386bde28ce21ff2ead3fc420f018604ca
URL: https://github.com/llvm/llvm-project/commit/a5cee3e386bde28ce21ff2ead3fc420f018604ca
DIFF: https://github.com/llvm/llvm-project/commit/a5cee3e386bde28ce21ff2ead3fc420f018604ca.diff
LOG: [mlir][linalg] Add a padding case for `ComplexType`
If the paddingAttr is an ArrayAttr with two values we know that
the element type is a `ComplexType` and we should pad the value
accordingly.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D154908
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index fe720aa24cdd77..f87fbbe412cfb1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
@@ -125,8 +126,17 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
}
Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
- Value paddingValue = rewriter.create<arith::ConstantOp>(
- opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
+
+ Value paddingValue;
+ if (auto complexTy = dyn_cast<ComplexType>(
+ getElementTypeOrSelf(opOperand->get().getType()))) {
+ auto complexAttr = cast<ArrayAttr>(paddingAttr);
+ paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
+ complexTy, complexAttr);
+ } else {
+ paddingValue = rewriter.create<arith::ConstantOp>(
+ opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
+ }
// Pad the operand to the bounding box defined by `paddedShape`.
auto paddedTensorType = RankedTensorType::get(
More information about the Mlir-commits
mailing list