[Mlir-commits] [mlir] [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (PR #188261)
Kunwar Grover
llvmlistbot at llvm.org
Tue Mar 24 07:42:46 PDT 2026
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/188261
PartialReduction tiling wasn't handeling linalg.index offsets properly. This patch fixes it to do the same thing as TilingInterface.
>From aa58715b268bfe3b7957ef1bcef425f5cdf6b3bd Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 24 Mar 2026 14:36:59 +0000
Subject: [PATCH] [mlir][linalg] Fix linalg.index handeling in partial
reduction tiling
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 2 +
.../Linalg/transform-tile-reduction.mlir | 35 ++++++++++++++++
.../tile-and-fuse-with-reduction-tiling.mlir | 40 +++++++++++++++++++
3 files changed, 77 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fd9c8a7a8eba7..558ebdebd65c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -634,11 +634,13 @@ struct LinalgOpPartialReductionInterface
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
+ offsetIndices(b, genericOp, offsets);
partialReductionOp = genericOp.getOperation();
} else {
SmallVector<Value> operands = std::move(tiledInputs);
llvm::append_range(operands, tiledInits);
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
+ offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
}
return TilingResult{
{partialReductionOp},
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..7f25886778bbc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,38 @@ module {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK: return %[[R]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+func.func @reduction_tile_with_linalg_index(%arg0: tensor<8x128xf32>, %out: tensor<8xi32>) -> tensor<8xi32> {
+ %red = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<8x128xf32>)
+ outs(%out : tensor<8xi32>) {
+ ^bb0(%in: f32, %acc: i32):
+ %idx = linalg.index 1 : index
+ %idx_i32 = arith.index_cast %idx : index to i32
+ %sum = arith.addi %idx_i32, %acc : i32
+ linalg.yield %sum : i32
+ } -> tensor<8xi32>
+ return %red : tensor<8xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @reduction_tile_with_linalg_index(
+// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK: linalg.generic
+// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+// CHECK: affine.apply {{.*}}(%[[IV]])[%[[LOCAL_IDX]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 62c82a15a5417..13f023ec9002b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -59,3 +59,43 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[REDUCE:.+]] = linalg.reduce
// CHECK-SAME: ins(%[[FORALL]] :
// CHECK: return %[[REDUCE]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+module {
+ func.func @partial_reduction_with_linalg_index(
+ %arg0 : tensor<8x128xf32>) -> tensor<8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %empty = tensor.empty() : tensor<8xi32>
+ %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<8xi32>) -> tensor<8xi32>
+ %generic = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<8x128xf32>) outs(%fill : tensor<8xi32>) {
+ ^bb0(%b0 : f32, %b1 : i32):
+ %idx = linalg.index 1 : index
+ %idx_i32 = arith.index_cast %idx : index to i32
+ %0 = arith.addi %idx_i32, %b1 : i32
+ linalg.yield %0 : i32
+ } -> tensor<8xi32>
+ return %generic : tensor<8xi32>
+ }
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+ %generic tile_sizes = [32]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @partial_reduction_with_linalg_index(
+// CHECK: scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+// CHECK: affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[IV0]])[%[[LOCAL_IDX]]]
More information about the Mlir-commits
mailing list