[Mlir-commits] [mlir] [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (PR #188261)

Kunwar Grover llvmlistbot at llvm.org
Tue Mar 24 08:01:29 PDT 2026


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/188261

>From aa58715b268bfe3b7957ef1bcef425f5cdf6b3bd Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 24 Mar 2026 14:36:59 +0000
Subject: [PATCH 1/2] [mlir][linalg] Fix linalg.index handeling in partial
 reduction tiling

---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  2 +
 .../Linalg/transform-tile-reduction.mlir      | 35 ++++++++++++++++
 .../tile-and-fuse-with-reduction-tiling.mlir  | 40 +++++++++++++++++++
 3 files changed, 77 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fd9c8a7a8eba7..558ebdebd65c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -634,11 +634,13 @@ struct LinalgOpPartialReductionInterface
       IRMapping mapping;
       op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                  genericOp.getRegion().begin(), mapping);
+      offsetIndices(b, genericOp, offsets);
       partialReductionOp = genericOp.getOperation();
     } else {
       SmallVector<Value> operands = std::move(tiledInputs);
       llvm::append_range(operands, tiledInits);
       partialReductionOp = mlir::clone(b, op, resultTypes, operands);
+      offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
     }
     return TilingResult{
         {partialReductionOp},
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..7f25886778bbc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,38 @@ module {
 //      CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]]
 // CHECK-SAME:       outs(%[[ARG2]] :
 //      CHECK:   return %[[R]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+func.func @reduction_tile_with_linalg_index(%arg0: tensor<8x128xf32>, %out: tensor<8xi32>) -> tensor<8xi32> {
+  %red = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<8x128xf32>)
+      outs(%out : tensor<8xi32>) {
+    ^bb0(%in: f32, %acc: i32):
+      %idx = linalg.index 1 : index
+      %idx_i32 = arith.index_cast %idx : index to i32
+      %sum = arith.addi %idx_i32, %acc : i32
+      linalg.yield %sum : i32
+  } -> tensor<8xi32>
+  return %red : tensor<8xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @reduction_tile_with_linalg_index(
+//       CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
+//       CHECK:     linalg.generic
+//       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+//       CHECK:       affine.apply {{.*}}(%[[IV]])[%[[LOCAL_IDX]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 62c82a15a5417..13f023ec9002b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -59,3 +59,43 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[REDUCE:.+]] = linalg.reduce
 //  CHECK-SAME:       ins(%[[FORALL]] :
 //       CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+module {
+  func.func @partial_reduction_with_linalg_index(
+      %arg0 : tensor<8x128xf32>) -> tensor<8xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %empty = tensor.empty() : tensor<8xi32>
+    %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<8xi32>) -> tensor<8xi32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                         affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]}
+        ins(%arg0 : tensor<8x128xf32>) outs(%fill : tensor<8xi32>) {
+      ^bb0(%b0 : f32, %b1 : i32):
+        %idx = linalg.index 1 : index
+        %idx_i32 = arith.index_cast %idx : index to i32
+        %0 = arith.addi %idx_i32, %b1 : i32
+        linalg.yield %0 : i32
+    } -> tensor<8xi32>
+    return %generic : tensor<8xi32>
+  }
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+      %generic tile_sizes = [32]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @partial_reduction_with_linalg_index(
+//       CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+//       CHECK:       affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[IV0]])[%[[LOCAL_IDX]]]

>From def9aac5125ad43024c6bc9755a3aaeeff682b7b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 24 Mar 2026 15:01:06 +0000
Subject: [PATCH 2/2] match affine_map

---
 mlir/test/Dialect/Linalg/transform-tile-reduction.mlir | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 7f25886778bbc..e31d4f333557c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -722,8 +722,9 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// CHECK-DAG: #[[$INDEX_MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECK-LABEL: func @reduction_tile_with_linalg_index(
 //       CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
 //       CHECK:     linalg.generic
 //       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
-//       CHECK:       affine.apply {{.*}}(%[[IV]])[%[[LOCAL_IDX]]]
+//       CHECK:       affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]



More information about the Mlir-commits mailing list