[Mlir-commits] [mlir] [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (PR #140892)

Simone Pellegrini llvmlistbot at llvm.org
Thu Jun 12 22:54:51 PDT 2025


https://github.com/simpel01 updated https://github.com/llvm/llvm-project/pull/140892

>From 7a6a8f061e74c2795ff8c18e2aac98f18d0c2607 Mon Sep 17 00:00:00 2001
From: Simone Pellegrini <simone.pellegrini at arm.com>
Date: Wed, 14 May 2025 10:15:21 +0200
Subject: [PATCH] [mlir][Linalg] Fix fusing of indexed linalg consumer with
 different axes

When fusing two `linalg.genericOp`, where the producer has index
semantics, invalid `affine.apply` ops can be generated where the number
of indices do not match the number of loops in the fused genericOp.

This patch fixes the issue by directly using the number of loops from the
generated fused op.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  3 +-
 .../Linalg/fusion-elementwise-ops.mlir        | 37 +++++++++++++++++++
 2 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1f5af39e604e7..f97ed3d6d5111 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
   // `consumerToProducerLoopsMap` to map the producer indices.
   if (producer.hasIndexSemantics()) {
     // Add an index operation for every fused loop dimension.
-    unsigned numFusedOpLoops =
-        std::max(producer.getNumLoops(), consumer.getNumLoops());
+    unsigned numFusedOpLoops = fusedOp.getNumLoops();
     SmallVector<Value> fusedIndices;
     fusedIndices.reserve(numFusedOpLoops);
     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fa..66fc55fadf8fa 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
 
 // -----
 
+func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) ->  tensor<2xi32> {
+  %0 = tensor.empty() : tensor<2x2xi32>
+  %1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
+          ^bb0(%in: i32, %out: i32):
+            %2 = linalg.index 1 : index
+            %3 = arith.index_cast %2 : index to i32
+            linalg.yield %3 : i32
+        } -> tensor<2x2xi32>
+  %4 = tensor.empty() : tensor<2xi32>
+  %5 = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) {
+          ^bb0(%in: i32, %out: i32):
+            linalg.yield %in : i32
+        } -> tensor<2xi32>
+  return %5 : tensor<2xi32>
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
+//      CHECK: func @fusion_different_axes_indexed(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x2xi32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 1 : i32
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty() : tensor<2xi32>
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP]]]
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B0:.+]]: i32
+//      CHECK:     linalg.yield %[[CST]] : i32
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @fold_fill_generic_basic
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
 //   CHECK-NOT: linalg.fill



More information about the Mlir-commits mailing list