[Mlir-commits] [mlir] [MLIR][Transform] FuseOp: accept transform params, add use_forall argument (PR #161883)
Tuomas Kärnä
llvmlistbot at llvm.org
Fri Oct 3 10:15:45 PDT 2025
https://github.com/tkarna created https://github.com/llvm/llvm-project/pull/161883
Changes to linalg `structured.fuse` transform op:
* Adds an optional `use_forall` boolean argument which generates a tiled `scf.forall` loop instead of `scf.for` loops.
* `tile_sizes` can now be any parameter or handle.
* `tile_interchange` can now be any parameter or handle.
* IR formatting changes
- from `transform.structured.fuse %0 [4, 8] ...`
- to `transform.structured.fuse %0 tile_sizes [4, 8] ...`
>From 60c130117abafff60febff00a3f728bf43e3a630 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Wed, 1 Oct 2025 14:52:38 +0300
Subject: [PATCH 1/2] FuseOp: add use_forall argument that generates scf.forall
loops
---
.../Linalg/TransformOps/LinalgTransformOps.td | 6 +++--
.../TransformOps/LinalgTransformOps.cpp | 13 +++++++---
.../mlir/dialects/transform/structured.py | 6 ++++-
.../Dialect/Linalg/transform-op-fuse.mlir | 25 +++++++++++++++++++
.../dialects/transform_structured_ext.py | 13 ++++++++++
5 files changed, 57 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0d6ebc087e2f3..7887a864139fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -410,13 +410,15 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
- DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
+ DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
+ DefaultValuedAttr<BoolAttr, "false">:$use_forall);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
- (`apply_cleanup` `=` $apply_cleanup^)? attr-dict
+ (`apply_cleanup` `=` $apply_cleanup^)?
+ (`use_forall` `=` $use_forall^)? attr-dict
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dd9b4c2490ef4..3555825120262 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -637,6 +637,10 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
+ bool useForall = getUseForall();
+ tilingOptions.setLoopType(useForall
+ ? scf::SCFTilingOptions::LoopType::ForallOp
+ : scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +656,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
+ size_t numLoops =
+ useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
- rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
+ transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -676,7 +682,8 @@ LogicalResult transform::FuseOp::verify() {
SmallVector<int64_t> sizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index e3bacb5777d9f..ed17465365397 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -147,6 +147,7 @@ def __init__(
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
@@ -160,6 +161,7 @@ def __init__(
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
@@ -173,6 +175,7 @@ def __init__(
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
@@ -180,7 +183,7 @@ def __init__(
tile_interchange = tile_interchange if tile_interchange else []
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
- num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
+ num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in tile_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -200,6 +203,7 @@ def __init__(
tile_sizes=tile_sizes,
tile_interchange=tile_interchange,
apply_cleanup=apply_cleanup,
+ use_forall=use_forall,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 9a44f95afb586..645fe6563cd69 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -57,6 +57,31 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @fuse_unary_forall
+func.func @fuse_unary_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: %[[RES:.*]] = scf.forall
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
+ // CHECK: return %[[RES]]
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop = transform.structured.fuse %0 [32, 32] use_forall = true
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @interchange_reduction
// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 8785d6d360074..4ad125a60cd96 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -114,6 +114,19 @@ def testFuseOpCompact(target):
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ at run
+ at create_sequence
+def testFuseOpCompactForall(target):
+ structured.FuseOp(
+ target, tile_sizes=[4, 8], apply_cleanup=True, use_forall=True,
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: apply_cleanup = true use_forall = true
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
@run
@create_sequence
def testFuseOpNoArg(target):
>From 2f64a1426b29cee624755651114b321cf0decd48 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Fri, 3 Oct 2025 14:34:33 +0300
Subject: [PATCH 2/2] FuseOp tile sizes and interchange args accept dynamic
values
---
.../Linalg/TransformOps/LinalgTransformOps.td | 57 ++++++-
.../TransformOps/LinalgTransformOps.cpp | 158 ++++++++++++++++--
.../mlir/dialects/transform/structured.py | 32 ++--
.../Dialect/Linalg/transform-op-fuse.mlir | 65 +++++--
mlir/test/Dialect/Tensor/tiling.mlir | 2 +-
.../tile-and-fuse-using-interface.mlir | 24 +--
.../dialects/transform_structured_ext.py | 24 ++-
7 files changed, 295 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 7887a864139fa..40588afa6477a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -395,33 +395,72 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp
//===----------------------------------------------------------------------===//
def FuseOp : Op<Transform_Dialect, "structured.fuse",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>,
- ReportTrackingListenerFailuresOpTrait]> {
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Tiles the operations pointed to by the target handle and fuses their
producers greedily using the options provided as attributes.
If `apply_cleanup` is true then slice canonicalization is applied between
- fusion steps.
+ fusion steps. If `use_forall` is true then tiling method generates a
+ `scf.forall` loop instead of `scf.for` loops.
}];
let arguments =
(ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
- DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
- DefaultValuedAttr<BoolAttr, "false">:$use_forall);
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
+ DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
+ DefaultValuedAttr<BoolAttr, "false">:$use_forall);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
+ let builders = [
+ OpBuilder<(ins "TypeRange":$loopTypes,
+ "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ "ArrayRef<int64_t>":$staticTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "TypeRange":$loopTypes,
+ "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedTileSizes,
+ "ArrayRef<OpFoldResult>":$mixedTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ "ArrayRef<int64_t>":$staticTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedTileSizes,
+ "ArrayRef<OpFoldResult>":$mixedTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ ];
let assemblyFormat = [{
- $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
+ $target
+ (`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes)^)?
+ (`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)^)?
(`apply_cleanup` `=` $apply_cleanup^)?
(`use_forall` `=` $use_forall^)? attr-dict
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
+ ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange();
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3555825120262..0d365f29a51a3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result, loopTypes,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
+ SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+ build(builder, result, loopTypes, target, mixedTileSizes,
+ mixedTileInterchange, applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
+ SmallVector<int64_t> staticTileInterchange;
+ SmallVector<Value> dynamicTileInterchange;
+ dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
+ staticTileInterchange);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileInterchangeAttr =
+ builder.getDenseI64ArrayAttr(staticTileInterchange);
+ unsigned numExpectedLoops =
+ useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(numExpectedLoops);
+ assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+ "expected one loop type or as many as loops");
+ if (loopTypes.size() == 1)
+ resultTypes.append(numExpectedLoops, loopTypes[0]);
+ else
+ llvm::append_range(resultTypes, loopTypes);
+ build(builder, result, /*transformed=*/target.getType(),
+ /*loops=*/resultTypes,
+ /*target=*/target,
+ /*tile_sizes=*/dynamicTileSizes,
+ /*tile_interchange=*/dynamicTileInterchange,
+ /*static_tile_sizes=*/staticTileSizesAttr,
+ /*static_tile_interchange=*/staticTileInterchangeAttr,
+ /*apply_cleanup=*/applyCleanup,
+ /*use_forall=*/useForall);
+}
+
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
@@ -630,10 +710,18 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- SmallVector<int64_t> tileInterchange =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int64_t> tileSizes;
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileSizes(), tileSizes);
+ if (!status.succeeded())
+ return status;
+ SmallVector<int64_t> tileInterchange;
+ status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileInterchange(), tileInterchange);
+ if (!status.succeeded())
+ return status;
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
@@ -671,17 +759,18 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (!llvm::any_of(permutation,
+ [](int64_t v) { return ShapedType::isDynamic(v); })) {
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << getTileInterchange();
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
size_t numExpectedLoops =
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
@@ -690,6 +779,49 @@ LogicalResult transform::FuseOp::verify() {
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ ValueRange dynamicValues = getTileSizes();
+ ArrayRef<int64_t> staticValues = getStaticTileSizes();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ ValueRange dynamicValues = getTileInterchange();
+ ArrayRef<int64_t> staticValues = getStaticTileInterchange();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
+}
+
+void transform::FuseOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getTileInterchangeMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index ed17465365397..d3fe3d5f085bf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -144,8 +144,8 @@ def __init__(
loop_types: Union[Type, Sequence[Type]],
target: Union[Operation, Value, OpView],
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
use_forall: Optional[bool] = False,
loc=None,
@@ -158,8 +158,8 @@ def __init__(
self,
target: Union[Operation, Value, OpView],
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
use_forall: Optional[bool] = False,
loc=None,
@@ -172,8 +172,8 @@ def __init__(
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
use_forall: Optional[bool] = False,
loc=None,
@@ -181,9 +181,17 @@ def __init__(
):
tile_sizes = tile_sizes if tile_sizes else []
tile_interchange = tile_interchange if tile_interchange else []
- _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
- _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
- num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in tile_sizes)
+ (
+ dynamic_tile_sizes,
+ static_tile_sizes,
+ _,
+ ) = _dispatch_dynamic_index_list(tile_sizes)
+ (
+ dynamic_tile_interchange,
+ static_tile_interchange,
+ _,
+ ) = _dispatch_dynamic_index_list(tile_interchange)
+ num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in static_tile_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -200,8 +208,10 @@ def __init__(
target.type,
loop_types,
target,
- tile_sizes=tile_sizes,
- tile_interchange=tile_interchange,
+ tile_sizes=dynamic_tile_sizes,
+ tile_interchange=dynamic_tile_interchange,
+ static_tile_sizes=static_tile_sizes,
+ static_tile_interchange=static_tile_interchange,
apply_cleanup=apply_cleanup,
use_forall=use_forall,
loc=loc,
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 645fe6563cd69..d472f75bfcb9a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -18,7 +18,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -48,7 +48,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -57,6 +57,35 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @fuse_unary_param
+func.func @fuse_unary_param(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: %[[RES:.*]] = scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
+ // CHECK: return %[[RES]]
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.param.constant 32 : i32 -> !transform.param<i32>
+ %2 = transform.param.constant 1 : i32 -> !transform.param<i32>
+ %3, %loops:2 = transform.structured.fuse %0 tile_sizes [%1, 32] interchange [0, %2]
+ : (!transform.any_op, !transform.param<i32>, !transform.param<i32>) ->
+ (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @fuse_unary_forall
func.func @fuse_unary_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -74,7 +103,7 @@ func.func @fuse_unary_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loop = transform.structured.fuse %0 [32, 32] use_forall = true
+ %1, %loop = transform.structured.fuse %0 tile_sizes [32, 32] use_forall = true
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -118,7 +147,7 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [5, 0, 7] interchange [0, 2, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%2, %loops_2 = transform.structured.tile_using_for %1 tile_sizes [0, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -146,7 +175,7 @@ func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [16, 32] interchange [0, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -172,7 +201,7 @@ func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [3, 5, 0, 0]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -198,7 +227,7 @@ func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]}
+ %1, %loops:3 = transform.structured.fuse %0 tile_sizes [3, 5, 2, 0]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -229,7 +258,7 @@ func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] apply_cleanup = true
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
@@ -263,7 +292,7 @@ func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] apply_cleanup = true
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
@@ -298,7 +327,7 @@ func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] apply_cleanup = true
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
@@ -324,7 +353,7 @@ func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> t
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true :
+ %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -349,7 +378,7 @@ func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim(%0: tenso
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true :
+ %transformed, %loops:2 = transform.structured.fuse %0 tile_sizes [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
@@ -373,7 +402,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous(%0: te
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true :
+ %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -404,7 +433,7 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true :
+ %transformed, %loops:4 = transform.structured.fuse %0 tile_sizes [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -433,7 +462,7 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true :
+ %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
transform.yield
}
@@ -458,7 +487,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false(%0:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false :
+ %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -481,7 +510,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
+ %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
transform.yield
}
@@ -507,7 +536,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
+ %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
transform.yield
}
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 04a99b5fd0d68..32fb0c9e41c39 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -149,7 +149,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%copy = transform.structured.match ops{["linalg.copy"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %copy [2, 3]
+ %a, %b, %c = transform.structured.fuse %copy tile_sizes [2, 3]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 8a0390a4379cf..0a056158d70b3 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -17,7 +17,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %matmul [10, 20]
+ %a, %b, %c = transform.structured.fuse %matmul tile_sizes [10, 20]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -69,7 +69,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %generic [10, 20]
+ %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%mm1, %mm2 = transform.split_handle %matmuls
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.structured.fuse %mm2 [10]
+ %a, %b = transform.structured.fuse %mm2 tile_sizes [10]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -188,7 +188,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %generic [10, 20]
+ %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -248,7 +248,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0]
+ %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] interchange[1, 0]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -307,7 +307,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %generic [10, 20]
+ %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -367,7 +367,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b, %c = transform.structured.fuse %generic [10, 20]
+ %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -423,7 +423,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%mm1, %mm2, %mm3 = transform.split_handle %matmuls
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
- %a, %b = transform.structured.fuse %mm3 [10]
+ %a, %b = transform.structured.fuse %mm3 tile_sizes [10]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -512,7 +512,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%generic1, %generic2, %generic3 = transform.split_handle %generics
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
- %a, %b = transform.structured.fuse %generic3 [10]
+ %a, %b = transform.structured.fuse %generic3 tile_sizes [10]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -568,7 +568,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%pad = transform.structured.match ops{["tensor.pad"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.structured.fuse %pad [8]
+ %a, %b = transform.structured.fuse %pad tile_sizes [8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -614,7 +614,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+ %a, %b = transform.structured.fuse %matmul tile_sizes [0, 1, 0]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -652,7 +652,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %loops:4 = transform.structured.fuse %generic {tile_sizes = [1, 16, 16, 16], tile_interchange = [0, 1, 2, 3], apply_cleanup = false}
+ %a, %loops:4 = transform.structured.fuse %generic tile_sizes [1, 16, 16, 16] interchange [0, 1, 2, 3] apply_cleanup = false
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 4ad125a60cd96..216e7a5caa31f 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -109,7 +109,7 @@ def testFuseOpCompact(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1] apply_cleanup = true
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -122,7 +122,7 @@ def testFuseOpCompactForall(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: apply_cleanup = true use_forall = true
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -137,6 +137,24 @@ def testFuseOpNoArg(target):
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
+ at run
+ at create_sequence
+def testFuseOpParams(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[constant_param(4), Attribute.parse("8")],
+ tile_interchange=[constant_param(0), Attribute.parse("1")]
+ )
+ # CHECK-LABEL: TEST: testFuseOpParams
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[P:.*]] = transform.param.constant 4
+ # CHECK-DAG: %[[I:.*]] = transform.param.constant 0
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[P]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
@run
@create_sequence
def testFuseOpAttributes(target):
@@ -145,7 +163,7 @@ def testFuseOpAttributes(target):
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
More information about the Mlir-commits
mailing list