[Mlir-commits] [mlir] [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (PR #140892)
Simone Pellegrini
llvmlistbot at llvm.org
Wed May 21 06:06:42 PDT 2025
https://github.com/simpel01 created https://github.com/llvm/llvm-project/pull/140892
When fusing two `linalg.genericOp`, where the producer has index semantics, invalid `affine.apply` ops can be generated where the number of indices do not match the number of loops in the fused genericOp.
This patch fixes the issue by directly using the number of loops from the generated fused op.
>From cbd670e408b75daa952b525d811ea0691f0e9897 Mon Sep 17 00:00:00 2001
From: Simone Pellegrini <simone.pellegrini at arm.com>
Date: Wed, 14 May 2025 10:15:21 +0200
Subject: [PATCH] [mlir][Linalg] Fix fusing of indexed linalg consumer with
different axes
When fusing two `linalg.genericOp`, where the producer has index
semantics, invalid `affine.apply` ops can be generated where the number
of indices do not match the number of loops in the fused genericOp.
This patch fixes the issue by directly using the number of loops from the
generated fused op.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 3 +-
.../Linalg/fusion-elementwise-ops.mlir | 40 +++++++++++++++++++
2 files changed, 41 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1f5af39e604e7..f97ed3d6d5111 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
// `consumerToProducerLoopsMap` to map the producer indices.
if (producer.hasIndexSemantics()) {
// Add an index operation for every fused loop dimension.
- unsigned numFusedOpLoops =
- std::max(producer.getNumLoops(), consumer.getNumLoops());
+ unsigned numFusedOpLoops = fusedOp.getNumLoops();
SmallVector<Value> fusedIndices;
fusedIndices.reserve(numFusedOpLoops);
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fa..031cb350bfba4 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -860,6 +860,46 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
// -----
+func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<4xi32> {
+ %0 = tensor.empty() : tensor<2x2xi32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ %2 = linalg.index 1 : index
+ %3 = arith.index_cast %2 : index to i32
+ linalg.yield %3 : i32
+ } -> tensor<2x2xi32>
+ %4 = tensor.empty() : tensor<4xi32>
+ %5 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0 floordiv 2, d0 mod 2)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<4xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %in : i32
+ } -> tensor<4xi32>
+ return %5 : tensor<4xi32>
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 mod 2)>
+// CHECK: func @fusion_different_axes_indexed(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<4xi32>
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]]]
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-NEXT: ^bb0(
+// CHECK-SAME: %[[B0:.+]]: i32
+// CHECK-DAG: %[[T0:.+]] = linalg.index 0
+// CHECK-DAG: %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[T0]]]
+// CHECK-DAG: %[[CAST:.+]] = arith.index_cast %[[T1]] : index to i32
+// CHECK: linalg.yield %[[CAST]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @fold_fill_generic_basic
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NOT: linalg.fill
More information about the Mlir-commits
mailing list