[Mlir-commits] [mlir] Add support for transform.param values in `PadOp`s pad_to_multiple_of (PR #90755)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 1 10:55:34 PDT 2024
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/90755
None
>From 6aa9f0db3a39a877a87245206c7562ed27f77657 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 12:20:02 -0500
Subject: [PATCH] Add support for transform.param values in `PadOp`s
pad_to_multiple_of
---
.../Linalg/TransformOps/LinalgTransformOps.td | 17 ++-
.../TransformOps/LinalgTransformOps.cpp | 112 ++++++++++++++++--
.../mlir/dialects/transform/structured.py | 20 +++-
.../test/Dialect/Linalg/transform-op-pad.mlir | 3 +-
.../dialects/transform_structured_ext.py | 5 +-
5 files changed, 140 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d0ad4ccdf031d9..16fb8f4fcc9466 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1011,7 +1011,9 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
- OptionalAttr<I64ArrayAttr>:$pad_to_multiple_of,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$pad_to_multiple_of,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+ $static_pad_to_multiple_of,
DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
DefaultValuedAttr<
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
@@ -1021,8 +1023,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
TransformHandleTypeInterface:$pad,
TransformHandleTypeInterface:$copy);
- let assemblyFormat =
- "$target attr-dict `:` functional-type(operands, results)";
+ let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let builders = [
@@ -1033,7 +1034,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
// TODO: support other operations (e.g. min, max etc).
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
- CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
+ CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
+ CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
+ CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
+ CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$paddingDimensions,
+ "ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
@@ -1043,6 +1050,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
/// copy_back_op attribute value indicating that no copy back is desired.
static constexpr StringRef kCopyOpNone = "none";
+ SmallVector<OpFoldResult> getMixedPadToMultipleOf();
+
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 156784f0e67402..dc060f4c0641cb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1664,6 +1664,8 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
// PadOp
//===---------------------------------------------------------------------===//
+static const StringLiteral kPadToMultipleOfKeyword = "pad_to_multiple_of";
+
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<int64_t> padToMultipleOf,
@@ -1677,14 +1679,111 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
/*target=*/target,
/*paddingValues=*/ArrayAttr(), // let inference handle this
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*padToMultipleOf=*/ValueRange{},
/*padToMultipleOf=*/
- (padToMultipleOf.empty() ? ArrayAttr()
- : b.getI64ArrayAttr(padToMultipleOf)),
+ (padToMultipleOf.empty()
+ ? DenseI64ArrayAttr()
+ : b.getDenseI64ArrayAttr(padToMultipleOf)),
+ /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
+ /*transposePaddings=*/b.getArrayAttr(transposePaddings),
+ /*copyBackOp=*/b.getStringAttr(copyBackOp));
+}
+
+void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
+ ArrayRef<int64_t> paddingDimensions,
+ ArrayRef<OpFoldResult> mixedPadToMultipleOf,
+ ArrayRef<int64_t> packPaddings,
+ ArrayRef<Attribute> transposePaddings,
+ StringRef copyBackOp) {
+ auto resultType = transform::AnyOpType::get(b.getContext());
+ SmallVector<int64_t> staticPadToMultipleOf;
+ SmallVector<Value> dynamicPadToMultipleOf;
+ dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
+ staticPadToMultipleOf);
+ return build(/*builder=*/b,
+ /*result=*/result,
+ /*types=*/TypeRange{resultType, resultType},
+ /*target=*/target,
+ /*paddingValues=*/ArrayAttr(), // let inference handle this
+ /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*padToMultipleOf=*/dynamicPadToMultipleOf,
+ /*padToMultipleOf=*/staticPadToMultipleOf,
/*packPaddings=*/b.getI64ArrayAttr(packPaddings),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
/*copyBackOp=*/b.getStringAttr(copyBackOp));
}
+SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
+ OpBuilder b(getContext());
+ return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
+}
+
+ParseResult transform::PadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand target;
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicPadToMultipleOf;
+ DenseI64ArrayAttr padToMultipleOf;
+ FunctionType functionalType;
+ llvm::SMLoc operandLoc;
+
+ if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
+ return ParseResult::failure();
+
+ if (succeeded(parser.parseOptionalKeyword(kPadToMultipleOfKeyword))) {
+ if (failed(parseDynamicIndexList(parser, dynamicPadToMultipleOf,
+ padToMultipleOf)))
+ return ParseResult::failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(functionalType) ||
+ parser.resolveOperand(target, functionalType.getInputs().front(),
+ result.operands) ||
+ parser.resolveOperands(dynamicPadToMultipleOf,
+ functionalType.getInputs().drop_front(),
+ operandLoc, result.operands))
+ return ParseResult::failure();
+
+ if (padToMultipleOf)
+ result.addAttribute(getStaticPadToMultipleOfAttrName(result.name),
+ padToMultipleOf);
+
+ result.addTypes(functionalType.getResults());
+
+ return success();
+}
+
+void transform::PadOp::print(OpAsmPrinter &p) {
+ p << ' ' << getTarget() << ' ';
+ if (!getMixedPadToMultipleOf().empty()) {
+ p << kPadToMultipleOfKeyword << ' ';
+ printDynamicIndexList(p, getOperation(), getPadToMultipleOf(),
+ getStaticPadToMultipleOfAttr(),
+ /*valueTypes=*/{},
+ /*scalables=*/{}, OpAsmParser::Delimiter::Square);
+ }
+
+ OpBuilder builder((*this)->getContext());
+ SmallVector<StringRef, 6> elidedAttrs({getStaticPadToMultipleOfAttrName()});
+ if (getCopyBackOpAttr() ==
+ builder.getStringAttr(
+ bufferization::MaterializeInDestinationOp::getOperationName()))
+ elidedAttrs.push_back(getCopyBackOpAttrName());
+ if (getPackPaddingsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getPackPaddingsAttrName());
+ if (getTransposePaddingsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getTransposePaddingsAttrName());
+ if (getPaddingDimensionsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getPaddingDimensionsAttrName());
+ if (getPaddingValuesAttr() == builder.getArrayAttr({}))
+ elidedAttrs.push_back(getPaddingValuesAttrName());
+
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/elidedAttrs);
+ p << " : ";
+ p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
+}
+
DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -1750,9 +1849,8 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
options.paddingDimensions =
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
- if (getPadToMultipleOf().has_value())
- padToMultipleOf =
- extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
+ if (!getStaticPadToMultipleOf().empty())
+ padToMultipleOf = llvm::to_vector(getStaticPadToMultipleOf());
options.padToMultipleOf = padToMultipleOf;
options.paddingValues = paddingValues;
options.packPaddings = packPaddings;
@@ -1819,8 +1917,8 @@ LogicalResult transform::PadOp::verify() {
"integers, found "
<< getPaddingDimensions();
}
- if (getPadToMultipleOf().has_value()) {
- if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
+ if (!getMixedPadToMultipleOf().empty()) {
+ if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
return emitOpError() << "expects as many multiples as padding_dimensions";
}
}
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index d7b41c0bd2207d..81bbd6ffb3d403 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -373,10 +373,11 @@ class PadOp(PadOp):
def __init__(
self,
target: Union[Operation, OpView, Value],
+ pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
*,
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
- pad_to_multiple_of: OptionalIntList = None,
+ static_pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
@@ -385,6 +386,20 @@ def __init__(
loc=None,
ip=None,
):
+ if (
+ static_pad_to_multiple_of is None
+ and pad_to_multiple_of is None
+ ):
+ dynamic_pad_to_multiple_of = []
+ elif static_pad_to_multiple_of is None:
+ (
+ dynamic_pad_to_multiple_of,
+ static_pad_to_multiple_of,
+ _,
+ ) = _dispatch_dynamic_index_list(pad_to_multiple_of)
+ else:
+ dynamic_pad_to_multiple_of = pad_to_multiple_of
+
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
any_op_type = transform.AnyOpType.get()
@@ -393,9 +408,10 @@ def __init__(
any_op_type,
any_op_type,
target,
+ pad_to_multiple_of=dynamic_pad_to_multiple_of,
padding_values=padding_values,
padding_dimensions=padding_dimensions,
- pad_to_multiple_of=pad_to_multiple_of,
+ static_pad_to_multiple_of=static_pad_to_multiple_of,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings,
copy_back_op=copy_back_op,
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index d27276cda49dc4..f82d4500090c5a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -73,10 +73,9 @@ func.func @pad_to_multiple(%arg0: tensor<24x12xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %padded, %pad, %copy_back = transform.structured.pad %0 {
+ %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [2, 2, 1] {
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2],
- pad_to_multiple_of=[2, 2, 1],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 91ecd0fc38e174..418b1216df0532 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -315,9 +315,10 @@ def testPadOpNoArgs(target):
def testPadOpArgs(target):
structured.PadOp(
target,
+ [],
padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
padding_dimensions=Attribute.parse("[1]"),
- pad_to_multiple_of=[128],
+ static_pad_to_multiple_of=[128],
pack_paddings=[0],
transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
copy_back_op="linalg.copy",
@@ -325,9 +326,9 @@ def testPadOpArgs(target):
# CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
+ # CHECK-DAG: pad_to_multiple_of [128]
# CHECK-DAG: copy_back_op = "linalg.copy"
# CHECK-DAG: pack_paddings = [0]
- # CHECK-DAG: pad_to_multiple_of = [128]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
More information about the Mlir-commits
mailing list