[Mlir-commits] [mlir] 4b59b7b - [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 02:03:11 PDT 2025
Author: Simone Pellegrini
Date: 2025-06-13T10:03:09+01:00
New Revision: 4b59b7b94608ddbd21d14bec68400f2eb21f510d
URL: https://github.com/llvm/llvm-project/commit/4b59b7b94608ddbd21d14bec68400f2eb21f510d
DIFF: https://github.com/llvm/llvm-project/commit/4b59b7b94608ddbd21d14bec68400f2eb21f510d.diff
LOG: [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)
When fusing two `linalg.genericOp`, where the producer has index
semantics, invalid `affine.apply` ops can be generated where the number
of indices do not match the number of loops in the fused genericOp.
This patch fixes the issue by directly using the number of loops from
the generated fused op.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1f5af39e604e7..f97ed3d6d5111 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
// `consumerToProducerLoopsMap` to map the producer indices.
if (producer.hasIndexSemantics()) {
// Add an index operation for every fused loop dimension.
- unsigned numFusedOpLoops =
- std::max(producer.getNumLoops(), consumer.getNumLoops());
+ unsigned numFusedOpLoops = fusedOp.getNumLoops();
SmallVector<Value> fusedIndices;
fusedIndices.reserve(numFusedOpLoops);
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fa..66fc55fadf8fa 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -860,6 +860,43 @@ func.func @fusion_
diff erent_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
// -----
+func.func @fusion_
diff erent_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
+ %0 = tensor.empty() : tensor<2x2xi32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ %2 = linalg.index 1 : index
+ %3 = arith.index_cast %2 : index to i32
+ linalg.yield %3 : i32
+ } -> tensor<2x2xi32>
+ %4 = tensor.empty() : tensor<2xi32>
+ %5 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %in : i32
+ } -> tensor<2xi32>
+ return %5 : tensor<2xi32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @fusion_
diff erent_axes_indexed(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32>
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-NEXT: ^bb0(
+// CHECK-SAME: %[[B0:.+]]: i32
+// CHECK: linalg.yield %[[CST]] : i32
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @fold_fill_generic_basic
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NOT: linalg.fill
More information about the Mlir-commits
mailing list