[Mlir-commits] [mlir] 1d883c6 - [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (#188261)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 24 08:54:35 PDT 2026
Author: Kunwar Grover
Date: 2026-03-24T15:54:29Z
New Revision: 1d883c675e786920d9e1c6af4fb5fcfb40b4f09a
URL: https://github.com/llvm/llvm-project/commit/1d883c675e786920d9e1c6af4fb5fcfb40b4f09a
DIFF: https://github.com/llvm/llvm-project/commit/1d883c675e786920d9e1c6af4fb5fcfb40b4f09a.diff
LOG: [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (#188261)
PartialReduction tiling wasn't handeling linalg.index offsets properly.
This patch fixes it to do the same thing as TilingInterface.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fd9c8a7a8eba7..558ebdebd65c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -634,11 +634,13 @@ struct LinalgOpPartialReductionInterface
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
+ offsetIndices(b, genericOp, offsets);
partialReductionOp = genericOp.getOperation();
} else {
SmallVector<Value> operands = std::move(tiledInputs);
llvm::append_range(operands, tiledInits);
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
+ offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
}
return TilingResult{
{partialReductionOp},
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..e31d4f333557c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,39 @@ module {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK: return %[[R]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+func.func @reduction_tile_with_linalg_index(%arg0: tensor<8x128xf32>, %out: tensor<8xi32>) -> tensor<8xi32> {
+ %red = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<8x128xf32>)
+ outs(%out : tensor<8xi32>) {
+ ^bb0(%in: f32, %acc: i32):
+ %idx = linalg.index 1 : index
+ %idx_i32 = arith.index_cast %idx : index to i32
+ %sum = arith.addi %idx_i32, %acc : i32
+ linalg.yield %sum : i32
+ } -> tensor<8xi32>
+ return %red : tensor<8xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[$INDEX_MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: func @reduction_tile_with_linalg_index(
+// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK: linalg.generic
+// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+// CHECK: affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 62c82a15a5417..13f023ec9002b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -59,3 +59,43 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[REDUCE:.+]] = linalg.reduce
// CHECK-SAME: ins(%[[FORALL]] :
// CHECK: return %[[REDUCE]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+module {
+ func.func @partial_reduction_with_linalg_index(
+ %arg0 : tensor<8x128xf32>) -> tensor<8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %empty = tensor.empty() : tensor<8xi32>
+ %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<8xi32>) -> tensor<8xi32>
+ %generic = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<8x128xf32>) outs(%fill : tensor<8xi32>) {
+ ^bb0(%b0 : f32, %b1 : i32):
+ %idx = linalg.index 1 : index
+ %idx_i32 = arith.index_cast %idx : index to i32
+ %0 = arith.addi %idx_i32, %b1 : i32
+ linalg.yield %0 : i32
+ } -> tensor<8xi32>
+ return %generic : tensor<8xi32>
+ }
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+ %generic tile_sizes = [32]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @partial_reduction_with_linalg_index(
+// CHECK: scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+// CHECK: affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[IV0]])[%[[LOCAL_IDX]]]
More information about the Mlir-commits
mailing list