[Mlir-commits] [mlir] [mlir][Tensor] Add rank-reducing slice in generatedSlices (PR #174248)
Bangtian Liu
llvmlistbot at llvm.org
Tue Jan 6 09:41:19 PST 2026
https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/174248
>From 603341c6b35a20c402da8d653bd4f3fdf2e2e7da Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Fri, 2 Jan 2026 14:54:43 -0800
Subject: [PATCH 1/3] [mlir][Tensor] Add rank-reducing slice in generatedSlices
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
.../Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 549ac7afca8ca..7903f3c51b73b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -53,6 +53,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
builder, sliceOp.getLoc(), sliceOp.getType(),
tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
tiledResult->tiledValues[0] = newSliceOp;
+ tiledResult->generatedSlices.push_back(newSliceOp);
}
return *tiledResult;
>From fbd91be92a7d9a521cfb7bd13f7858f11ce164f2 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Fri, 2 Jan 2026 17:27:51 -0800
Subject: [PATCH 2/3] fix timeout issue
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
.../Dialect/SCF/Transforms/TileUsingInterface.cpp | 13 ++++++++++++-
.../tile-and-fuse-with-reduction-tiling.mlir | 5 ++++-
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0dcaeed70aa5e..884cc1cc2ac09 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1327,7 +1327,18 @@ getUntiledProducerFromSliceSource(OpOperand *source,
}
if (loopIt == loops.rend())
destinationIterArg = source;
- return {dyn_cast<OpResult>(source->get()), destinationIterArg};
+
+ OpResult result = dyn_cast<OpResult>(source->get());
+ if (result) {
+ Operation *producer = result.getOwner();
+ Operation *innermostLoop = loops.back();
+ // If the producer is already inside the innermost loop (where the slice
+ // is), it has already been fused. Skip it to avoid infinite loops.
+ if (innermostLoop->isProperAncestor(producer))
+ return {OpResult(), std::nullopt};
+ }
+
+ return {result, destinationIterArg};
}
/// Implementation of fusing producer of a single slice by computing the
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 8cace28d441c6..62c82a15a5417 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -1,6 +1,9 @@
// RUN: mlir-opt -transform-interpreter -cse -mlir-print-local-scope -split-input-file -verify-diagnostics %s | FileCheck %s
-// Check tile+ fuse works with partial reduction outer parallel strategy.
+// Check tile + fuse works with partial reduction outer parallel strategy.
+// This also tests that the fusion logic correctly skips producers that are
+// already inside the innermost loop (e.g., the rank-reducing slice of the
+// fused fill), avoiding infinite loops in the fusion worklist.
module{
func.func @tile_and_fuse_with_partial_reduction_outer_parallel(
>From edd4706c42831440edc01d8f850c516770069f4a Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 6 Jan 2026 09:05:32 -0800
Subject: [PATCH 3/3] add a test about rank-reducing slices
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
.../tile-and-fuse-using-interface.mlir | 60 +++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 21d7816934bf9..43a6705777623 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -675,3 +675,63 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: ins(%[[TILEDARG0]]
// CHECK-SAME: outs(%[[TILEDARG1]]
// CHECK: tensor.insert_slice %[[RES:.*]]
+
+// -----
+
+// Tests that rank-reducing slices created during fusion are tracked in
+// generatedSlices (via SwapExtractSliceWithProducerPatterns.cpp).
+// This enables cleanup patterns to transform them during tile-and-fuse.
+//
+// Affected cleanup patterns:
+// - SwapExtractSliceWithFillPatterns: swaps extract_slice(fill) -> fill(extract_slice).
+
+#map2d = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @fuse_fill_through_rank_reducing_slice(%arg0: tensor<4x96xf32>) -> tensor<4x96xf32> {
+ %cst = arith.constant 1.0 : f32
+
+ %empty_3d = tensor.empty() : tensor<4x1x96xf32>
+ %fill_3d = linalg.fill ins(%cst : f32) outs(%empty_3d : tensor<4x1x96xf32>) -> tensor<4x1x96xf32>
+
+ %slice_2d = tensor.extract_slice %fill_3d[0, 0, 0] [4, 1, 96] [1, 1, 1]
+ : tensor<4x1x96xf32> to tensor<4x96xf32>
+
+ %result = linalg.generic {
+ indexing_maps = [#map2d, #map2d],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%slice_2d : tensor<4x96xf32>)
+ outs(%arg0 : tensor<4x96xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %sum = arith.addf %in, %out : f32
+ linalg.yield %sum : f32
+ } -> tensor<4x96xf32>
+
+ return %result : tensor<4x96xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %consumer = transform.structured.match ops{["linalg.generic"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.fuse %consumer tile_sizes [0, 32] {apply_cleanup}
+ : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @fuse_fill_through_rank_reducing_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x96xf32>
+// CHECK-DAG: %[[EMPTY_3D:.+]] = tensor.empty() : tensor<4x1x96xf32>
+// CHECK: scf.for %[[IV:[a-zA-Z0-9_]+]] = {{.*}} iter_args(%[[ITERARG:.+]] = %[[ARG0]])
+
+// CHECK: %[[TILE_3D:.+]] = tensor.extract_slice %[[EMPTY_3D]][0, 0, %[[IV]]] [4, 1, 32] [1, 1, 1]
+// CHECK-SAME: tensor<4x1x96xf32> to tensor<4x1x32xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[TILE_3D]] : tensor<4x1x32xf32>)
+// CHECK: %[[RANK_REDUCED:.+]] = tensor.extract_slice %[[FILL]][0, 0, 0] [4, 1, 32] [1, 1, 1]
+// CHECK-SAME: tensor<4x1x32xf32> to tensor<4x32xf32>
+
+// CHECK: %[[CONSUMER_OUT:.+]] = tensor.extract_slice %[[ITERARG]][0, %[[IV]]] [4, 32]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[RANK_REDUCED]] : tensor<4x32xf32>)
+// CHECK-SAME: outs(%[[CONSUMER_OUT]] : tensor<4x32xf32>)
More information about the Mlir-commits
mailing list