[Mlir-commits] [mlir] [mlir][linalg] fix specialization of transposed matmul variants (PR #181387)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 13 12:05:39 PST 2026


https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/181387

>From 66fd81eedbfa7435b381ea81646a7b7eb14358c3 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 16:53:13 +0000
Subject: [PATCH 1/5] init

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  18 +-
 .../Linalg/specialize-generic-ops.mlir        | 171 ++++++++++++++++++
 2 files changed, 187 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..512722196aecc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -147,10 +147,10 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
     castAttrVec = {rewriter.getNamedAttr(
         "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
 
-  LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+  auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
       ValueRange{op.getDpsInits()[0]}, castAttrVec);
-  return namedOp;
+  return cast<LinalgOp>(namedOp.getOperation());
 }
 
 // Returns the cast type to use for a matmul-like named op. If the generic
@@ -299,8 +299,22 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
+    if (a == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+                                                               genericOp,
+                                                               castTy);
+    if (b == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+                                                               genericOp,
+                                                               castTy);
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
+  if (a == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
+                                                        castTy);
+  if (b == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
+                                                        castTy);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..5b2609da662c4 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -314,6 +314,73 @@ func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
 
 // -----
 
+// Both A and B transposed: no named op variant exists for this combination.
+#mapABt0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#mapABt1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#mapABt2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_transpose_a_and_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                                              %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#mapABt0, #mapABt1, #mapABt2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_transpose_a_and_b
+// CHECK: linalg.generic
+
+// -----
+
+// Output transposed: C is accessed as (n, m) instead of (m, n).
+#mapCt0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapCt1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapCt2 = affine_map<(d0, d1, d2) -> (d1, d0)>
+func.func @negative_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                                              %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#mapCt0, #mapCt1, #mapCt2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_transposed_output
+// CHECK: linalg.generic
+
+// -----
+
+// Batch dim not in identity position: batch dim d0 appears at result
+// position 1 in A's map instead of position 0.
+#mapBni0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
+#mapBni1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapBni2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: tensor<2x8x16xf32>,
+                                                     %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#mapBni0, #mapBni1, #mapBni2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%A, %B : tensor<4x2x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x4x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x4x16xf32>
+  return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_non_identity_batch
+// CHECK: linalg.generic
+
+// -----
+
 // TODO: matvec
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1)>
@@ -331,3 +398,107 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 }
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
+
+// -----
+
+// Matmul transpose A: A is accessed as (k, m) instead of (m, k)
+#mapTA0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#mapTA1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapTA2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_transpose_a(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#mapTA0, #mapTA1, #mapTA2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_TA_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP_TA_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_TA_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TA_A]], #[[$MAP_TA_B]], #[[$MAP_TA_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Matmul transpose B: B is accessed as (n, k) instead of (k, n)
+#mapTB0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapTB1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#mapTB2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#mapTB0, #mapTB1, #mapTB2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+   ^bb0(%in: f32, %in_0: f32, %out: f32):
+     %1 = arith.mulf %in, %in_0 : f32
+     %2 = arith.addf %out, %1 : f32
+     linalg.yield %2 : f32
+   } -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_TB_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_TB_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TB_A]], #[[$MAP_TB_B]], #[[$MAP_TB_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul transpose A: A is accessed as (b, k, m) instead of (b, m, k)
+#mapBTA0 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#mapBTA1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapBTA2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul_transpose_a(%A: tensor<2x8x4xf32>, %B: tensor<2x8x16xf32>, %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#mapBTA0, #mapBTA1, #mapBTA2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%A, %B : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x4x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x4x16xf32>
+  return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_BTA_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$MAP_BTA_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP_BTA_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTA_A]], #[[$MAP_BTA_B]], #[[$MAP_BTA_C]]] ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+
+// -----
+
+// Batch matmul transpose B: B is accessed as (b, n, k) instead of (b, k, n)
+#mapBTB0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapBTB1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#mapBTB2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul_transpose_b(%A: tensor<2x4x8xf32>, %B: tensor<2x16x8xf32>, %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#mapBTB0, #mapBTB1, #mapBTB2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%A, %B : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%Out : tensor<2x4x16xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<2x4x16xf32>
+  return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_BTB_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_BTB_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_BTB_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTB_A]], #[[$MAP_BTB_B]], #[[$MAP_BTB_C]]] ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>

>From 17ed89e7b4adbd62dea6e874106052b7694212b6 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 17:11:08 +0000
Subject: [PATCH 2/5] formatting

---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 512722196aecc..5b60de6d96543 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -300,18 +300,18 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
     if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
-                                                               genericOp,
-                                                               castTy);
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(
+          rewriter, genericOp, castTy);
     if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
-                                                               genericOp,
-                                                               castTy);
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(
+          rewriter, genericOp, castTy);
+
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
   if (a == IndexMatchResult::Transposed)
     return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
                                                         castTy);
+
   if (b == IndexMatchResult::Transposed)
     return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
                                                         castTy);

>From eefb9da774e09d4dbebf9cc762868753b1250e47 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:02:32 +0000
Subject: [PATCH 3/5] address comments

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 31 +++++------
 .../Linalg/specialize-generic-ops.mlir        | 52 ++++++++++++-------
 2 files changed, 44 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 5b60de6d96543..4a00545384972 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -133,8 +133,8 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 
 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
 // All the variants expressed as pseudo regular expression:
-// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-// have same number of ins/out, so its easy to stamp different versions.
+// `linalg.{batch_}?matmul` have same number of ins/out, so its easy to
+// stamp different versions.
 // `castTy` is an optional type function that indicates whether (and which) cast
 // attribute is needed for the named matmul op variant.
 template <typename NamedOpTy>
@@ -147,10 +147,16 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
     castAttrVec = {rewriter.getNamedAttr(
         "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
 
-  auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
       ValueRange{op.getDpsInits()[0]}, castAttrVec);
-  return cast<LinalgOp>(namedOp.getOperation());
+
+  // Set the original generic's maps to preserve transposed operand semantics.
+  namedOp->setAttr("indexing_maps", indexingMaps);
+
+  return namedOp;
 }
 
 // Returns the cast type to use for a matmul-like named op. If the generic
@@ -298,23 +304,10 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
         genericOp, "contains invalid cast ops for the named matmul op");
 
   /// Codegen the different matmul variants.
-  if (numOfBatchDims) {
-    if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(
-          rewriter, genericOp, castTy);
-    if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(
-          rewriter, genericOp, castTy);
-
+  /// maps to express transposed operands on the standard op.
+  if (numOfBatchDims)
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
-  }
-  if (a == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp,
-                                                        castTy);
 
-  if (b == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp,
-                                                        castTy);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 5b2609da662c4..43d19734e534d 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -417,13 +417,16 @@ func.func @op_matmul_transpose_a(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
    return %0 : tensor<?x?xf32>
 }
 
-// CHECK-DAG: #[[$MAP_TA_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK-DAG: #[[$MAP_TA_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$MAP_TA_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: op_matmul_transpose_a
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TA_A]], #[[$MAP_TA_B]], #[[$MAP_TA_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
 // -----
 
@@ -443,13 +446,16 @@ func.func @op_matmul_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
    return %0 : tensor<?x?xf32>
 }
 
-// CHECK-DAG: #[[$MAP_TB_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$MAP_TB_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$MAP_TB_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: op_matmul_transpose_b
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul indexing_maps = [#[[$MAP_TB_A]], #[[$MAP_TB_B]], #[[$MAP_TB_C]]] ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
 // -----
 
@@ -469,13 +475,16 @@ func.func @op_batch_matmul_transpose_a(%A: tensor<2x8x4xf32>, %B: tensor<2x8x16x
   return %0 : tensor<2x4x16xf32>
 }
 
-// CHECK-DAG: #[[$MAP_BTA_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
-// CHECK-DAG: #[[$MAP_BTA_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[$MAP_BTA_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 // CHECK-LABEL: op_batch_matmul_transpose_a
-// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTA_A]], #[[$MAP_BTA_B]], #[[$MAP_BTA_C]]] ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
 
 // -----
 
@@ -495,10 +504,13 @@ func.func @op_batch_matmul_transpose_b(%A: tensor<2x4x8xf32>, %B: tensor<2x16x8x
   return %0 : tensor<2x4x16xf32>
 }
 
-// CHECK-DAG: #[[$MAP_BTB_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[$MAP_BTB_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK-DAG: #[[$MAP_BTB_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 // CHECK-LABEL: op_batch_matmul_transpose_b
-// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul indexing_maps = [#[[$MAP_BTB_A]], #[[$MAP_BTB_B]], #[[$MAP_BTB_C]]] ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>) outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>

>From 306f4622c333fc6c94995e942562a17d0e69330a Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:04:36 +0000
Subject: [PATCH 4/5] style

---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4a00545384972..9da04727c271e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -304,9 +304,9 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
         genericOp, "contains invalid cast ops for the named matmul op");
 
   /// Codegen the different matmul variants.
-  /// maps to express transposed operands on the standard op.
-  if (numOfBatchDims)
+  if (numOfBatchDims) {
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
+  }
 
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }

>From b098addabda3b90ab88a96c9a340047d489e1121 Mon Sep 17 00:00:00 2001
From: default <ziereis at roofline.ai>
Date: Fri, 13 Feb 2026 20:05:24 +0000
Subject: [PATCH 5/5] whitespace

---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 9da04727c271e..05370f4aae9eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -307,7 +307,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (numOfBatchDims) {
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
-
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 



More information about the Mlir-commits mailing list