[Mlir-commits] [mlir] [mlir][transform] Add support for transform.param pad multiples in `PadOp` (PR #90755)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 4 14:50:40 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/90755
>From 6aa9f0db3a39a877a87245206c7562ed27f77657 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 12:20:02 -0500
Subject: [PATCH 01/12] Add support for transform.param values in `PadOp`s
pad_to_multiple_of
---
.../Linalg/TransformOps/LinalgTransformOps.td | 17 ++-
.../TransformOps/LinalgTransformOps.cpp | 112 ++++++++++++++++--
.../mlir/dialects/transform/structured.py | 20 +++-
.../test/Dialect/Linalg/transform-op-pad.mlir | 3 +-
.../dialects/transform_structured_ext.py | 5 +-
5 files changed, 140 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d0ad4ccdf031d9..16fb8f4fcc9466 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1011,7 +1011,9 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
- OptionalAttr<I64ArrayAttr>:$pad_to_multiple_of,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$pad_to_multiple_of,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+ $static_pad_to_multiple_of,
DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
DefaultValuedAttr<
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
@@ -1021,8 +1023,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
TransformHandleTypeInterface:$pad,
TransformHandleTypeInterface:$copy);
- let assemblyFormat =
- "$target attr-dict `:` functional-type(operands, results)";
+ let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let builders = [
@@ -1033,7 +1034,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
// TODO: support other operations (e.g. min, max etc).
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
- CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
+ CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
+ CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
+ CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
+ CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$paddingDimensions,
+ "ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
@@ -1043,6 +1050,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
/// copy_back_op attribute value indicating that no copy back is desired.
static constexpr StringRef kCopyOpNone = "none";
+ SmallVector<OpFoldResult> getMixedPadToMultipleOf();
+
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 156784f0e67402..dc060f4c0641cb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1664,6 +1664,8 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
// PadOp
//===---------------------------------------------------------------------===//
+static const StringLiteral kPadToMultipleOfKeyword = "pad_to_multiple_of";
+
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<int64_t> padToMultipleOf,
@@ -1677,14 +1679,111 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
/*target=*/target,
/*paddingValues=*/ArrayAttr(), // let inference handle this
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*padToMultipleOf=*/ValueRange{},
/*padToMultipleOf=*/
- (padToMultipleOf.empty() ? ArrayAttr()
- : b.getI64ArrayAttr(padToMultipleOf)),
+ (padToMultipleOf.empty()
+ ? DenseI64ArrayAttr()
+ : b.getDenseI64ArrayAttr(padToMultipleOf)),
+ /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
+ /*transposePaddings=*/b.getArrayAttr(transposePaddings),
+ /*copyBackOp=*/b.getStringAttr(copyBackOp));
+}
+
+void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
+ ArrayRef<int64_t> paddingDimensions,
+ ArrayRef<OpFoldResult> mixedPadToMultipleOf,
+ ArrayRef<int64_t> packPaddings,
+ ArrayRef<Attribute> transposePaddings,
+ StringRef copyBackOp) {
+ auto resultType = transform::AnyOpType::get(b.getContext());
+ SmallVector<int64_t> staticPadToMultipleOf;
+ SmallVector<Value> dynamicPadToMultipleOf;
+ dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
+ staticPadToMultipleOf);
+ return build(/*builder=*/b,
+ /*result=*/result,
+ /*types=*/TypeRange{resultType, resultType},
+ /*target=*/target,
+ /*paddingValues=*/ArrayAttr(), // let inference handle this
+ /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*padToMultipleOf=*/dynamicPadToMultipleOf,
+ /*padToMultipleOf=*/staticPadToMultipleOf,
/*packPaddings=*/b.getI64ArrayAttr(packPaddings),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
/*copyBackOp=*/b.getStringAttr(copyBackOp));
}
+SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
+ OpBuilder b(getContext());
+ return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
+}
+
+ParseResult transform::PadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand target;
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicPadToMultipleOf;
+ DenseI64ArrayAttr padToMultipleOf;
+ FunctionType functionalType;
+ llvm::SMLoc operandLoc;
+
+ if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
+ return ParseResult::failure();
+
+ if (succeeded(parser.parseOptionalKeyword(kPadToMultipleOfKeyword))) {
+ if (failed(parseDynamicIndexList(parser, dynamicPadToMultipleOf,
+ padToMultipleOf)))
+ return ParseResult::failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(functionalType) ||
+ parser.resolveOperand(target, functionalType.getInputs().front(),
+ result.operands) ||
+ parser.resolveOperands(dynamicPadToMultipleOf,
+ functionalType.getInputs().drop_front(),
+ operandLoc, result.operands))
+ return ParseResult::failure();
+
+ if (padToMultipleOf)
+ result.addAttribute(getStaticPadToMultipleOfAttrName(result.name),
+ padToMultipleOf);
+
+ result.addTypes(functionalType.getResults());
+
+ return success();
+}
+
+void transform::PadOp::print(OpAsmPrinter &p) {
+ p << ' ' << getTarget() << ' ';
+ if (!getMixedPadToMultipleOf().empty()) {
+ p << kPadToMultipleOfKeyword << ' ';
+ printDynamicIndexList(p, getOperation(), getPadToMultipleOf(),
+ getStaticPadToMultipleOfAttr(),
+ /*valueTypes=*/{},
+ /*scalables=*/{}, OpAsmParser::Delimiter::Square);
+ }
+
+ OpBuilder builder((*this)->getContext());
+ SmallVector<StringRef, 6> elidedAttrs({getStaticPadToMultipleOfAttrName()});
+ if (getCopyBackOpAttr() ==
+ builder.getStringAttr(
+ bufferization::MaterializeInDestinationOp::getOperationName()))
+ elidedAttrs.push_back(getCopyBackOpAttrName());
+ if (getPackPaddingsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getPackPaddingsAttrName());
+ if (getTransposePaddingsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getTransposePaddingsAttrName());
+ if (getPaddingDimensionsAttr() == builder.getI64ArrayAttr({}))
+ elidedAttrs.push_back(getPaddingDimensionsAttrName());
+ if (getPaddingValuesAttr() == builder.getArrayAttr({}))
+ elidedAttrs.push_back(getPaddingValuesAttrName());
+
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/elidedAttrs);
+ p << " : ";
+ p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
+}
+
DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -1750,9 +1849,8 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
options.paddingDimensions =
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
- if (getPadToMultipleOf().has_value())
- padToMultipleOf =
- extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
+ if (!getStaticPadToMultipleOf().empty())
+ padToMultipleOf = llvm::to_vector(getStaticPadToMultipleOf());
options.padToMultipleOf = padToMultipleOf;
options.paddingValues = paddingValues;
options.packPaddings = packPaddings;
@@ -1819,8 +1917,8 @@ LogicalResult transform::PadOp::verify() {
"integers, found "
<< getPaddingDimensions();
}
- if (getPadToMultipleOf().has_value()) {
- if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
+ if (!getMixedPadToMultipleOf().empty()) {
+ if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
return emitOpError() << "expects as many multiples as padding_dimensions";
}
}
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index d7b41c0bd2207d..81bbd6ffb3d403 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -373,10 +373,11 @@ class PadOp(PadOp):
def __init__(
self,
target: Union[Operation, OpView, Value],
+ pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
*,
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
- pad_to_multiple_of: OptionalIntList = None,
+ static_pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
@@ -385,6 +386,20 @@ def __init__(
loc=None,
ip=None,
):
+ if (
+ static_pad_to_multiple_of is None
+ and pad_to_multiple_of is None
+ ):
+ dynamic_pad_to_multiple_of = []
+ elif static_pad_to_multiple_of is None:
+ (
+ dynamic_pad_to_multiple_of,
+ static_pad_to_multiple_of,
+ _,
+ ) = _dispatch_dynamic_index_list(pad_to_multiple_of)
+ else:
+ dynamic_pad_to_multiple_of = pad_to_multiple_of
+
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
any_op_type = transform.AnyOpType.get()
@@ -393,9 +408,10 @@ def __init__(
any_op_type,
any_op_type,
target,
+ pad_to_multiple_of=dynamic_pad_to_multiple_of,
padding_values=padding_values,
padding_dimensions=padding_dimensions,
- pad_to_multiple_of=pad_to_multiple_of,
+ static_pad_to_multiple_of=static_pad_to_multiple_of,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings,
copy_back_op=copy_back_op,
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index d27276cda49dc4..f82d4500090c5a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -73,10 +73,9 @@ func.func @pad_to_multiple(%arg0: tensor<24x12xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %padded, %pad, %copy_back = transform.structured.pad %0 {
+ %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [2, 2, 1] {
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2],
- pad_to_multiple_of=[2, 2, 1],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 91ecd0fc38e174..418b1216df0532 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -315,9 +315,10 @@ def testPadOpNoArgs(target):
def testPadOpArgs(target):
structured.PadOp(
target,
+ [],
padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
padding_dimensions=Attribute.parse("[1]"),
- pad_to_multiple_of=[128],
+ static_pad_to_multiple_of=[128],
pack_paddings=[0],
transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
copy_back_op="linalg.copy",
@@ -325,9 +326,9 @@ def testPadOpArgs(target):
# CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
+ # CHECK-DAG: pad_to_multiple_of [128]
# CHECK-DAG: copy_back_op = "linalg.copy"
# CHECK-DAG: pack_paddings = [0]
- # CHECK-DAG: pad_to_multiple_of = [128]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
>From a1a7b170cff06461421b98d08b0942b60694351f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 13:57:34 -0500
Subject: [PATCH 02/12] fix for param pad_to_multiple_of
---
.../Linalg/TransformOps/LinalgTransformOps.td | 13 ++--
.../TransformOps/LinalgTransformOps.cpp | 62 ++++++++++++++++++-
.../test/Dialect/Linalg/transform-op-pad.mlir | 36 +++++++++++
3 files changed, 101 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 16fb8f4fcc9466..ada7f7666d5f60 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -978,8 +978,8 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
//===----------------------------------------------------------------------===//
def PadOp : Op<Transform_Dialect, "structured.pad",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>,
+ [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Pads the operations pointed to by the target handle using the options
@@ -1052,11 +1052,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
SmallVector<OpFoldResult> getMixedPadToMultipleOf();
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &results,
+ ::mlir::transform::TransformState &state);
}];
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dc060f4c0641cb..c68bc4c1f025ac 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1713,6 +1713,16 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
/*copyBackOp=*/b.getStringAttr(copyBackOp));
}
+void PadOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ onlyReadsHandle(getPadToMultipleOf(), effects);
+ producesHandle(getPadded(), effects);
+ producesHandle(getPad(), effects);
+ producesHandle(getCopy(), effects);
+ modifiesPayload(effects);
+}
+
SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
OpBuilder b(getContext());
return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
@@ -1848,9 +1858,55 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
LinalgPaddingOptions options;
options.paddingDimensions =
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
- SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
- if (!getStaticPadToMultipleOf().empty())
- padToMultipleOf = llvm::to_vector(getStaticPadToMultipleOf());
+
+ SmallVector<int64_t> padToMultipleOf;
+ for (OpFoldResult sz : getMixedPadToMultipleOf()) {
+ if (sz.is<Attribute>()) {
+ auto attr = sz.get<Attribute>();
+ padToMultipleOf.push_back(cast<IntegerAttr>(attr).getInt());
+ continue;
+ } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
+ ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
+ if (params.size() != 1)
+ return emitSilenceableFailure(getLoc()) << "expected a single param";
+ padToMultipleOf.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ auto szPayloads = state.getPayloadOps(sz.get<Value>());
+ if (!llvm::hasSingleElement(szPayloads)) {
+ auto diag = this->emitOpError("requires pad_to_multiple_of handle that "
+ "is mapped to 1 payload op");
+ diag.attachNote(sz.get<Value>().getLoc())
+ << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ Operation *szPayloadOp = *szPayloads.begin();
+ if (szPayloadOp->getNumResults() != 1 ||
+ !szPayloadOp->getResult(0).getType().isIndex()) {
+ auto diag = this->emitOpError(
+ "requires vector pad_to_multiple_of op with 1 index result");
+ diag.attachNote(szPayloadOp->getLoc())
+ << "pad_to_multiple_of payload op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ IntegerAttr attr;
+ if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
+ auto diag = this->emitOpError("requires constant pad_to_multiple_of");
+ diag.attachNote(szPayloadOp->getLoc())
+ << "pad_to_multiple_of payload op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ padToMultipleOf.push_back(attr.getInt());
+ }
+ if (padToMultipleOf.empty())
+ padToMultipleOf =
+ SmallVector<int64_t>(options.paddingDimensions.size(), 1);
+
options.padToMultipleOf = padToMultipleOf;
options.paddingValues = paddingValues;
options.packPaddings = packPaddings;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index f82d4500090c5a..47bb5ddf4afc3e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -86,6 +86,42 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+// CHECK-LABEL: @parametrized_pad_to_multiple
+func.func @parametrized_pad_to_multiple(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>,
+ %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> {
+ %0 = affine.min #map()[%iv2]
+ %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+ %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+ %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+ // CHECK: linalg.matmul
+ // CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x7xf32>, tensor<7x6xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<4x6xf32>)
+ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+ func.return %5 : tensor<24x25xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+ %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [%c2, 2, 1] {
+ padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
+ padding_dimensions=[0, 1, 2],
+ pack_paddings=[1, 1, 0]
+ } : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,
>From 5159eeb464c0e6858056cb877194d2293af8c81c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 13:58:54 -0500
Subject: [PATCH 03/12] format
---
mlir/python/mlir/dialects/transform/structured.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 81bbd6ffb3d403..4f4a0e598df7d3 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -386,10 +386,7 @@ def __init__(
loc=None,
ip=None,
):
- if (
- static_pad_to_multiple_of is None
- and pad_to_multiple_of is None
- ):
+ if static_pad_to_multiple_of is None and pad_to_multiple_of is None:
dynamic_pad_to_multiple_of = []
elif static_pad_to_multiple_of is None:
(
>From 23ef22ab40f4494aa93d6034f4c38714a9ec4a05 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 14:11:43 -0500
Subject: [PATCH 04/12] cleanup diagnostic messages
---
.../Linalg/TransformOps/LinalgTransformOps.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c68bc4c1f025ac..c3963aee828565 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1860,6 +1860,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
SmallVector<int64_t> padToMultipleOf;
+ // TODO: This should probably be a common utility function.
for (OpFoldResult sz : getMixedPadToMultipleOf()) {
if (sz.is<Attribute>()) {
auto attr = sz.get<Attribute>();
@@ -1876,8 +1877,9 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
auto szPayloads = state.getPayloadOps(sz.get<Value>());
if (!llvm::hasSingleElement(szPayloads)) {
- auto diag = this->emitOpError("requires pad_to_multiple_of handle that "
- "is mapped to 1 payload op");
+ auto diag = this->emitOpError()
+ << "requires " << kPadToMultipleOfKeyword
+ << " handle that is mapped to 1 payload op";
diag.attachNote(sz.get<Value>().getLoc())
<< "mapped to " << llvm::range_size(szPayloads) << " payload ops";
return DiagnosedSilenceableFailure::definiteFailure();
@@ -1886,18 +1888,20 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
Operation *szPayloadOp = *szPayloads.begin();
if (szPayloadOp->getNumResults() != 1 ||
!szPayloadOp->getResult(0).getType().isIndex()) {
- auto diag = this->emitOpError(
- "requires vector pad_to_multiple_of op with 1 index result");
+ auto diag = this->emitOpError()
+ << "requires " << kPadToMultipleOfKeyword
+ << " to be result of op with 1 index result";
diag.attachNote(szPayloadOp->getLoc())
- << "pad_to_multiple_of payload op";
+ << kPadToMultipleOfKeyword << " payload op";
return DiagnosedSilenceableFailure::definiteFailure();
}
IntegerAttr attr;
if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
- auto diag = this->emitOpError("requires constant pad_to_multiple_of");
+ auto diag = this->emitOpError()
+ << "requires constant " << kPadToMultipleOfKeyword;
diag.attachNote(szPayloadOp->getLoc())
- << "pad_to_multiple_of payload op";
+ << kPadToMultipleOfKeyword << " payload op";
return DiagnosedSilenceableFailure::definiteFailure();
}
>From 2dade05508512e565feccac381b185e749b3f735 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 19:36:40 -0500
Subject: [PATCH 05/12] refactor paramhandle reification
---
.../TransformOps/LinalgTransformOps.cpp | 143 +++++++-----------
1 file changed, 54 insertions(+), 89 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c3963aee828565..01d4d2a033830c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -171,6 +171,50 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
return DiagnosedSilenceableFailure::success();
}
+static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
+ TransformState &state, TransformOpInterface &transformOp,
+ const SmallVectorImpl<OpFoldResult> &mixedResults,
+ SmallVectorImpl<int64_t> &reified) {
+ for (OpFoldResult paramOrHandle : mixedResults) {
+ if (isa<Attribute>(paramOrHandle)) {
+ reified.push_back(
+ cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
+ continue;
+ } else if (isa<Value>(paramOrHandle) &&
+ isa<ParamType>(paramOrHandle.get<Value>().getType())) {
+ ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure() << "expected a single param";
+ reified.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ auto paramOrHandlePayloads =
+ state.getPayloadOps(paramOrHandle.get<Value>());
+ if (!llvm::hasSingleElement(paramOrHandlePayloads))
+ return transformOp.emitDefiniteFailure()
+ << "requires param or handle that is mapped to 1 payload op";
+
+ Operation *paramOrHandlePayloadOp = *paramOrHandlePayloads.begin();
+ if (paramOrHandlePayloadOp->getNumResults() != 1 ||
+ !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
+ return transformOp.emitDefiniteFailure()
+ << "requires param or handle to be result of op with 1 index "
+ "result";
+ }
+
+ IntegerAttr attr;
+ if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
+ return transformOp.emitDefiniteFailure()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ reified.push_back(attr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
@@ -1798,6 +1842,7 @@ DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
+ auto transformOp = cast<TransformOpInterface>(getOperation());
SmallVector<Operation *> paddedOps, padOps, copyBackOps;
for (Operation *target : state.getPayloadOps(getTarget())) {
@@ -1860,53 +1905,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
SmallVector<int64_t> padToMultipleOf;
- // TODO: This should probably be a common utility function.
- for (OpFoldResult sz : getMixedPadToMultipleOf()) {
- if (sz.is<Attribute>()) {
- auto attr = sz.get<Attribute>();
- padToMultipleOf.push_back(cast<IntegerAttr>(attr).getInt());
- continue;
- } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
- ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
- if (params.size() != 1)
- return emitSilenceableFailure(getLoc()) << "expected a single param";
- padToMultipleOf.push_back(
- cast<IntegerAttr>(params.front()).getValue().getSExtValue());
- continue;
- }
-
- auto szPayloads = state.getPayloadOps(sz.get<Value>());
- if (!llvm::hasSingleElement(szPayloads)) {
- auto diag = this->emitOpError()
- << "requires " << kPadToMultipleOfKeyword
- << " handle that is mapped to 1 payload op";
- diag.attachNote(sz.get<Value>().getLoc())
- << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- Operation *szPayloadOp = *szPayloads.begin();
- if (szPayloadOp->getNumResults() != 1 ||
- !szPayloadOp->getResult(0).getType().isIndex()) {
- auto diag = this->emitOpError()
- << "requires " << kPadToMultipleOfKeyword
- << " to be result of op with 1 index result";
- diag.attachNote(szPayloadOp->getLoc())
- << kPadToMultipleOfKeyword << " payload op";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- IntegerAttr attr;
- if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
- auto diag = this->emitOpError()
- << "requires constant " << kPadToMultipleOfKeyword;
- diag.attachNote(szPayloadOp->getLoc())
- << kPadToMultipleOfKeyword << " payload op";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- padToMultipleOf.push_back(attr.getInt());
- }
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
+ if (!status.succeeded())
+ return status;
if (padToMultipleOf.empty())
padToMultipleOf =
SmallVector<int64_t>(options.paddingDimensions.size(), 1);
@@ -3362,49 +3364,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
auto targets = state.getPayloadOps(getTarget());
if (std::empty(targets))
return DiagnosedSilenceableFailure::success();
-
+ auto transformOp = cast<TransformOpInterface>(getOperation());
SmallVector<int64_t> vectorSizes;
- for (OpFoldResult sz : getMixedVectorSizes()) {
- if (sz.is<Attribute>()) {
- auto attr = sz.get<Attribute>();
- vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
- continue;
- } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
- ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
- if (params.size() != 1)
- return emitSilenceableFailure(getLoc()) << "expected a single param";
- vectorSizes.push_back(
- cast<IntegerAttr>(params.front()).getValue().getSExtValue());
- continue;
- }
-
- auto szPayloads = state.getPayloadOps(sz.get<Value>());
- if (!llvm::hasSingleElement(szPayloads)) {
- auto diag = this->emitOpError(
- "requires vector size handle that is mapped to 1 payload op");
- diag.attachNote(sz.get<Value>().getLoc())
- << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- Operation *szPayloadOp = *szPayloads.begin();
- if (szPayloadOp->getNumResults() != 1 ||
- !szPayloadOp->getResult(0).getType().isIndex()) {
- auto diag = this->emitOpError(
- "requires vector size payload op with 1 index result");
- diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- IntegerAttr attr;
- if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
- auto diag = this->emitOpError("requires constant vector size");
- diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
-
- vectorSizes.push_back(attr.getInt());
- }
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedVectorSizes(), vectorSizes);
+ if (!status.succeeded())
+ return status;
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
>From 124957b0c21964f09add3d07761a2ff34737f6ad Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 23:57:16 -0500
Subject: [PATCH 06/12] Add python test for new functionality
---
.../python/dialects/transform_structured_ext.py | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 418b1216df0532..0667a2ce86e926 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -8,6 +8,7 @@
from mlir.dialects import pdl
from mlir.dialects.transform import structured
from mlir.dialects.transform import pdl as transform_pdl
+from mlir.dialects.transform.extras import constant_param
def run(f):
@@ -334,6 +335,22 @@ def testPadOpArgs(target):
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
+ at run
+ at create_sequence
+def testPadOpArgsParam(target):
+ structured.PadOp(
+ target,
+ [constant_param(128), Attribute.parse("2")],
+ padding_dimensions=Attribute.parse("[0, 1]"),
+ )
+ # CHECK-LABEL: TEST: testPadOpArgsParam
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[P:.*]] = transform.param.constant 128
+ # CHECK: transform.structured.pad
+ # CHECK-DAG: pad_to_multiple_of [%[[P]], 2]
+ # CHECK-DAG: padding_dimensions = [1]
+
+
@run
@create_sequence
def testScalarize(target):
>From f69f052762b71ca7c3a14921751db85d794ef94e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 2 May 2024 00:01:55 -0500
Subject: [PATCH 07/12] fix typo
---
mlir/test/python/dialects/transform_structured_ext.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 0667a2ce86e926..8deca33de6d99d 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -348,7 +348,7 @@ def testPadOpArgsParam(target):
# CHECK-DAG: %[[P:.*]] = transform.param.constant 128
# CHECK: transform.structured.pad
# CHECK-DAG: pad_to_multiple_of [%[[P]], 2]
- # CHECK-DAG: padding_dimensions = [1]
+ # CHECK-DAG: padding_dimensions = [0, 1]
@run
>From a9c511e00e3cd98fab61c1667b4492967e1023af Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 2 May 2024 09:26:28 -0500
Subject: [PATCH 08/12] use tablegen assembly format
---
.../Linalg/TransformOps/LinalgTransformOps.td | 8 ++-
.../TransformOps/LinalgTransformOps.cpp | 66 -------------------
2 files changed, 7 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ada7f7666d5f60..8ae9b3f7121b76 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1023,7 +1023,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
TransformHandleTypeInterface:$pad,
TransformHandleTypeInterface:$copy);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ $target oilist(
+ `pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of))
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
let hasVerifier = 1;
let builders = [
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 01d4d2a033830c..87b49accd340a8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1772,72 +1772,6 @@ SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
}
-ParseResult transform::PadOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::UnresolvedOperand target;
- SmallVector<OpAsmParser::UnresolvedOperand> dynamicPadToMultipleOf;
- DenseI64ArrayAttr padToMultipleOf;
- FunctionType functionalType;
- llvm::SMLoc operandLoc;
-
- if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
- return ParseResult::failure();
-
- if (succeeded(parser.parseOptionalKeyword(kPadToMultipleOfKeyword))) {
- if (failed(parseDynamicIndexList(parser, dynamicPadToMultipleOf,
- padToMultipleOf)))
- return ParseResult::failure();
- }
-
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(functionalType) ||
- parser.resolveOperand(target, functionalType.getInputs().front(),
- result.operands) ||
- parser.resolveOperands(dynamicPadToMultipleOf,
- functionalType.getInputs().drop_front(),
- operandLoc, result.operands))
- return ParseResult::failure();
-
- if (padToMultipleOf)
- result.addAttribute(getStaticPadToMultipleOfAttrName(result.name),
- padToMultipleOf);
-
- result.addTypes(functionalType.getResults());
-
- return success();
-}
-
-void transform::PadOp::print(OpAsmPrinter &p) {
- p << ' ' << getTarget() << ' ';
- if (!getMixedPadToMultipleOf().empty()) {
- p << kPadToMultipleOfKeyword << ' ';
- printDynamicIndexList(p, getOperation(), getPadToMultipleOf(),
- getStaticPadToMultipleOfAttr(),
- /*valueTypes=*/{},
- /*scalables=*/{}, OpAsmParser::Delimiter::Square);
- }
-
- OpBuilder builder((*this)->getContext());
- SmallVector<StringRef, 6> elidedAttrs({getStaticPadToMultipleOfAttrName()});
- if (getCopyBackOpAttr() ==
- builder.getStringAttr(
- bufferization::MaterializeInDestinationOp::getOperationName()))
- elidedAttrs.push_back(getCopyBackOpAttrName());
- if (getPackPaddingsAttr() == builder.getI64ArrayAttr({}))
- elidedAttrs.push_back(getPackPaddingsAttrName());
- if (getTransposePaddingsAttr() == builder.getI64ArrayAttr({}))
- elidedAttrs.push_back(getTransposePaddingsAttrName());
- if (getPaddingDimensionsAttr() == builder.getI64ArrayAttr({}))
- elidedAttrs.push_back(getPaddingDimensionsAttrName());
- if (getPaddingValuesAttr() == builder.getArrayAttr({}))
- elidedAttrs.push_back(getPaddingValuesAttrName());
-
- p.printOptionalAttrDict((*this)->getAttrs(),
- /*elidedAttrs=*/elidedAttrs);
- p << " : ";
- p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
-}
-
DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
>From 3008e5de723514345390a55151f8ab9d9d58c3ba Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 2 May 2024 10:19:43 -0500
Subject: [PATCH 09/12] address some comments
---
.../Linalg/TransformOps/LinalgTransformOps.td | 1 +
.../TransformOps/LinalgTransformOps.cpp | 27 ++++++++++---------
2 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8ae9b3f7121b76..f23c65d827d168 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1056,6 +1056,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
/// copy_back_op attribute value indicating that no copy back is desired.
static constexpr StringRef kCopyOpNone = "none";
+ /// Returns a mix of dynamic `pad_to_multiple_of` and static `static_pad_to_multiple_of`.
SmallVector<OpFoldResult> getMixedPadToMultipleOf();
::mlir::DiagnosedSilenceableFailure apply(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 87b49accd340a8..8752e90bc7cad0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -171,42 +171,43 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
return DiagnosedSilenceableFailure::success();
}
+/// When possible, converts each `OpFoldResult` in `mixedResult` to
+/// an integer if the value can be statically inferred. If a result
+/// is a `Value` then it must be either a `ParamType` or a handle
+/// to an a constant like op.
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
TransformState &state, TransformOpInterface &transformOp,
- const SmallVectorImpl<OpFoldResult> &mixedResults,
- SmallVectorImpl<int64_t> &reified) {
+ ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
for (OpFoldResult paramOrHandle : mixedResults) {
if (isa<Attribute>(paramOrHandle)) {
reified.push_back(
cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
continue;
- } else if (isa<Value>(paramOrHandle) &&
- isa<ParamType>(paramOrHandle.get<Value>().getType())) {
+ } else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
if (params.size() != 1)
- return transformOp.emitDefiniteFailure() << "expected a single param";
+ return transformOp.emitSilenceableError() << "expected a single param";
reified.push_back(
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
continue;
}
- auto paramOrHandlePayloads =
- state.getPayloadOps(paramOrHandle.get<Value>());
- if (!llvm::hasSingleElement(paramOrHandlePayloads))
- return transformOp.emitDefiniteFailure()
+ auto payload = state.getPayloadOps(paramOrHandle.get<Value>());
+ if (!llvm::hasSingleElement(payload))
+ return transformOp.emitSilenceableError()
<< "requires param or handle that is mapped to 1 payload op";
- Operation *paramOrHandlePayloadOp = *paramOrHandlePayloads.begin();
+ Operation *paramOrHandlePayloadOp = *payload.begin();
if (paramOrHandlePayloadOp->getNumResults() != 1 ||
!paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
- return transformOp.emitDefiniteFailure()
+ return transformOp.emitSilenceableError()
<< "requires param or handle to be result of op with 1 index "
"result";
}
IntegerAttr attr;
if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
- return transformOp.emitDefiniteFailure()
+ return transformOp.emitSilenceableError()
<< "requires param or handle to be the result of a constant like "
"op";
@@ -1768,7 +1769,7 @@ void PadOp::getEffects(
}
SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
- OpBuilder b(getContext());
+ Builder b(getContext());
return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
}
>From f65cc751df750da7940c76895cd3de9f9277ff89 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 3 May 2024 09:58:23 -0500
Subject: [PATCH 10/12] address review comments
---
.../Dialect/Linalg/TransformOps/LinalgTransformOps.td | 4 ++--
mlir/python/mlir/dialects/transform/structured.py | 10 ++++------
mlir/test/python/dialects/transform_structured_ext.py | 11 +++++------
3 files changed, 11 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f23c65d827d168..55d82fd5825bf7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1024,8 +1024,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
TransformHandleTypeInterface:$copy);
let assemblyFormat = [{
- $target oilist(
- `pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of))
+ $target
+ (`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
attr-dict
`:` functional-type(operands, results)
}];
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 4f4a0e598df7d3..2c49ef0960c756 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -373,11 +373,10 @@ class PadOp(PadOp):
def __init__(
self,
target: Union[Operation, OpView, Value],
- pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
*,
+ pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
- static_pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
@@ -386,16 +385,15 @@ def __init__(
loc=None,
ip=None,
):
- if static_pad_to_multiple_of is None and pad_to_multiple_of is None:
+ if pad_to_multiple_of is None:
dynamic_pad_to_multiple_of = []
- elif static_pad_to_multiple_of is None:
+ static_pad_to_multiple_of = None
+ else:
(
dynamic_pad_to_multiple_of,
static_pad_to_multiple_of,
_,
) = _dispatch_dynamic_index_list(pad_to_multiple_of)
- else:
- dynamic_pad_to_multiple_of = pad_to_multiple_of
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 8deca33de6d99d..f4c092ba9ee98f 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -316,10 +316,9 @@ def testPadOpNoArgs(target):
def testPadOpArgs(target):
structured.PadOp(
target,
- [],
+ pad_to_multiple_of=[128],
padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
padding_dimensions=Attribute.parse("[1]"),
- static_pad_to_multiple_of=[128],
pack_paddings=[0],
transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
copy_back_op="linalg.copy",
@@ -340,15 +339,15 @@ def testPadOpArgs(target):
def testPadOpArgsParam(target):
structured.PadOp(
target,
- [constant_param(128), Attribute.parse("2")],
- padding_dimensions=Attribute.parse("[0, 1]"),
+ pad_to_multiple_of=[constant_param(128), Attribute.parse("2"), 10],
+ padding_dimensions=Attribute.parse("[0, 1, 2]"),
)
# CHECK-LABEL: TEST: testPadOpArgsParam
# CHECK: transform.sequence
# CHECK-DAG: %[[P:.*]] = transform.param.constant 128
# CHECK: transform.structured.pad
- # CHECK-DAG: pad_to_multiple_of [%[[P]], 2]
- # CHECK-DAG: padding_dimensions = [0, 1]
+ # CHECK-DAG: pad_to_multiple_of [%[[P]], 2, 10]
+ # CHECK-DAG: padding_dimensions = [0, 1, 2]
@run
>From eaf0e38fd69f336181ff943c4d3c1d556de54fd8 Mon Sep 17 00:00:00 2001
From: srcarroll <50210727+srcarroll at users.noreply.github.com>
Date: Sat, 4 May 2024 16:15:46 -0500
Subject: [PATCH 11/12] Update
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8752e90bc7cad0..c5b2ce197f62de 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -192,7 +192,10 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
continue;
}
- auto payload = state.getPayloadOps(paramOrHandle.get<Value>());
+ Value handle = paramOrHandle.get<Value>();
+ if (!isa<TransformHandleOpInterface>(handle.getType())
+ return transformOp.emitSilenceableError() << "unexpected value handle";
+ auto payload = state.getPayloadOps(handle);
if (!llvm::hasSingleElement(payload))
return transformOp.emitSilenceableError()
<< "requires param or handle that is mapped to 1 payload op";
>From 66f9d4dfef2662adba5a8271e459de8d27321b5b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 4 May 2024 16:50:22 -0500
Subject: [PATCH 12/12] fix typo
---
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c5b2ce197f62de..eadd819bee740c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -193,7 +193,7 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
}
Value handle = paramOrHandle.get<Value>();
- if (!isa<TransformHandleOpInterface>(handle.getType())
+ if (!isa<TransformHandleTypeInterface>(handle.getType()))
return transformOp.emitSilenceableError() << "unexpected value handle";
auto payload = state.getPayloadOps(handle);
if (!llvm::hasSingleElement(payload))
More information about the Mlir-commits
mailing list