[Mlir-commits] [mlir] [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (PR #140892)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 21 06:07:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Simone Pellegrini (simpel01)

<details>
<summary>Changes</summary>

When fusing two `linalg.genericOp`, where the producer has index semantics, invalid `affine.apply` ops can be generated where the number of indices do not match the number of loops in the fused genericOp.

This patch fixes the issue by directly using the number of loops from the generated fused op.

---
Full diff: https://github.com/llvm/llvm-project/pull/140892.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-2) 
- (modified) mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir (+40) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1f5af39e604e7..f97ed3d6d5111 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
   // `consumerToProducerLoopsMap` to map the producer indices.
   if (producer.hasIndexSemantics()) {
     // Add an index operation for every fused loop dimension.
-    unsigned numFusedOpLoops =
-        std::max(producer.getNumLoops(), consumer.getNumLoops());
+    unsigned numFusedOpLoops = fusedOp.getNumLoops();
     SmallVector<Value> fusedIndices;
     fusedIndices.reserve(numFusedOpLoops);
     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fa..031cb350bfba4 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -860,6 +860,46 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
 
 // -----
 
+func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) ->  tensor<4xi32> {
+  %0 = tensor.empty() : tensor<2x2xi32>
+  %1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
+          ^bb0(%in: i32, %out: i32):
+            %2 = linalg.index 1 : index
+            %3 = arith.index_cast %2 : index to i32
+            linalg.yield %3 : i32
+        } -> tensor<2x2xi32>
+  %4 = tensor.empty() : tensor<4xi32>
+  %5 = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0 floordiv 2, d0 mod 2)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<4xi32>) {
+          ^bb0(%in: i32, %out: i32):
+            linalg.yield %in : i32
+        } -> tensor<4xi32>
+  return %5 : tensor<4xi32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 mod 2)>
+//      CHECK: func @fusion_different_axes_indexed(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x2xi32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty() : tensor<4xi32>
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]]]
+// CHECK-SAME:       outs(%[[INIT]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B0:.+]]: i32
+//  CHECK-DAG:     %[[T0:.+]] = linalg.index 0
+//  CHECK-DAG:     %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[T0]]]
+//  CHECK-DAG:     %[[CAST:.+]] = arith.index_cast %[[T1]] : index to i32
+//      CHECK:     linalg.yield %[[CAST]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @fold_fill_generic_basic
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
 //   CHECK-NOT: linalg.fill

``````````

</details>


https://github.com/llvm/llvm-project/pull/140892


More information about the Mlir-commits mailing list