[Mlir-commits] [mlir] [mlir] [linalg] fix failure on specializing matmul with permuted loops (PR #184294)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 5 09:18:59 PST 2026
https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/184294
>From 0a9bbdbbd2f227be82d53fb8fa7384e091f5d316 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Mon, 2 Mar 2026 20:37:47 +0000
Subject: [PATCH 1/5] fix(linalg-specialize): support reordered loops
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 42 +++++++++++----
.../Linalg/specialize-generic-ops.mlir | 51 +++++++++++++++++++
2 files changed, 84 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d74335e3c08c9..1f2acca0230e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -139,7 +139,8 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// attribute is needed for the named matmul op variant.
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
- std::optional<TypeFn> castTy) {
+ std::optional<TypeFn> castTy,
+ ArrayRef<AffineMap> indexingMaps) {
SmallVector<NamedAttribute> castAttrVec;
// Only explicitly specify the cast attribute for unsigned cast; signed is
// the default for linalg.matmul/linalg.batch_matmul.
@@ -147,14 +148,11 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
castAttrVec = {rewriter.getNamedAttr(
"cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
- ArrayAttr indexingMaps = op.getIndexingMaps();
-
- LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]}, castAttrVec);
- // Set the original generic's maps to preserve transposed operand semantics.
- namedOp->setAttr("indexing_maps", indexingMaps);
+ namedOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(indexingMaps));
return namedOp;
}
@@ -299,11 +297,37 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
genericOp, "contains invalid cast ops for the named matmul op");
- /// Codegen the different matmul variants.
+ // Build indexing maps for the named op in its canonical dimension ordering
+ auto *ctx = genericOp.getContext();
+ unsigned numLoopDims = numOfBatchDims + 3;
+ unsigned mIdx = numOfBatchDims;
+ unsigned nIdx = mIdx + 1;
+ unsigned kIdx = mIdx + 2;
+
+ auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
+ SmallVector<unsigned> tensorDims;
+ for (unsigned i = 0; i < numOfBatchDims; ++i)
+ tensorDims.push_back(i);
+ if (match == IndexMatchResult::Transposed)
+ llvm::append_values(tensorDims, colIdx, rowIdx);
+ else
+ llvm::append_values(tensorDims, rowIdx, colIdx);
+ return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
+ };
+
+ auto mapA = makeMap(a, mIdx, kIdx);
+ auto mapB = makeMap(b, kIdx, nIdx);
+ auto mapC = makeMap(c, mIdx, nIdx);
+
+ SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
+
+ // Codegen the different matmul variants.
if (numOfBatchDims) {
- return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
}
- return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
}
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87218844c5c39..aaa47dd2a47ca 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -530,3 +530,54 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_TC]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Matmul with non-conventional loop ordering (d0=m, d1=k, d2=n).
+#map_nc_a = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nc_b = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map_nc_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_nc_a, #map_nc_b, #map_nc_c],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_non_conventional_dims
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-conventional loop ordering (d0=batch, d1=m, d2=k, d3=n).
+#map_bnc_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map_bnc_b = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map_bnc_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bnc_a, #map_bnc_b, #map_bnc_c],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: op_batch_matmul_non_conventional_dims
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
>From 0cf797976d56a3c8633fb7870f4aa88c15b35607 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Mon, 2 Mar 2026 21:25:53 +0000
Subject: [PATCH 2/5] rename
---
.../Dialect/Linalg/specialize-generic-ops.mlir | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index aaa47dd2a47ca..bd052c3f98097 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -533,10 +533,10 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
// -----
-// Matmul with non-conventional loop ordering (d0=m, d1=k, d2=n).
-#map_nc_a = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map_nc_b = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map_nc_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+// Matmul with non-canonical loop ordering.
+#map_nc_a = affine_map<(m, k, n) -> (m, k)>
+#map_nc_b = affine_map<(m, k, n) -> (k, n)>
+#map_nc_c = affine_map<(m, k, n) -> (m, n)>
func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
%Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic
@@ -558,10 +558,10 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
// -----
-// Batch matmul with non-conventional loop ordering (d0=batch, d1=m, d2=k, d3=n).
-#map_bnc_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-#map_bnc_b = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map_bnc_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// Batch matmul with non-canonical loop ordering.
+#map_bnc_a = affine_map<(batch, m, k, n) -> (batch, m, k)>
+#map_bnc_b = affine_map<(batch, m, k, n) -> (batch, k, n)>
+#map_bnc_c = affine_map<(batch, m, k, n) -> (batch, m, n)>
func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
%Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
%0 = linalg.generic
>From 7cfbdcab24ec1f25911a0cade2078ae76631e94f Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Tue, 3 Mar 2026 07:34:22 +0000
Subject: [PATCH 3/5] rename
---
mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index bd052c3f98097..8058d56607136 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -537,7 +537,7 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
#map_nc_a = affine_map<(m, k, n) -> (m, k)>
#map_nc_b = affine_map<(m, k, n) -> (k, n)>
#map_nc_c = affine_map<(m, k, n) -> (m, n)>
-func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+func.func @op_matmul_non_canonical_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
%Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic
{indexing_maps = [#map_nc_a, #map_nc_b, #map_nc_c],
@@ -551,7 +551,7 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: op_matmul_non_conventional_dims
+// CHECK-LABEL: op_matmul_non_canonical_loops
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -562,7 +562,7 @@ func.func @op_matmul_non_conventional_dims(%A: tensor<?x?xf32>, %B: tensor<?x?xf
#map_bnc_a = affine_map<(batch, m, k, n) -> (batch, m, k)>
#map_bnc_b = affine_map<(batch, m, k, n) -> (batch, k, n)>
#map_bnc_c = affine_map<(batch, m, k, n) -> (batch, m, n)>
-func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
+func.func @op_batch_matmul_non_canonical_loops(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
%Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
%0 = linalg.generic
{indexing_maps = [#map_bnc_a, #map_bnc_b, #map_bnc_c],
@@ -576,7 +576,7 @@ func.func @op_batch_matmul_non_conventional_dims(%A: tensor<2x16x8xf32>, %B: ten
return %0 : tensor<2x16x16xf32>
}
-// CHECK-LABEL: op_batch_matmul_non_conventional_dims
+// CHECK-LABEL: op_batch_matmul_non_canonical_loops
// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
>From e8f7259e6851800f6fe97fa33a475786653f542d Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Thu, 5 Mar 2026 08:03:41 +0000
Subject: [PATCH 4/5] more tests
---
.../Linalg/specialize-generic-ops.mlir | 159 ++++++++++++++++++
1 file changed, 159 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 8058d56607136..5c58a5fedd639 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -581,3 +581,162 @@ func.func @op_batch_matmul_non_canonical_loops(%A: tensor<2x16x8xf32>, %B: tenso
// CHECK-NOT: linalg.generic
// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// -----
+
+// Matmul with non-canonical loop ordering (d0=m, d1=k, d2=n) and B transposed.
+#map_nc_tb_a = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nc_tb_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_nc_tb_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @op_matmul_non_canonical_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_nc_tb_a, #map_nc_tb_b, #map_nc_tb_c],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-canonical loop ordering (d0=batch, d1=m, d2=k, d3=n)
+// and B Transposed.
+#map_bnc_tb_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map_bnc_tb_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bnc_tb_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+func.func @op_batch_matmul_non_canonical_transpose_b(%A: tensor<2x16x8xf32>, %B: tensor<2x16x8xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bnc_tb_a, #map_bnc_tb_b, #map_bnc_tb_c],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// Matmul with fully permuted loop ordering.
+#map_fs_a = affine_map<(d0, d1, d2) -> (d1, d0)>
+#map_fs_b = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_fs_c = affine_map<(d0, d1, d2) -> (d1, d2)>
+func.func @op_matmul_fully_shuffled_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_fs_a, #map_fs_b, #map_fs_c],
+ iterator_types = ["reduction", "parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_fully_shuffled_loops
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// TODO: this could also be specialized to a named matmul.
+#map_bcast_a = affine_map<(d0, d1, d2) -> (d2)>
+#map_bcast_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_bcast_c = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_broadcast_a(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bcast_a, #map_bcast_b, #map_bcast_c],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast_a = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+#map_bbcast_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bbcast_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_a(%A: tensor<16x8xf32>, %B: tensor<2x8x16xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bbcast_a, #map_bbcast_b, #map_bbcast_c],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast2_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map_bbcast2_b = affine_map<(d0, d1, d2, d3) -> (d3)>
+#map_bbcast2_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bbcast2_a, #map_bbcast2_b, #map_bbcast2_c],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_b
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul
>From 21dc0451995fa50f707a88b781d714acc9af5c05 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Thu, 5 Mar 2026 17:18:44 +0000
Subject: [PATCH 5/5] add todo
---
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 1f2acca0230e2..b4de2bb1e1169 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -304,6 +304,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
unsigned nIdx = mIdx + 1;
unsigned kIdx = mIdx + 2;
+ // TODO: add support for indexing_maps with broadcasts.
auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
SmallVector<unsigned> tensorDims;
for (unsigned i = 0; i < numOfBatchDims; ++i)
More information about the Mlir-commits
mailing list