[Mlir-commits] [mlir] [mlir] [linalg] fix failure on specializing matmul with permuted loops (PR #184294)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 5 09:18:59 PST 2026


https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/184294

>From 0a9bbdbbd2f227be82d53fb8fa7384e091f5d316 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Mon, 2 Mar 2026 20:37:47 +0000
Subject: [PATCH 1/5] fix(linalg-specialize): support reordered loops

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 42 +++++++++++----
 .../Linalg/specialize-generic-ops.mlir        | 51 +++++++++++++++++++
 2 files changed, 84 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d74335e3c08c9..1f2acca0230e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -139,7 +139,8 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 // attribute is needed for the named matmul op variant.
 template <typename NamedOpTy>
 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
-                                         std::optional<TypeFn> castTy) {
+                                         std::optional<TypeFn> castTy,
+                                         ArrayRef<AffineMap> indexingMaps) {
   SmallVector<NamedAttribute> castAttrVec;
   // Only explicitly specify the cast attribute for unsigned cast; signed is
   // the default for linalg.matmul/linalg.batch_matmul.
@@ -147,14 +148,11 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
     castAttrVec = {rewriter.getNamedAttr(
         "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
 
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-
-  LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+  auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
       ValueRange{op.getDpsInits()[0]}, castAttrVec);
 
-  // Set the original generic's maps to preserve transposed operand semantics.
-  namedOp->setAttr("indexing_maps", indexingMaps);
+  namedOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(indexingMaps));
 
   return namedOp;
 }
@@ -299,11 +297,37 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(
         genericOp, "contains invalid cast ops for the named matmul op");
 
-  /// Codegen the different matmul variants.
+  // Build indexing maps for the named op in its canonical dimension ordering
+  auto *ctx = genericOp.getContext();
+  unsigned numLoopDims = numOfBatchDims + 3;
+  unsigned mIdx = numOfBatchDims;
+  unsigned nIdx = mIdx + 1;
+  unsigned kIdx = mIdx + 2;
+
+  auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
+    SmallVector<unsigned> tensorDims;
+    for (unsigned i = 0; i < numOfBatchDims; ++i)
+      tensorDims.push_back(i);
+    if (match == IndexMatchResult::Transposed)
+      llvm::append_values(tensorDims, colIdx, rowIdx);
+    else
+      llvm::append_values(tensorDims, rowIdx, colIdx);
+    return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
+  };
+
+  auto mapA = makeMap(a, mIdx, kIdx);
+  auto mapB = makeMap(b, kIdx, nIdx);
+  auto mapC = makeMap(c, mIdx, nIdx);
+
+  SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
+
+  // Codegen the different matmul variants.
   if (numOfBatchDims) {
-    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
+                                                   namedOpMaps);
   }
-  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
+                                            namedOpMaps);
 }
 
 /// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87218844c5c39..aaa47dd2a47ca 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -530,3 +530,54 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_TC]]]
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Matmul with non-conventional loop ordering (d0=m, d1=k, d2=n).
+#map_nc_a = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nc_b = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map_nc_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                                            %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map_nc_a, #map_nc_b, #map_nc_c],
+          iterator_types = ["parallel", "reduction", "parallel"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_non_conventional_dims
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-conventional loop ordering (d0=batch, d1=m, d2=k, d3=n).
+#map_bnc_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map_bnc_b = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map_bnc_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
+                                                  %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#map_bnc_a, #map_bnc_b, #map_bnc_c],
+            iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+           ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: op_batch_matmul_non_conventional_dims
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+

>From 0cf797976d56a3c8633fb7870f4aa88c15b35607 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Mon, 2 Mar 2026 21:25:53 +0000
Subject: [PATCH 2/5] rename

---
 .../Dialect/Linalg/specialize-generic-ops.mlir   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index aaa47dd2a47ca..bd052c3f98097 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -533,10 +533,10 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 
 // -----
 
-// Matmul with non-conventional loop ordering (d0=m, d1=k, d2=n).
-#map_nc_a = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map_nc_b = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map_nc_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+// Matmul with non-canonical loop ordering.
+#map_nc_a = affine_map<(m, k, n) -> (m, k)>
+#map_nc_b = affine_map<(m, k, n) -> (k, n)>
+#map_nc_c = affine_map<(m, k, n) -> (m, n)>
 func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
                                             %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic
@@ -558,10 +558,10 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
 
 // -----
 
-// Batch matmul with non-conventional loop ordering (d0=batch, d1=m, d2=k, d3=n).
-#map_bnc_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-#map_bnc_b = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map_bnc_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// Batch matmul with non-canonical loop ordering.
+#map_bnc_a = affine_map<(batch, m, k, n) -> (batch, m, k)>
+#map_bnc_b = affine_map<(batch, m, k, n) -> (batch, k, n)>
+#map_bnc_c = affine_map<(batch, m, k, n) -> (batch, m, n)>
 func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
                                                   %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
   %0 = linalg.generic

>From 7cfbdcab24ec1f25911a0cade2078ae76631e94f Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Tue, 3 Mar 2026 07:34:22 +0000
Subject: [PATCH 3/5] rename

---
 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index bd052c3f98097..8058d56607136 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -537,7 +537,7 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 #map_nc_a = affine_map<(m, k, n) -> (m, k)>
 #map_nc_b = affine_map<(m, k, n) -> (k, n)>
 #map_nc_c = affine_map<(m, k, n) -> (m, n)>
-func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+func.func @op_matmul_non_canonical_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
                                             %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic
          {indexing_maps = [#map_nc_a, #map_nc_b, #map_nc_c],
@@ -551,7 +551,7 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
    return %0 : tensor<?x?xf32>
 }
 
-// CHECK-LABEL: op_matmul_non_conventional_dims
+// CHECK-LABEL: op_matmul_non_canonical_loops
 // CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -562,7 +562,7 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
 #map_bnc_a = affine_map<(batch, m, k, n) -> (batch, m, k)>
 #map_bnc_b = affine_map<(batch, m, k, n) -> (batch, k, n)>
 #map_bnc_c = affine_map<(batch, m, k, n) -> (batch, m, n)>
-func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
+func.func @op_batch_matmul_non_canonical_loops(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
                                                   %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
   %0 = linalg.generic
            {indexing_maps = [#map_bnc_a, #map_bnc_b, #map_bnc_c],
@@ -576,7 +576,7 @@ func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: ten
   return %0 : tensor<2x16x16xf32>
 }
 
-// CHECK-LABEL: op_batch_matmul_non_conventional_dims
+// CHECK-LABEL: op_batch_matmul_non_canonical_loops
 // CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>

>From e8f7259e6851800f6fe97fa33a475786653f542d Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Thu, 5 Mar 2026 08:03:41 +0000
Subject: [PATCH 4/5] more tests

---
 .../Linalg/specialize-generic-ops.mlir        | 159 ++++++++++++++++++
 1 file changed, 159 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 8058d56607136..5c58a5fedd639 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -581,3 +581,162 @@ func.func @op_batch_matmul_non_canonical_loops(%A: tensor<2x16x8xf32>, %B: tenso
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 
+// -----
+
+// Matmul with non-canonical loop ordering (d0=m, d1=k, d2=n) and B transposed.
+#map_nc_tb_a = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nc_tb_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_nc_tb_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @op_matmul_non_canonical_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                                                %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map_nc_tb_a, #map_nc_tb_b, #map_nc_tb_c],
+          iterator_types = ["parallel", "reduction", "parallel"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-canonical loop ordering (d0=batch, d1=m, d2=k, d3=n)
+// and B Transposed.
+#map_bnc_tb_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map_bnc_tb_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bnc_tb_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+func.func @op_batch_matmul_non_canonical_transpose_b(%A: tensor<2x16x8xf32>, %B: tensor<2x16x8xf32>,
+                                                      %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#map_bnc_tb_a, #map_bnc_tb_b, #map_bnc_tb_c],
+            iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+           ins(%A, %B : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// Matmul with fully permuted loop ordering.
+#map_fs_a = affine_map<(d0, d1, d2) -> (d1, d0)>
+#map_fs_b = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_fs_c = affine_map<(d0, d1, d2) -> (d1, d2)>
+func.func @op_matmul_fully_shuffled_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                                           %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map_fs_a, #map_fs_b, #map_fs_c],
+          iterator_types = ["reduction", "parallel", "parallel"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_fully_shuffled_loops
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// TODO: this could also be specialized to a named matmul.
+#map_bcast_a = affine_map<(d0, d1, d2) -> (d2)>
+#map_bcast_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_bcast_c = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_broadcast_a(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+                                        %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map_bcast_a, #map_bcast_b, #map_bcast_c],
+          iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast_a = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+#map_bbcast_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bbcast_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_a(%A: tensor<16x8xf32>, %B: tensor<2x8x16xf32>,
+                                              %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#map_bbcast_a, #map_bbcast_b, #map_bbcast_c],
+            iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%A, %B : tensor<16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast2_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map_bbcast2_b = affine_map<(d0, d1, d2, d3) -> (d3)>
+#map_bbcast2_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>,
+                                              %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#map_bbcast2_a, #map_bbcast2_b, #map_bbcast2_c],
+            iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%A, %B : tensor<2x16x8xf32>, tensor<8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_b
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul

>From 21dc0451995fa50f707a88b781d714acc9af5c05 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Thu, 5 Mar 2026 17:18:44 +0000
Subject: [PATCH 5/5] add todo

---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 1f2acca0230e2..b4de2bb1e1169 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -304,6 +304,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   unsigned nIdx = mIdx + 1;
   unsigned kIdx = mIdx + 2;
 
+  // TODO: add support for indexing_maps with broadcasts.
   auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
     SmallVector<unsigned> tensorDims;
     for (unsigned i = 0; i < numOfBatchDims; ++i)



More information about the Mlir-commits mailing list