[Mlir-commits] [mlir] 1d883c6 - [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (#188261)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 24 08:54:35 PDT 2026


Author: Kunwar Grover
Date: 2026-03-24T15:54:29Z
New Revision: 1d883c675e786920d9e1c6af4fb5fcfb40b4f09a

URL: https://github.com/llvm/llvm-project/commit/1d883c675e786920d9e1c6af4fb5fcfb40b4f09a
DIFF: https://github.com/llvm/llvm-project/commit/1d883c675e786920d9e1c6af4fb5fcfb40b4f09a.diff

LOG: [mlir][linalg] Fix linalg.index handeling in partial reduction tiling (#188261)

PartialReduction tiling wasn't handeling linalg.index offsets properly.
This patch fixes it to do the same thing as TilingInterface.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fd9c8a7a8eba7..558ebdebd65c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -634,11 +634,13 @@ struct LinalgOpPartialReductionInterface
       IRMapping mapping;
       op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                  genericOp.getRegion().begin(), mapping);
+      offsetIndices(b, genericOp, offsets);
       partialReductionOp = genericOp.getOperation();
     } else {
       SmallVector<Value> operands = std::move(tiledInputs);
       llvm::append_range(operands, tiledInits);
       partialReductionOp = mlir::clone(b, op, resultTypes, operands);
+      offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
     }
     return TilingResult{
         {partialReductionOp},

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..e31d4f333557c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,39 @@ module {
 //      CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]]
 // CHECK-SAME:       outs(%[[ARG2]] :
 //      CHECK:   return %[[R]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+func.func @reduction_tile_with_linalg_index(%arg0: tensor<8x128xf32>, %out: tensor<8xi32>) -> tensor<8xi32> {
+  %red = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<8x128xf32>)
+      outs(%out : tensor<8xi32>) {
+    ^bb0(%in: f32, %acc: i32):
+      %idx = linalg.index 1 : index
+      %idx_i32 = arith.index_cast %idx : index to i32
+      %sum = arith.addi %idx_i32, %acc : i32
+      linalg.yield %sum : i32
+  } -> tensor<8xi32>
+  return %red : tensor<8xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG: #[[$INDEX_MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: func @reduction_tile_with_linalg_index(
+//       CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
+//       CHECK:     linalg.generic
+//       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+//       CHECK:       affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
index 62c82a15a5417..13f023ec9002b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -59,3 +59,43 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:   %[[REDUCE:.+]] = linalg.reduce
 //  CHECK-SAME:       ins(%[[FORALL]] :
 //       CHECK:   return %[[REDUCE]]
+
+// -----
+
+// Check that linalg.index is correctly offset after partial reduction tiling.
+
+module {
+  func.func @partial_reduction_with_linalg_index(
+      %arg0 : tensor<8x128xf32>) -> tensor<8xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %empty = tensor.empty() : tensor<8xi32>
+    %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<8xi32>) -> tensor<8xi32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                         affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]}
+        ins(%arg0 : tensor<8x128xf32>) outs(%fill : tensor<8xi32>) {
+      ^bb0(%b0 : f32, %b1 : i32):
+        %idx = linalg.index 1 : index
+        %idx_i32 = arith.index_cast %idx : index to i32
+        %0 = arith.addi %idx_i32, %b1 : i32
+        linalg.yield %0 : i32
+    } -> tensor<8xi32>
+    return %generic : tensor<8xi32>
+  }
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+      %generic tile_sizes = [32]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @partial_reduction_with_linalg_index(
+//       CHECK:   scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
+//       CHECK:       affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[IV0]])[%[[LOCAL_IDX]]]


        


More information about the Mlir-commits mailing list