[Mlir-commits] [mlir] fb0881f - [mlir][Tensor] Add rank-reducing slice in generatedSlices (#174248)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 14 15:29:45 PST 2026
Author: Bangtian Liu
Date: 2026-01-14T18:29:41-05:00
New Revision: fb0881f62891e4042725731ec2e4cbc7e8f37e7c
URL: https://github.com/llvm/llvm-project/commit/fb0881f62891e4042725731ec2e4cbc7e8f37e7c
DIFF: https://github.com/llvm/llvm-project/commit/fb0881f62891e4042725731ec2e4cbc7e8f37e7c.diff
LOG: [mlir][Tensor] Add rank-reducing slice in generatedSlices (#174248)
When `replaceExtractSliceWithTiledProducer `creates a rank-reducing
slice to handle type mismatches, it should be tracked in
`generatedSlices `so downstream cleanup patterns (like IREE's
FoldExtractSliceOfBroadcast) can process it.
This PR also fixes an infinite loop in getUntiledProducerFromSliceSource
where adding the slice to generatedSlices caused the fusion worklist to
repeatedly try to re-fuse producers already inside the innermost loop;
the fix skips producers that are already inside the innermost loop via
an isProperAncestor check.
Added a lit test (@fuse_through_rank_reducing_slice) demonstrating
correct fusion through rank-reducing slices. Note that demonstrating the
generatedSlices tracking benefit requires a cleanup pattern
(SwapExtractSliceWithFillPatterns) to consume the slice; IREE's full CI
suite (iree-org/iree#23012) validates this works correctly in practice
with patterns like FoldExtractSliceOfBroadcast.
---------
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0dcaeed70aa5e..4d22a5e97ba4a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1327,7 +1327,18 @@ getUntiledProducerFromSliceSource(OpOperand *source,
}
if (loopIt == loops.rend())
destinationIterArg = source;
- return {dyn_cast<OpResult>(source->get()), destinationIterArg};
+
+ auto result = dyn_cast<OpResult>(source->get());
+ if (result) {
+ Operation *producer = result.getOwner();
+ Operation *innermostLoop = loops.back();
+ // If the producer is already inside the innermost loop (where the slice
+ // is), it has already been fused. Skip it to avoid infinite loops.
+ if (innermostLoop->isProperAncestor(producer))
+ return {OpResult(), std::nullopt};
+ }
+
+ return {result, destinationIterArg};
}
/// Implementation of fusing producer of a single slice by computing the
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 549ac7afca8ca..7903f3c51b73b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -53,6 +53,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
builder, sliceOp.getLoc(), sliceOp.getType(),
tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
tiledResult->tiledValues[0] = newSliceOp;
+ tiledResult->generatedSlices.push_back(newSliceOp);
}
return *tiledResult;
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 21d7816934bf9..43a6705777623 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -675,3 +675,63 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: ins(%[[TILEDARG0]]
// CHECK-SAME: outs(%[[TILEDARG1]]
// CHECK: tensor.insert_slice %[[RES:.*]]
+
+// -----
+
+// Tests that rank-reducing slices created during fusion are tracked in
+// generatedSlices (via SwapExtractSliceWithProducerPatterns.cpp).
+// This enables cleanup patterns to transform them during tile-and-fuse.
+//
+// Affected cleanup patterns:
+// - SwapExtractSliceWithFillPatterns: swaps extract_slice(fill) -> fill(extract_slice).
+
+#map2d = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @fuse_fill_through_rank_reducing_slice(%arg0: tensor<4x96xf32>) -> tensor<4x96xf32> {
+ %cst = arith.constant 1.0 : f32
+
+ %empty_3d = tensor.empty() : tensor<4x1x96xf32>
+ %fill_3d = linalg.fill ins(%cst : f32) outs(%empty_3d : tensor<4x1x96xf32>) -> tensor<4x1x96xf32>
+
+ %slice_2d = tensor.extract_slice %fill_3d[0, 0, 0] [4, 1, 96] [1, 1, 1]
+ : tensor<4x1x96xf32> to tensor<4x96xf32>
+
+ %result = linalg.generic {
+ indexing_maps = [#map2d, #map2d],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%slice_2d : tensor<4x96xf32>)
+ outs(%arg0 : tensor<4x96xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %sum = arith.addf %in, %out : f32
+ linalg.yield %sum : f32
+ } -> tensor<4x96xf32>
+
+ return %result : tensor<4x96xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %consumer = transform.structured.match ops{["linalg.generic"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.fuse %consumer tile_sizes [0, 32] {apply_cleanup}
+ : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @fuse_fill_through_rank_reducing_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x96xf32>
+// CHECK-DAG: %[[EMPTY_3D:.+]] = tensor.empty() : tensor<4x1x96xf32>
+// CHECK: scf.for %[[IV:[a-zA-Z0-9_]+]] = {{.*}} iter_args(%[[ITERARG:.+]] = %[[ARG0]])
+
+// CHECK: %[[TILE_3D:.+]] = tensor.extract_slice %[[EMPTY_3D]][0, 0, %[[IV]]] [4, 1, 32] [1, 1, 1]
+// CHECK-SAME: tensor<4x1x96xf32> to tensor<4x1x32xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[TILE_3D]] : tensor<4x1x32xf32>)
+// CHECK: %[[RANK_REDUCED:.+]] = tensor.extract_slice %[[FILL]][0, 0, 0] [4, 1, 32] [1, 1, 1]
+// CHECK-SAME: tensor<4x1x32xf32> to tensor<4x32xf32>
+
+// CHECK: %[[CONSUMER_OUT:.+]] = tensor.extract_slice %[[ITERARG]][0, %[[IV]]] [4, 32]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[RANK_REDUCED]] : tensor<4x32xf32>)
+// CHECK-SAME: outs(%[[CONSUMER_OUT]] : tensor<4x32xf32>)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 8cace28d441c6..62c82a15a5417 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -1,6 +1,9 @@
// RUN: mlir-opt -transform-interpreter -cse -mlir-print-local-scope -split-input-file -verify-diagnostics %s | FileCheck %s
-// Check tile+ fuse works with partial reduction outer parallel strategy.
+// Check tile + fuse works with partial reduction outer parallel strategy.
+// This also tests that the fusion logic correctly skips producers that are
+// already inside the innermost loop (e.g., the rank-reducing slice of the
+// fused fill), avoiding infinite loops in the fusion worklist.
module{
func.func @tile_and_fuse_with_partial_reduction_outer_parallel(
More information about the Mlir-commits
mailing list