[Mlir-commits] [mlir] [mlir][linalg][elementwise] Fold broadcast into new elementwise (PR #167626)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 17:50:08 PST 2025
https://github.com/someoneinjd updated https://github.com/llvm/llvm-project/pull/167626
>From 92ebdad8813c4b1b2c3d2e322ee8d735693fb51d Mon Sep 17 00:00:00 2001
From: someoneinjd <someoneinjd at outlook.com>
Date: Wed, 12 Nov 2025 11:02:26 +0800
Subject: [PATCH] [mlir][linalg][elementwise] Fold broadcast into new
elementwise
---
.../Linalg/Transforms/FoldIntoElementwise.cpp | 42 +++--
.../test/Dialect/Linalg/elementwise/fold.mlir | 162 +++++++++++++++++-
2 files changed, 184 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
index b1c0c3b161b20..d023b7af9a33b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -29,7 +29,25 @@ using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-fold-into-elementwise"
namespace {
-struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
+template <typename ProducerOpTy>
+struct ElementwiseOpFolder {
+ static bool fold(OpOperand *operand, AffineMap consumerMap,
+ SmallVectorImpl<Value> &newIns,
+ SmallVectorImpl<AffineMap> &newMaps) {
+ auto producerOp = operand->get().getDefiningOp<ProducerOpTy>();
+ if (!producerOp)
+ return false;
+ newIns.push_back(producerOp.getInput());
+ // push in composed affine map
+ newMaps.push_back(
+ producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
+ .compose(consumerMap));
+ return true;
+ }
+};
+
+template <typename... ProducerOps>
+struct FoldIntoElementwisePattern : public OpRewritePattern<ElementwiseOp> {
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ElementwiseOp op,
@@ -38,20 +56,17 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
SmallVector<Value> newIns;
SmallVector<AffineMap> newMaps;
for (OpOperand *operand : op.getDpsInputOperands()) {
- AffineMap map = op.getMatchingIndexingMap(operand);
- auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
-
- if (!map.isIdentity() || !transposeOp) {
+ AffineMap consumerMap = op.getMatchingIndexingMap(operand);
+ const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
+ operand, consumerMap, newIns, newMaps) ||
+ ...);
+ if (folded) {
+ changed = true;
+ } else {
// push in original operand and its map.
newIns.push_back(operand->get());
- newMaps.push_back(map);
- continue;
+ newMaps.push_back(consumerMap);
}
- newIns.push_back(transposeOp.getInput());
- // push in transposeOp's inverse permutation map.
- newMaps.push_back(transposeOp.getMatchingIndexingMap(
- transposeOp.getDpsInputOperand(0)));
- changed = true;
}
if (!changed)
return failure();
@@ -83,5 +98,6 @@ struct LinalgFoldIntoElementwisePass
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldTransposePattern>(patterns.getContext());
+ patterns.add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
index e83c32fb6a2cf..732b8a90f51d2 100644
--- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -9,11 +9,11 @@
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
-func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty = tensor.empty() : tensor<8x16x32xf32>
- %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
+ %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
- ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}
@@ -28,16 +28,164 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
-func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
- %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
+ %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
- ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//
+// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty = tensor.empty() : tensor<8x16x32xf32>
+ %broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+ %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
+//
+func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %empty_b = tensor.empty() : tensor<32x16xf32>
+ %broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]
+
+ %empty_t = tensor.empty() : tensor<16x32xf32>
+ %transposed_A = linalg.transpose ins(%broadcasted_A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
+
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %result : tensor<16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty_t = tensor.empty() : tensor<16x32xf32>
+ %transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
+
+ %empty_b = tensor.empty() : tensor<8x16x32xf32>
+ %broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<16x32xf32>) outs(%empty_b : tensor<8x16x32xf32>) dimensions = [0]
+
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+
+ %empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]
+
+ %empty_t = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %transposed_A = linalg.transpose ins(%broadcasted_A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
+
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
+//
+func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
+
+ %empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
+
+ %empty_b = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
+ %broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<?x?xf32>) outs(%empty_b : tensor<?x?x?xf32>) dimensions = [0]
+
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%broadcasted_A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %result : tensor<?x?x?xf32>
+}
More information about the Mlir-commits
mailing list