[Mlir-commits] [mlir] [MLIR] Add fusability query to TilingInterface (PR #166502)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Nov 26 09:15:12 PST 2025
https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/166502
>From 10e225e03a4bca6b63f79af88d18e5c3ee0a2c05 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 21 Oct 2025 18:08:14 -0400
Subject: [PATCH] [MLIR] Add fusability query to TilingInterface
This introduces `isOpFusableWithProducer/Consumer` methods to the
TilingInterface that enable querying whether a tilable op can be fused
into a given set of producer slices or consumer slice without generating
IR. This is needed to enable use of the tiling interface in pattern
rewrites, as without this any pattern rewrite that tries to invoke the
method to tile is allowed to generate IR and fail.
---
.../mlir/Interfaces/TilingInterface.td | 37 ++++++
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 3 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 28 ++++-
.../TilingInterface/query-fusability.mlir | 70 ++++++++++++
.../TestTilingInterfaceTransformOps.cpp | 105 ++++++++++++++++++
.../TestTilingInterfaceTransformOps.td | 46 +++++++-
6 files changed, 285 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Interfaces/TilingInterface/query-fusability.mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index e0516abdfcf0c..c30782a25e40f 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
+ >,
+ //===------------------------------------------------------------------===//
+ // Interface methods for querying fusability.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Indicates whether it is possible to fuse this operation with the given
+ result slice. This method is not allowed to generate any IR.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isOpFusableWithConsumerSlice",
+ /*args=*/(ins
+ "unsigned":$resultNumber,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Indicates whether it is possible to fuse this operation with the given
+ list of operand slices. This method is not allowed to generate any IR.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isOpFusableWithProducerSlices",
+ /*args=*/(ins
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
>
];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 705d6f26efd29..8e14ef4a2ea12 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
SmallVector<OpFoldResult> allShapeSizes =
op.createFlatListOfOperandDims(b, op.getLoc());
AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
- if (!shapeSizesToLoopsMap)
- return failure();
+ assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap");
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 8a0440bcc6fb9..50a84ace09258 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -167,7 +167,7 @@ struct LinalgOpTilingInterface
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
if (!dimExpr)
- continue;
+ return failure();
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
@@ -357,6 +357,32 @@ struct LinalgOpTilingInterface
/// Inline the op payload and store the result.
return inlinePayload(builder, linalgOp, ivs, indexedValues);
}
+
+ bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ // The verifier gives all the necessary requirements for consumer fusion.
+ return true;
+ }
+
+ bool isOpFusableWithProducerSlices(
+ Operation *op, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+
+ auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ // Check that offsets/sizes are consistent across all operands.
+ OpBuilder b(op);
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
+ allOffsets, allSizes, mappedOffsets,
+ mappedSizes));
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
new file mode 100644
index 0000000000000..d7b0528a764bb
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+
+ %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+ %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+ // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}}
+ %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_producer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+
+ %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+ %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+ // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}}
+ %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_producer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @fusable_with_consumer_extract_slice(%arg0: tensor<100x200xf32>, %arg1: tensor<100x200xf32>, %dest: tensor<100x200xf32>) -> tensor<10x20xf32> {
+ // expected-remark @+1 {{can be fused with consumer tensor.extract_slice op}}
+ %add = linalg.add ins(%arg0, %arg1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ %c0 = arith.constant 0 : index
+ %slice = tensor.extract_slice %add[%c0, %c0] [10, 20] [1, 1] : tensor<100x200xf32> to tensor<10x20xf32>
+
+ return %slice : tensor<10x20xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_consumer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 74bdaaa3d7c57..583d68b83cf0b 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
@@ -683,6 +684,110 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestQueryProducerFusability
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp) {
+ return emitSilenceableError()
+ << "target operation does not implement TilingInterface";
+ }
+
+ // Collect operand numbers and their corresponding producer insert_slice
+ // offsets and sizes.
+ SmallVector<unsigned> operandNumbers;
+ SmallVector<SmallVector<OpFoldResult>> allOffsets;
+ SmallVector<SmallVector<OpFoldResult>> allSizes;
+
+ for (OpOperand &operand : target->getOpOperands()) {
+ Value operandValue = operand.get();
+ Operation *definingOp = operandValue.getDefiningOp();
+
+ // Look for a producer tensor.insert_slice. This is only for testing
+ // purposes and otherwise is not a useful transformation.
+ if (auto insertSliceOp =
+ dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
+ operandNumbers.push_back(operand.getOperandNumber());
+ allOffsets.push_back(insertSliceOp.getMixedOffsets());
+ allSizes.push_back(insertSliceOp.getMixedSizes());
+ }
+ }
+
+ if (!operandNumbers.empty()) {
+ bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices(
+ operandNumbers, allOffsets, allSizes);
+
+ if (isFusable) {
+ target->emitRemark()
+ << "can be fused with producer tensor.insert_slice ops";
+ } else {
+ target->emitRemark()
+ << "cannot be fused with producer tensor.insert_slice ops";
+ }
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestQueryProducerFusability::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// TestQueryConsumerFusability
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp) {
+ return emitSilenceableError()
+ << "target operation does not implement TilingInterface";
+ }
+
+ // Look for tensor.extract_slice ops that consume results of the tilable op.
+ for (OpResult result : target->getResults()) {
+ for (OpOperand &use : result.getUses()) {
+ Operation *user = use.getOwner();
+
+ // Look for a consumer tensor.extract_slice. This is only for testing
+ // purposes and otherwise is not a useful transformation.
+ if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
+ bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice(
+ result.getResultNumber(), extractSliceOp.getMixedOffsets(),
+ extractSliceOp.getMixedSizes());
+
+ if (isFusable) {
+ target->emitRemark()
+ << "can be fused with consumer tensor.extract_slice op";
+ } else {
+ target->emitRemark()
+ << "cannot be fused with consumer tensor.extract_slice op";
+ }
+ }
+ }
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestQueryConsumerFusability::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsPayload(effects);
+}
+
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 29669bd0930ed..8c4f64de47795 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -197,11 +197,55 @@ def TestTileUsingCustomLoopOp : Op<
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs TransformHandleTypeInterface:$tiled_ops,
Variadic<TransformHandleTypeInterface>:$loops);
-
+
let assemblyFormat = [{
$root_op `tile_sizes` `=` $tile_sizes
attr-dict `:` functional-type(operands, results)
}];
}
+def TestQueryProducerFusability : Op<
+ Transform_Dialect, "test.query_producer_fusability",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Test operation for the producer fusability query method in the
+ TilingInterface.
+
+ For each operation in the target handle, this looks for tensor.insert_slice
+ ops that produce operands to the tilable op. The offset/sizes from those
+ inserts is used as the arguments to `isOpFusableWithProducerSlices` and
+ emits a remark with the result of the query.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+}
+
+def TestQueryConsumerFusability
+ : Op<Transform_Dialect, "test.query_consumer_fusability",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Test operation for the consumer fusability query method in the
+ TilingInterface.
+
+ For each operation in the target handle, this looks for tensor.extract_slice
+ ops that consume results of the tilable op. The offset/sizes from those
+ extracts is used as the arguments to `isOpFusableWithConsumerSlice` and
+ emits a remark with the result of the query.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+}
+
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS
More information about the Mlir-commits
mailing list