[Mlir-commits] [mlir] [mlir][linalg][elementwise] Fold broadcast into new elementwise (PR #167626)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 18:07:32 PDT 2026


https://github.com/someoneinjd updated https://github.com/llvm/llvm-project/pull/167626

>From eb33115245da21c509b6f43b9c9124529afdbf62 Mon Sep 17 00:00:00 2001
From: someoneinjd <someoneinjd at outlook.com>
Date: Wed, 12 Nov 2025 11:02:26 +0800
Subject: [PATCH] [mlir][linalg][elementwise] Fold broadcast into new
 elementwise

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  11 +-
 .../Linalg/Transforms/FoldIntoElementwise.cpp |  44 ++--
 .../test/Dialect/Linalg/elementwise/fold.mlir | 218 +++++++++++++++++-
 3 files changed, 248 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 26638b2a644c4..b873f260e7d92 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -169,14 +169,15 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
 }
 
 def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
-  let summary = "Fold transpose ops into elementwise";
+  let summary = "Fold transpose and broadcast ops into elementwise";
   let dependentDialects = ["linalg::LinalgDialect"];
 
   let description = [{
-    Fold transpose ops that feed `linalg.elementwise` into the elementwise op
-    by updating its indexing maps. `linalg.transpose` producers whose consumer
-    indexing map is the identity are absorbed, turning the permutation into
-    the elementwise map itself. Other operands remain untouched.
+    Fold transpose or broadcast op that feeds a `linalg.elementwise` into the
+    elementwise op. `linalg.transpose` and `linalg.broadcast` producers whose
+    consumer indexing map is a projected permutation can be absorbed into the
+    indexing map of the `linalg.elementwise` by composing the producer's map
+    into the elementwise op's indexing map. Other operands remain untouched.
   }];
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
index b1c0c3b161b20..0be128c3b5e87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -29,7 +29,27 @@ using namespace mlir::linalg;
 #define DEBUG_TYPE "linalg-fold-into-elementwise"
 
 namespace {
-struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
+template <typename ProducerOpTy>
+struct ElementwiseOpFolder {
+  // Helper function to fold broadcast etc into elementwise op.
+  // Producer in this context is `broadcast op` etc, consumer is elwise operand.
+  static bool fold(OpOperand *elwiseOperand, AffineMap elwiseMap,
+                   SmallVector<Value> &newIns,
+                   SmallVector<AffineMap> &newMaps) {
+    auto producerOp = elwiseOperand->get().getDefiningOp<ProducerOpTy>();
+    if (!producerOp || !elwiseMap.isProjectedPermutation())
+      return false;
+    newIns.push_back(producerOp.getInput());
+    // push in the new composed affine map
+    newMaps.push_back(
+        producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
+            .compose(elwiseMap));
+    return true;
+  }
+};
+
+template <typename... ProducerOps>
+struct FoldIntoElementwisePattern : public OpRewritePattern<ElementwiseOp> {
   using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ElementwiseOp op,
@@ -38,20 +58,17 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
     SmallVector<Value> newIns;
     SmallVector<AffineMap> newMaps;
     for (OpOperand *operand : op.getDpsInputOperands()) {
-      AffineMap map = op.getMatchingIndexingMap(operand);
-      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
-
-      if (!map.isIdentity() || !transposeOp) {
+      AffineMap consumerMap = op.getMatchingIndexingMap(operand);
+      const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
+                               operand, consumerMap, newIns, newMaps) ||
+                           ...);
+      if (folded) {
+        changed = true;
+      } else {
         // push in original operand and its map.
         newIns.push_back(operand->get());
-        newMaps.push_back(map);
-        continue;
+        newMaps.push_back(consumerMap);
       }
-      newIns.push_back(transposeOp.getInput());
-      // push in transposeOp's inverse permutation map.
-      newMaps.push_back(transposeOp.getMatchingIndexingMap(
-          transposeOp.getDpsInputOperand(0)));
-      changed = true;
     }
     if (!changed)
       return failure();
@@ -83,5 +100,6 @@ struct LinalgFoldIntoElementwisePass
 
 void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldTransposePattern>(patterns.getContext());
+  patterns.add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
index e83c32fb6a2cf..80fd90f3d4dbe 100644
--- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -9,11 +9,11 @@
 // CHECK-SAME:       ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
 // CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
 //
-func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
   %empty = tensor.empty() : tensor<8x16x32xf32>
-  %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty :  tensor<8x16x32xf32>) permutation = [1, 0, 2]
+  %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
   %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
-                          ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+                          ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
   return %result : tensor<8x16x32xf32>
 }
 
@@ -28,16 +28,220 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
 // CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
 //
-func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) ->  tensor<?x?xf32> {
+func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
   %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
 
   %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
-  %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty :  tensor<?x?xf32>) permutation = [1, 0]
+  %transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
   %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
-                          ins(%A, %transposed_B : tensor<?x?xf32>,  tensor<?x?xf32>)
-                          outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+                          ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
+                          outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//
+// CHECK:  func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+  %empty = tensor.empty() : tensor<8x16x32xf32>
+  %broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK:  func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+  %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
+                          outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK:  func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:              indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<16x32xf32>
+//
+#identity = affine_map<(d0, d1) -> (d0, d1)>
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %empty_b = tensor.empty() : tensor<32x16xf32>
+
+  %broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]
+
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          indexing_maps = [#transpose, #identity]
+                          ins(%broadcasted_A : tensor<32x16xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %result : tensor<16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK:  func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:              indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<8x16x32xf32>
+//
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#broadcast = affine_map<(d0, d1, d2) -> (d1, d2)>
+
+func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+  %empty_t = tensor.empty() : tensor<16x32xf32>
+  %transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
+
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          indexing_maps = [#broadcast, #identity]
+                          ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK:  func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+#identity = affine_map<(d0, d1) -> (d0, d1)>
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+
+  %empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]
+
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          indexing_maps = [#transpose, #identity, #identity]
+                          ins(%broadcasted_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK:  func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?x?xf32>
+//
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#broadcast = affine_map<(d0, d1, d2) -> (d1, d2)>
+
+func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
+
+  %empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
+
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          indexing_maps = [#broadcast, #identity, #identity]
+                          ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %result : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[DIAGONAL:.+]] = affine_map<(d0) -> (d0, d0)>
+//
+// CHECK:  func.func @fold_failed_diagonal_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
+// CHECK-NEXT:  %[[EMPTY:.+]] = tensor.empty() : tensor<16x16xf32>
+// CHECK-NEXT:  %[[BROADCASTED_B:.+]] = linalg.broadcast ins(%[[B]] : tensor<16xf32>) outs(%[[EMPTY]] : tensor<16x16xf32>) dimensions = [0]
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[DIAGONAL]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[BROADCASTED_B]] : tensor<16xf32>, tensor<16x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<16xf32>
+//
+#identity = affine_map<(d0) -> (d0)>
+#diagonal = affine_map<(d0) -> (d0, d0)>
+
+func.func @fold_failed_diagonal_map(%A: tensor<16xf32>, %B: tensor<16xf32>, %C: tensor<16xf32>) -> tensor<16xf32> {
+  %empty = tensor.empty() : tensor<16x16xf32>
+  %broadcasted_B = linalg.broadcast ins(%B : tensor<16xf32>) outs(%empty : tensor<16x16xf32>) dimensions = [0]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          indexing_maps = [#identity, #diagonal, #identity]
+                          ins(%A, %broadcasted_B : tensor<16xf32>, tensor<16x16xf32>) outs(%C : tensor<16xf32>) -> tensor<16xf32>
+  return %result : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[CONSTANT:.+]] = affine_map<(d0) -> (0, d0)>
+//
+// CHECK:  func.func @fold_failed_constant_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
+// CHECK-NEXT:  %[[EMPTY:.+]] = tensor.empty() : tensor<32x16xf32>
+// CHECK-NEXT:  %[[TRANSPOSED_B:.+]] = linalg.transpose ins(%[[B]] : tensor<16x32xf32>) outs(%[[EMPTY]] : tensor<32x16xf32>) permutation = [1, 0]
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[CONSTANT]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[TRANSPOSED_B]] : tensor<16xf32>, tensor<32x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<16xf32>
+//
+#identity = affine_map<(d0) -> (d0)>
+#constant = affine_map<(d0) -> (0, d0)>
+
+func.func @fold_failed_constant_map(%A: tensor<16xf32>, %B: tensor<16x32xf32>, %C: tensor<16xf32>) -> tensor<16xf32> {
+  %empty = tensor.empty() : tensor<32x16xf32>
+  %transposed_B = linalg.transpose ins(%B : tensor<16x32xf32>) outs(%empty : tensor<32x16xf32>) permutation = [1, 0]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          indexing_maps = [#identity, #constant, #identity]
+                          ins(%A, %transposed_B : tensor<16xf32>, tensor<32x16xf32>) outs(%C : tensor<16xf32>) -> tensor<16xf32>
+  return %result : tensor<16xf32>
+}



More information about the Mlir-commits mailing list