[Mlir-commits] [mlir] [mlir][linalg] fix specialization of transposed matmul variants (PR #181387)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 13 12:05:39 PST 2026
https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/181387
>From 66fd81eedbfa7435b381ea81646a7b7eb14358c3 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 16:53:13 +0000
Subject: [PATCH 1/5] init
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 18 +-
.../Linalg/specialize-generic-ops.mlir | 171 ++++++++++++++++++
2 files changed, 187 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..512722196aecc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -147,10 +147,10 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
castAttrVec = {rewriter.getNamedAttr(
"cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
- LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]}, castAttrVec);
- return namedOp;
+ return cast<LinalgOp>(namedOp.getOperation());
}
// Returns the cast type to use for a matmul-like named op. If the generic
@@ -299,8 +299,22 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+ genericOp,
+ castTy);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+ genericOp,
+ castTy);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
+ castTy);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
+ castTy);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..5b2609da662c4 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -314,6 +314,73 @@ func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
// -----
+// Both A and B transposed: no named op variant exists for this combination.
+#mapABt0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#mapABt1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#mapABt2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_transpose_a_and_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapABt0, #mapABt1, #mapABt2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_transpose_a_and_b
+// CHECK: linalg.generic
+
+// -----
+
+// Output transposed: C is accessed as (n, m) instead of (m, n).
+#mapCt0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapCt1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapCt2 = affine_map<(d0, d1, d2) -> (d1, d0)>
+func.func @negative_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapCt0, #mapCt1, #mapCt2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_transposed_output
+// CHECK: linalg.generic
+
+// -----
+
+// Batch dim not in identity position: batch dim d0 appears at result
+// position 1 in A's map instead of position 0.
+#mapBni0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
+#mapBni1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapBni2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: tensor<2x8x16xf32>,
+ %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapBni0, #mapBni1, #mapBni2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<4x2x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x4x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x4x16xf32>
+ return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_non_identity_batch
+// CHECK: linalg.generic
+
+// -----
+
// TODO: matvec
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
@@ -331,3 +398,107 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
}
// CHECK-LABEL: op_matvec
// CHECK: linalg.generic
+
+// -----
+
+// Matmul transpose A: A is accessed as (k, m) instead of (m, k)
+#mapTA0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#mapTA1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapTA2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_transpose_a(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapTA0, #mapTA1, #mapTA2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_TA_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP_TA_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_TA_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TA_A]], #[[$MAP_TA_B]], #[[$MAP_TA_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Matmul transpose B: B is accessed as (n, k) instead of (k, n)
+#mapTB0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapTB1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#mapTB2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapTB0, #mapTB1, #mapTB2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_TB_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_TB_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TB_A]], #[[$MAP_TB_B]], #[[$MAP_TB_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul transpose A: A is accessed as (b, k, m) instead of (b, m, k)
+#mapBTA0 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#mapBTA1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapBTA2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul_transpose_a(%A: tensor<2x8x4xf32>, %B: tensor<2x8x16xf32>, %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapBTA0, #mapBTA1, #mapBTA2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x4x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x4x16xf32>
+ return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_BTA_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$MAP_BTA_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP_BTA_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTA_A]], #[[$MAP_BTA_B]], #[[$MAP_BTA_C]]] ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+
+// -----
+
+// Batch matmul transpose B: B is accessed as (b, n, k) instead of (b, k, n)
+#mapBTB0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapBTB1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#mapBTB2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul_transpose_b(%A: tensor<2x4x8xf32>, %B: tensor<2x16x8xf32>, %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapBTB0, #mapBTB1, #mapBTB2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%Out : tensor<2x4x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x4x16xf32>
+ return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_BTB_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_BTB_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_BTB_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTB_A]], #[[$MAP_BTB_B]], #[[$MAP_BTB_C]]] ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
>From 17ed89e7b4adbd62dea6e874106052b7694212b6 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 17:11:08 +0000
Subject: [PATCH 2/5] formatting
---
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 512722196aecc..5b60de6d96543 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -300,18 +300,18 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp,
- castTy);
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(
+ rewriter, genericOp, castTy);
if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp,
- castTy);
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(
+ rewriter, genericOp, castTy);
+
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
castTy);
+
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
castTy);
>From eefb9da774e09d4dbebf9cc762868753b1250e47 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:02:32 +0000
Subject: [PATCH 3/5] address comments
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 31 +++++------
.../Linalg/specialize-generic-ops.mlir | 52 ++++++++++++-------
2 files changed, 44 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 5b60de6d96543..4a00545384972 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -133,8 +133,8 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
// All the variants expressed as pseudo regular expression:
-// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-// have same number of ins/out, so its easy to stamp different versions.
+// `linalg.{batch_}?matmul` have same number of ins/out, so its easy to
+// stamp different versions.
// `castTy` is an optional type function that indicates whether (and which) cast
// attribute is needed for the named matmul op variant.
template <typename NamedOpTy>
@@ -147,10 +147,16 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
castAttrVec = {rewriter.getNamedAttr(
"cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
- auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]}, castAttrVec);
- return cast<LinalgOp>(namedOp.getOperation());
+
+ // Set the original generic's maps to preserve transposed operand semantics.
+ namedOp->setAttr("indexing_maps", indexingMaps);
+
+ return namedOp;
}
// Returns the cast type to use for a matmul-like named op. If the generic
@@ -298,23 +304,10 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
genericOp, "contains invalid cast ops for the named matmul op");
/// Codegen the different matmul variants.
- if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(
- rewriter, genericOp, castTy);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(
- rewriter, genericOp, castTy);
-
+ /// maps to express transposed operands on the standard op.
+ if (numOfBatchDims)
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
- }
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
- castTy);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
- castTy);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 5b2609da662c4..43d19734e534d 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -417,13 +417,16 @@ func.func @op_matmul_transpose_a(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
return %0 : tensor<?x?xf32>
}
-// CHECK-DAG: #[[$MAP_TA_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK-DAG: #[[$MAP_TA_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$MAP_TA_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: op_matmul_transpose_a
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TA_A]], #[[$MAP_TA_B]], #[[$MAP_TA_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
@@ -443,13 +446,16 @@ func.func @op_matmul_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
return %0 : tensor<?x?xf32>
}
-// CHECK-DAG: #[[$MAP_TB_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$MAP_TB_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$MAP_TB_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: op_matmul_transpose_b
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TB_A]], #[[$MAP_TB_B]], #[[$MAP_TB_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
@@ -469,13 +475,16 @@ func.func @op_batch_matmul_transpose_a(%A: tensor<2x8x4xf32>, %B: tensor<2x8x16x
return %0 : tensor<2x4x16xf32>
}
-// CHECK-DAG: #[[$MAP_BTA_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
-// CHECK-DAG: #[[$MAP_BTA_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[$MAP_BTA_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-LABEL: op_batch_matmul_transpose_a
-// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTA_A]], #[[$MAP_BTA_B]], #[[$MAP_BTA_C]]] ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
// -----
@@ -495,10 +504,13 @@ func.func @op_batch_matmul_transpose_b(%A: tensor<2x4x8xf32>, %B: tensor<2x16x8x
return %0 : tensor<2x4x16xf32>
}
-// CHECK-DAG: #[[$MAP_BTB_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[$MAP_BTB_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK-DAG: #[[$MAP_BTB_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-LABEL: op_batch_matmul_transpose_b
-// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTB_A]], #[[$MAP_BTB_B]], #[[$MAP_BTB_C]]] ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
>From 306f4622c333fc6c94995e942562a17d0e69330a Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:04:36 +0000
Subject: [PATCH 4/5] style
---
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4a00545384972..9da04727c271e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -304,9 +304,9 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
genericOp, "contains invalid cast ops for the named matmul op");
/// Codegen the different matmul variants.
- /// maps to express transposed operands on the standard op.
- if (numOfBatchDims)
+ if (numOfBatchDims) {
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
+ }
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
>From b098addabda3b90ab88a96c9a340047d489e1121 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:05:24 +0000
Subject: [PATCH 5/5] whitespace
---
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 9da04727c271e..05370f4aae9eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -307,7 +307,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (numOfBatchDims) {
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
-
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
More information about the Mlir-commits
mailing list