[Mlir-commits] [mlir] 1e84e91 - [mlir][Linalg] NFC - Improve some transform op builders
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jul 11 06:35:55 PDT 2023
Author: Nicolas Vasilache
Date: 2023-07-11T15:35:43+02:00
New Revision: 1e84e91efa533b7dee7c77105dc9a3f9ae740da9
URL: https://github.com/llvm/llvm-project/commit/1e84e91efa533b7dee7c77105dc9a3f9ae740da9
DIFF: https://github.com/llvm/llvm-project/commit/1e84e91efa533b7dee7c77105dc9a3f9ae740da9.diff
LOG: [mlir][Linalg] NFC - Improve some transform op builders
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 82a1fcc5c8f4a4..2979a8018cdf3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -128,6 +128,11 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
OptionalAttr<AnyAttr>:$memory_space);
let results = (outs Transform_AnyValue:$allocated_buffer);
let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
+ OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
+ ];
}
//===----------------------------------------------------------------------===//
@@ -929,6 +934,20 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
"$target attr-dict `:` functional-type(operands, results)";
let hasVerifier = 1;
+ let builders = [
+ // Builder for a transform::PadOp with automatic inference of padding
+ // value. Warning: this will set the value 0 for the inferred elemental
+ // type without taking the op into account and thus only work for the
+ // add/mul ring at the moment.
+ // TODO: support other operations (e.g. min, max etc).
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$paddingDimensions,
+ CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
+ CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
+ CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
+ CArg<"bool", "false">:$copyBack)>
+ ];
+
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e86bb7c545109c..7ca6b272103b7a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -173,6 +173,26 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
+void transform::BufferizeToAllocationOp::build(OpBuilder &b,
+ OperationState &result,
+ Value target,
+ Attribute memorySpace) {
+ return build(b, result,
+ /*resultTypes=*/b.getType<transform::AnyValueType>(),
+ /*target=*/target,
+ /*memorySpace=*/memorySpace);
+}
+
+void transform::BufferizeToAllocationOp::build(OpBuilder &b,
+ OperationState &result,
+ Value target,
+ int64_t memorySpace) {
+ return build(b, result,
+ /*resultTypes=*/b.getType<transform::AnyValueType>(),
+ /*target=*/target,
+ /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
+}
+
DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
@@ -1448,6 +1468,27 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
// PadOp
//===---------------------------------------------------------------------===//
+void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
+ ArrayRef<int64_t> paddingDimensions,
+ ArrayRef<int64_t> padToMultipleOf,
+ ArrayRef<int64_t> packPaddings,
+ ArrayRef<Attribute> transposePaddings,
+ bool copyBack) {
+ auto resultType = transform::AnyOpType::get(b.getContext());
+ return build(/*builder=*/b,
+ /*result=*/result,
+ /*types=*/TypeRange{resultType, resultType},
+ /*target=*/target,
+ /*paddingValues=*/ArrayAttr(), // let inference handle this
+ /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*padToMultipleOf=*/
+ (padToMultipleOf.empty() ? ArrayAttr()
+ : b.getI64ArrayAttr(padToMultipleOf)),
+ /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
+ /*transposePaddings=*/b.getArrayAttr(transposePaddings),
+ /*copyBack=*/b.getBoolAttr(copyBack));
+}
+
DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 98e9c81e4a3974..fe720aa24cdd77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -139,12 +139,24 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
LogicalResult
linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
- const LinalgPaddingOptions &options,
+ const LinalgPaddingOptions &constOptions,
LinalgOp &paddedOp, SmallVector<Value> &replacements,
SmallVector<tensor::PadOp> &padOps, bool copyBack) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
Location loc = opToPad->getLoc();
+ LinalgPaddingOptions options(constOptions);
+ // Allow inference of pad values if they are not explicitly specified.
+ // TODO: be mindful about the value depending on the actual operation.
+ if (options.paddingValues.empty()) {
+ SmallVector<Type> types(opToPad->getOperandTypes());
+ llvm::append_range(types, opToPad->getResultTypes());
+ for (Type t : types) {
+ options.paddingValues.push_back(
+ rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ }
+ }
+
// TODO: there are cases where we may still want to pad to larger sizes.
if (!opToPad.hasTensorSemantics())
return rewriter.notifyMatchFailure(opToPad,
More information about the Mlir-commits
mailing list