[Mlir-commits] [mlir] [MLIR][Linalg] Remove matmul_transpose variants (PR #147961)

Renato Golin llvmlistbot at llvm.org
Fri Jul 11 10:09:44 PDT 2025


https://github.com/rengolin updated https://github.com/llvm/llvm-project/pull/147961

>From ed24919966d1d62a2cb5ae908f40c9932d137ad9 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Tue, 8 Jul 2025 10:58:42 +0100
Subject: [PATCH 1/3] [MLIR][Linalg] Remove matmul_transpose variants

Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and
replace it with `matmul affine_maps [...]` whenever appropriate.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 286 ------------------
 .../Linalg/Transforms/BlockPackMatmul.cpp     |   6 +-
 .../Linalg/Transforms/DropUnitDims.cpp        |  16 -
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  11 -
 .../Linalg/Transforms/TransposeMatmul.cpp     |  68 +++--
 .../Linalg/Transforms/Vectorization.cpp       |   8 +-
 .../linalg/opdsl/ops/core_named_ops.py        |  93 ------
 .../Linalg/block-pack-matmul-layout.mlir      |  50 ---
 .../Dialect/Linalg/block-pack-matmul.mlir     | 144 ---------
 .../Dialect/Linalg/fold-add-into-dest.mlir    |  30 --
 mlir/test/Dialect/Linalg/named-ops.mlir       | 111 -------
 .../Linalg/rank-reduce-contraction-ops.mlir   |  85 ------
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |   2 +-
 .../test/Dialect/Linalg/transform-op-pad.mlir |   6 +-
 .../transform-op-specialize-matmul.mlir       |  89 ------
 .../test/Dialect/Linalg/transpose-matmul.mlir |  38 ++-
 .../Linalg/CPU/ArmSME/matmul-transpose-a.mlir |   9 +-
 .../linalg/opdsl/test_core_named_ops.py       |   2 +-
 mlir/utils/tree-sitter-mlir/dialect/linalg.js |   2 -
 .../tree-sitter-mlir/queries/highlights.scm   |   2 -
 20 files changed, 88 insertions(+), 970 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_a
-  cpp_class_name: MatmulTransposeAOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with lhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_b
-  cpp_class_name: MatmulTransposeBOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with rhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
   cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_a
-  cpp_class_name: BatchMatmulTransposeAOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where lhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_b
-  cpp_class_name: BatchMatmulTransposeBOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where rhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
   cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::GenericOp>,
                BlockPackMatmul<linalg::MatmulOp>,
-               BlockPackMatmul<linalg::BatchMatmulOp>,
-               BlockPackMatmul<linalg::MatmulTransposeAOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
-               BlockPackMatmul<linalg::MatmulTransposeBOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+               BlockPackMatmul<linalg::BatchMatmulOp>>(
       patterns.getContext(), controlFn);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   static bool constexpr reduceLeft =
       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
-      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
       (std::is_same_v<FromOpTy, MatmulOp> &&
        std::is_same_v<ToOpTy, VecmatOp>) ||
-      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, VecmatOp>) ||
       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
 
   /// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
   MLIRContext *context = patterns.getContext();
   // Unbatching patterns for unit batch size
   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-          context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-          context);
   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
 
   // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
   // Batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
-      context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
-      context);
 
   // Non-batch rank 0 reducing patterns
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
-                                                               genericOp);
-    if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
-                                                               genericOp);
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
   }
-
-  if (a == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
-  if (b == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
 ///
 /// with
 ///
-///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
       dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{1, 0});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = matmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+    mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-        matmulOp.getOutputs());
+    newLHS = matmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+    mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+      loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      matmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(matmulOp, newMatmulOp);
   return newMatmulOp;
 }
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
 ///
 /// with
 ///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// Only the non-batch dimensions are transposed. By default the LHS is
 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
       type.getElementType(), dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2, d3;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2, d3);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-        batchMatmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = batchMatmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-        batchMatmulOp.getOutputs());
+    newLHS = batchMatmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      batchMatmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
   return newMatmulOp;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
            "vectorization\n");
       return failure();
     }
-    if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+    if (isa<linalg::MatmulOp>(op)) {
       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
            "is not supported\n");
       return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
       return failure();
   }
 
-  // Check to not let go the matmul with extended semantic, through this
-  // transform.
-  if (linalgOp.hasUserDefinedMaps())
-    return failure();
-
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
 }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 1b359da40a291..fd4a5a848f1e3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -373,42 +373,6 @@ def quantized_matmul(
     )
 
 
- at linalg_structured_op
-def matmul_transpose_a(
-    A=TensorDef(T1, S.K, S.N),
-    B=TensorDef(T2, S.K, S.M),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs with lhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
-
-
- at linalg_structured_op
-def matmul_transpose_b(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.N, S.K),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs with rhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
-
-
 @linalg_structured_op
 def mmt4d(
     lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
@@ -453,44 +417,6 @@ def batch_mmt4d(
     ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
 
 
- at linalg_structured_op
-def batch_matmul_transpose_a(
-    A=TensorDef(T1, Batch, S.K, S.M),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs where lhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
- at linalg_structured_op
-def batch_matmul_transpose_b(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.N, S.K),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs where rhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.n, D.k]
-    )
-
-
 @linalg_structured_op
 def quantized_batch_matmul(
     A=TensorDef(T1, Batch, S.M, S.K),
@@ -512,25 +438,6 @@ def quantized_batch_matmul(
     ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
- at linalg_structured_op
-def batch_reduce_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-):
-    """Performs a batch-reduce matrix multiplication of two 3D inputs.
-    The partial multiplication results are reduced into a 2D output.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
 @linalg_structured_op
 def matvec(
     A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
index 4ba4b09f52163..2f30e8b9d01e7 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
@@ -20,20 +20,6 @@ func.func @block_matmul(
   return %0 : tensor<64x64xf32>
 }
 
-func.func @block_matmul_transpose_a(
-    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
-  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
-                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
-  return %0 : tensor<64x64xf32>
-}
-
-func.func @block_matmul_transpose_b(
-    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
-  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
-                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
-  return %0 : tensor<64x64xf32>
-}
-
 // MMT4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
 // MMT4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
 // MMT4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
@@ -43,18 +29,6 @@ func.func @block_matmul_transpose_b(
 // MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
 // MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
 // MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_a
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_b
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
 
 // MM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
 // MM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -65,18 +39,6 @@ func.func @block_matmul_transpose_b(
 // MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
 // MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
 // MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_a
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_b
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
 
 // MTM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
 // MTM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -87,15 +49,3 @@ func.func @block_matmul_transpose_b(
 // MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
 // MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
 // MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_a
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_b
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index aa860dbd581a9..e16af1f6da0e3 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -197,150 +197,6 @@ func.func @block_batch_matmul(
 
 // -----
 
-func.func @block_matmul_transpose_a(
-    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
-  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
-                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
-  return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_a(
-// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_a(
-    %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
-  %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
-                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
-  return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_a(
-// CHECK-SAME:   %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
-func.func @block_matmul_transpose_b(
-    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
-  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
-                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
-  return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_b(
-// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_b(
-    %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
-  %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
-                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
-  return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_b(
-// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
index d8e92e40739dc..e90247d052b4b 100644
--- a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -157,36 +157,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!type = tensor<2048x2048xf32>
-func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
-  %0 = arith.constant dense<1.111111e+00> : !type
-  %cst = arith.constant 0.000000e+00 : f32
-  %1 = tensor.empty() : !type
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
-  %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
-  %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
-  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
-  return %5 : !type
-}
-
-// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
-// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
-// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
-// CHECK-NOT: linalg.add
-// CHECK-NEXT: return %[[RES]]
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.linalg.fold_add_into_dest
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 !type = tensor<2048x2048xf32>
 func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
   %0 = arith.constant dense<1.111111e+00> : !type
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 412f40d501154..3da8fb950d8f7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1222,84 +1222,6 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
 
 // -----
 
-// CHECK-LABEL: func @matmul_transpose_a
-//       CHECK:   linalg.matmul_transpose_a
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @matmul_transpose_a_explicit
-//       CHECK:   linalg.matmul
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.matmul indexing_maps = [
-                       affine_map<(d0, d1, d2) -> (d2, d0)>,
-                       affine_map<(d0, d1, d2) -> (d2, d1)>,
-                       affine_map<(d0, d1, d2) -> (d0, d1)>
-                      ]
-                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
-                      outs(%arg2: memref<3x7xf32>)
-  return
-}
-
-// -----
-
-func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.matmul indexing_maps = [
-                       affine_map<(d0, d1, d2) -> (d0, d2)>,
-                       affine_map<(d0, d1, d2) -> (d1, d2)>,
-                       affine_map<(d0, d1, d2) -> (d0, d1)>
-                      ]
-                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
-                      outs(%arg2: memref<3x7xf32>)
-  return
-}
-
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
-// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
-// CHECK:           return
-// CHECK:         }
-
-// -----
-
-func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.matmul indexing_maps = [
-                       affine_map<(d0, d1, d2) -> (d2, d0)>,
-                       affine_map<(d0, d1, d2) -> (d1, d2)>,
-                       affine_map<(d0, d1, d2) -> (d0, d1)>
-                      ]
-                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
-                      outs(%arg2: memref<3x7xf32>)
-  return
-}
-
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
-// CHECK:           return
-// CHECK:         }
-
-// -----
-
 func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
   linalg.matmul indexing_maps = [
                        affine_map<(d0, d1, d2) -> (d2)>,
@@ -1478,17 +1400,6 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
 
 // -----
 
-// CHECK-LABEL: func @matmul_transpose_b
-//       CHECK:   linalg.matmul_transpose_b
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
-  return
-}
-
-// -----
-
 // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -1806,28 +1717,6 @@ func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: m
 
 // -----
 
-// CHECK-LABEL: func @batchmatmul_transpose_a
-//       CHECK:   linalg.batch_matmul_transpose_a
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
-  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @batchmatmul_transpose_b
-//       CHECK:   linalg.batch_matmul_transpose_b
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
-  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
-  return
-}
-
-// -----
-
 // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 43bddb075e649..704576de41960 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -92,38 +92,6 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
 
 // -----
 
-func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
-  // CHECK-LABEL: @singleton_batchmatmul_transpose_a
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
-  //  CHECK-NEXT:   return
-  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
-  return
-}
-
-// -----
-
-func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
-  // CHECK-LABEL: @singleton_batchmatmul_transpose_b
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
-  //  CHECK-NEXT:   return
-  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
-  return
-}
-
-// -----
-
 func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK-LABEL: @matmul_to_matvec_tensor
   //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -226,59 +194,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
 
 // -----
 
-func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
-  // CHECK-LABEL: @matmul_transpose_a_to_vecmat
-  // CHECK: collapse_shape {{.*}} into tensor<256xf32>
-  // CHECK: collapse_shape {{.*}} into tensor<512xf32>
-  // CHECK: linalg.vecmat
-  // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
-    %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
-    return %0 : tensor<1x512xf32>
-}
-
-// -----
-
-func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
-  // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
-  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
-  // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
-  // CHECK: linalg.batch_vecmat
-  // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
-    %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
-    return %0 : tensor<64x1x512xf32>
-}
-
-// -----
-
-func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
-  // CHECK-LABEL: @matmul_transpose_b_to_matvec
-  // CHECK: linalg.matvec
-    linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
-    return
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
-  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
-  // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
-  // CHECK: linalg.batch_matvec
-  // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
-    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> 
-    return %0 : tensor<64x128x1xf32>
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
-  // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
-  // CHECK: linalg.dot
-    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> 
-    return %0 : tensor<1x1x1xf32>
-}
-
-// -----
-
 func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
   // CHECK-LABEL: @nonsingleton_batch_matmul
   // CHECK-NOT:   collapse_shape
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 778d5bb8b9c84..1b0bade728b44 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -504,7 +504,7 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
     %c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
     %sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index f91eb9c30a51a..51bf4a23406d4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -465,14 +465,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK: %[[RHS:.*]] = tensor.pad
 // CHECK: scf.for
 // CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
-// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK-DAG: tensor.extract_slice %[[RHS]][%{{.*}}, 0] [32, %{{.*}}]
 func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     %tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
index f64953bceefe1..bd4c65512dfa6 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -30,66 +30,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.generic
-     {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
-     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
-      ^bb0(%in: f32, %in_0: f32, %out: f32):
-       %0 = arith.mulf %in, %in_0 : f32
-       %1 = arith.addf %out, %0 : f32
-       linalg.yield %1 : f32
-  }
-  return
-}
-
-// CHECK-LABEL: @matmul_transpose_a
-// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.generic
-          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
-          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.mulf %in, %in_0 : f32
-    %2 = arith.addf %out, %1 : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: @matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -117,32 +57,3 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
-// -----
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  %0 = linalg.generic
-       {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-       ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-      %1 = arith.mulf %in, %in_0 : f32
-      %2 = arith.addf %out, %1 : f32
-      linalg.yield %2 : f32
-  } -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
-}
-
-// CHECK-LABEL: @batch_matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
index d2b7e9f7f1992..4ee87fb11d527 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,6 +1,20 @@
 // RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
 // RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
 
+// TRANSPOSE-A-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// TRANSPOSE-A-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// TRANSPOSE-A-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-A-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// TRANSPOSE-A-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// TRANSPOSE-A-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// TRANSPOSE-B-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// TRANSPOSE-B-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// TRANSPOSE-B-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-B-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// TRANSPOSE-B-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// TRANSPOSE-B-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
 // CHECK-LABEL:   func.func @matmul_static(
 // CHECK-SAME:                             %[[A:.*]]: tensor<16x8xf32>,
 // CHECK-SAME:                             %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
@@ -9,10 +23,10 @@
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // CHECK:           return %[[C]] : tensor<16x16xf32>
 // CHECK:         }
 func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -38,11 +52,11 @@ func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<
 // TRANSPOSE-A:     %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK:           return %[[C]] : tensor<?x?xf32>
 // CHECK:         }
 func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -69,10 +83,10 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A:     %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A:     %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // CHECK:           return %[[B0]] : tensor<?x16xf32>
 // CHECK:         }
 func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
@@ -96,10 +110,10 @@ func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 // CHECK:           return %[[C]] : tensor<2x16x16xf32>
 // CHECK:         }
 func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) {
@@ -127,12 +141,12 @@ func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -
 // TRANSPOSE-A:     %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32>
 // TRANSPOSE-B:     %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK:           return %[[C]] : tensor<?x?x?xf32>
 // CHECK:         }
 func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
@@ -161,10 +175,10 @@ func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) ->
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
 // TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32>
 // TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A:     %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-A:     %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
 // TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
 // TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B:     %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-B:     %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
 // CHECK:           return %[[B0]] : tensor<2x?x16xf32>
 // CHECK:         }
 func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index 06a6e2279b6a7..9d043573091eb 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -9,7 +9,12 @@
 // RUN: FileCheck %s
 
 func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
-  %res = linalg.matmul_transpose_a ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+  %res = linalg.matmul
+            indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2, d0)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>]
+            ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
                                    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
   %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
   call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
@@ -56,7 +61,7 @@ func.func @main() {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
-    %matmul_transpose_a = transform.structured.match ops{["linalg.matmul_transpose_a"]} in %module
+    %matmul_transpose_a = transform.structured.match ops{["linalg.matmul"]} in %module
       : (!transform.any_op) -> !transform.any_op
 
     // Step 1: Tile for size [4] x [4], which corresponds to SVLs x SVLs, where
diff --git a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
index ee76b6d25cae1..bc273bfa06f86 100644
--- a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
+++ b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
@@ -1,7 +1,7 @@
 # RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops | FileCheck %s
 
 # Just verify that at least one known op is generated.
-# CHECK: name: matmul
+# CHECK: name: copy
 
 # verify some special cases: negf->NegFOp, powf->PowFOp
 # CHECK cpp_class_name: NegFOp
diff --git a/mlir/utils/tree-sitter-mlir/dialect/linalg.js b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
index ddde92b2f692b..f4658085ce6f3 100644
--- a/mlir/utils/tree-sitter-mlir/dialect/linalg.js
+++ b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
@@ -4,7 +4,6 @@ module.exports = {
   linalg_dialect : $ => prec.right(choice(
                      seq(choice(
                              'linalg.batch_matmul',
-                             'linalg.batch_matmul_transpose_b',
                              'linalg.batch_matvec',
                              'linalg.batch_reduce_matmul', 'linalg.broadcast',
                              'linalg.conv_1d_ncw_fcw', 'linalg.conv_1d_nwc_wcf',
@@ -27,7 +26,6 @@ module.exports = {
                              'linalg.dot', 'linalg.elemwise_binary',
                              'linalg.elemwise_unary', 'linalg.fill',
                              'linalg.fill_rng_2d', 'linalg.matmul',
-                             'linalg.matmul_transpose_b',
                              'linalg.matmul_unsigned', 'linalg.matvec',
                              'linalg.mmt4d', 'linalg.pooling_nchw_max',
                              'linalg.pooling_nchw_sum',
diff --git a/mlir/utils/tree-sitter-mlir/queries/highlights.scm b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
index 4cbea7bbca031..59e280bab414a 100644
--- a/mlir/utils/tree-sitter-mlir/queries/highlights.scm
+++ b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
@@ -213,7 +213,6 @@
   "bufferization.to_tensor"
 
   "linalg.batch_matmul"
-  "linalg.batch_matmul_transpose_b"
   "linalg.batch_matvec"
   "linalg.batch_reduce_matmul"
   "linalg.broadcast"
@@ -244,7 +243,6 @@
   "linalg.fill"
   "linalg.fill_rng_2d"
   "linalg.matmul"
-  "linalg.matmul_transpose_b"
   "linalg.matmul_unsigned"
   "linalg.matvec"
   "linalg.mmt4d"

>From b2460adbe58be1456d092abdcc2bdf558657e6f2 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Thu, 10 Jul 2025 14:34:01 +0100
Subject: [PATCH 2/3] format

---
 mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 57f898458516e..b4507a93e0098 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,6 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::GenericOp>,
                BlockPackMatmul<linalg::MatmulOp>,
-               BlockPackMatmul<linalg::BatchMatmulOp>>(
-      patterns.getContext(), controlFn);
+               BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+                                                       controlFn);
 }

>From 8744e49402e73887849c78c7865910f48526bb42 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 11 Jul 2025 12:36:37 +0100
Subject: [PATCH 3/3] Specialization for transpose variants

No code changes needed anymore, all handled by the new specialization
classes. Fixing a bug in the previous matmul builder where the affine
map was being ignored.
---
 mlir/include/mlir/Dialect/Linalg/IR/Linalg.h  | 118 ++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 173 +++++++++++++++++-
 .../Linalg/Transforms/TransposeMatmul.cpp     |  68 ++-----
 3 files changed, 308 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..d18c1fdcb171e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -144,4 +144,122 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
 
+namespace mlir::linalg {
+
+/// Specialization of `linalg.matmul` op that has a transpose map on A
+class MatmulTransposeAOp : public MatmulOp {
+  /// Create an affine map for a transpose-A matmul. Used only in the builders.
+  static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+  using MatmulOp::MatmulOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+  /// Build a transpose A matmul.
+  static void build(OpBuilder &builder, OperationState &result,
+                    ValueRange inputs, ValueRange outputs,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type and a cast type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, Attribute cast,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.matmul` op that has a transpose map on B
+class MatmulTransposeBOp : public MatmulOp {
+  /// Create an affine map for a transpose-B matmul. Used only in the builders.
+  static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+  using MatmulOp::MatmulOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+  /// Build a transpose B matmul.
+  static void build(OpBuilder &builder, OperationState &result,
+                    ValueRange inputs, ValueRange outputs,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose B matmul with a specific result type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose B matmul with a specific result type and a cast type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, Attribute cast,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
+class BatchMatmulTransposeAOp : public BatchMatmulOp {
+  /// Create an affine map for a transpose-A batch_matmul. Used only in the
+  /// builders.
+  static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+  using BatchMatmulOp::BatchMatmulOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+  /// Build a transpose A matmul.
+  static void build(OpBuilder &builder, OperationState &result,
+                    ValueRange inputs, ValueRange outputs,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type and a cast type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, Attribute cast,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
+class BatchMatmulTransposeBOp : public BatchMatmulOp {
+  /// Create an affine map for a transpose-B batch_matmul. Used only in the
+  /// builders.
+  static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+  using BatchMatmulOp::BatchMatmulOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+  /// Build a transpose A matmul.
+  static void build(OpBuilder &builder, OperationState &result,
+                    ValueRange inputs, ValueRange outputs,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+  /// Build a transpose A matmul with a specific result type and a cast type.
+  static void build(OpBuilder &builder, OperationState &result,
+                    TypeRange resultTensorTypes, ValueRange inputs,
+                    ValueRange outputs, Attribute cast,
+                    ArrayRef<NamedAttribute> attributes = {});
+
+  static bool classof(Operation *op);
+};
+
+} // namespace mlir::linalg
+
 #endif // MLIR_DIALECT_LINALG_IR_LINALG_H
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..d4b46037267d8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -193,9 +193,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                           ArrayRef<AffineMap> indexingMaps) {
   // Initialize indexingMaps attribute, for MatmulOp.
   SmallVector<Attribute, 3> indexingMapsAttrVal;
-  indexingMapsAttrVal = llvm::map_to_vector(
-      MatmulOp::getDefaultIndexingMaps(b.getContext()),
-      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
   state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
                            attributes, regionBuilder);
@@ -3881,6 +3882,172 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+SmallVector<AffineMap> MatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
+  AffineExpr d0, d1, d2;
+  auto context = builder.getContext();
+  bindDims(context, d0, d1, d2);
+  AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+  AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+  AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+  SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+  return affineMaps;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       ValueRange inputs, ValueRange outputs,
+                                       ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       TypeRange resultTensorTypes,
+                                       ValueRange inputs, ValueRange outputs,
+                                       ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       TypeRange resultTensorTypes,
+                                       ValueRange inputs, ValueRange outputs,
+                                       Attribute cast,
+                                       ArrayRef<NamedAttribute> attributes) {
+  result.addAttribute("cast", cast);
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+  return dyn_cast_or_null<linalg::MatmulOp>(op);
+}
+
+SmallVector<AffineMap> MatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
+  AffineExpr d0, d1, d2;
+  auto context = builder.getContext();
+  bindDims(context, d0, d1, d2);
+  AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+  AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+  AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+  SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+  return affineMaps;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       ValueRange inputs, ValueRange outputs,
+                                       ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       TypeRange resultTensorTypes,
+                                       ValueRange inputs, ValueRange outputs,
+                                       ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+                                       OperationState &result,
+                                       TypeRange resultTensorTypes,
+                                       ValueRange inputs, ValueRange outputs,
+                                       Attribute cast,
+                                       ArrayRef<NamedAttribute> attributes) {
+  result.addAttribute("cast", cast);
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+  return dyn_cast_or_null<linalg::MatmulOp>(op);
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
+  AffineExpr d0, d1, d2, d3;
+  auto context = builder.getContext();
+  bindDims(context, d0, d1, d2, d3);
+  AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+  AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+  AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+  SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+  return affineMaps;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+    ValueRange inputs, ValueRange outputs,
+    ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+    ValueRange inputs, ValueRange outputs, Attribute cast,
+    ArrayRef<NamedAttribute> attributes) {
+  result.addAttribute("cast", cast);
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+  return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
+  AffineExpr d0, d1, d2, d3;
+  auto context = builder.getContext();
+  bindDims(context, d0, d1, d2, d3);
+  AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+  AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+  AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+  SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+  return affineMaps;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+    ValueRange inputs, ValueRange outputs,
+    ArrayRef<NamedAttribute> attributes) {
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+    ValueRange inputs, ValueRange outputs, Attribute cast,
+    ArrayRef<NamedAttribute> attributes) {
+  result.addAttribute("cast", cast);
+  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+                BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+  return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
+}
+
 //===----------------------------------------------------------------------===//
 // ContractOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 086f9e5d05e6f..934781d1cab75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,11 +23,10 @@ using namespace mlir::linalg;
 ///
 /// with
 ///
-///   linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
+///   linalg.matmul_transpose_a(linalg.transpose(a), b)
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
-/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
@@ -58,31 +57,18 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
       dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{1, 0});
-  Value newLHS, newRHS;
-  AffineMap mapLHS, mapRHS, mapOut;
-  AffineExpr d0, d1, d2;
-  auto context = rewriter.getContext();
-  bindDims(context, d0, d1, d2);
+  Operation *newMatmulOp;
   if (transposeLHS) {
-    newLHS = transposeOp->getResult(0);
-    newRHS = matmulOp.getInputs()[1];
-    mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
-    mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
-    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
   } else {
-    newLHS = matmulOp.getInputs()[0];
-    newRHS = transposeOp->getResult(0);
-    mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
-    mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
-    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+        matmulOp.getOutputs());
   }
-  Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
-      loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
-      matmulOp.getOutputs());
-  newMatmulOp->setAttr("indexing_maps",
-                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
-                                              AffineMapAttr::get(mapRHS),
-                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(matmulOp, newMatmulOp);
   return newMatmulOp;
 }
@@ -93,11 +79,10 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
 ///
 /// with
 ///
-///   linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
+///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
 ///
 /// Only the non-batch dimensions are transposed. By default the LHS is
 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
-/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
@@ -129,31 +114,18 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
       type.getElementType(), dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-  Value newLHS, newRHS;
-  AffineMap mapLHS, mapRHS, mapOut;
-  AffineExpr d0, d1, d2, d3;
-  auto context = rewriter.getContext();
-  bindDims(context, d0, d1, d2, d3);
+  Operation *newMatmulOp;
   if (transposeLHS) {
-    newLHS = transposeOp->getResult(0);
-    newRHS = batchMatmulOp.getInputs()[1];
-    mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
-    mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
-    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+        batchMatmulOp.getOutputs());
   } else {
-    newLHS = batchMatmulOp.getInputs()[0];
-    newRHS = transposeOp->getResult(0);
-    mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
-    mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
-    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+        batchMatmulOp.getOutputs());
   }
-  Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
-      loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
-      batchMatmulOp.getOutputs());
-  newMatmulOp->setAttr("indexing_maps",
-                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
-                                              AffineMapAttr::get(mapRHS),
-                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
   return newMatmulOp;
 }



More information about the Mlir-commits mailing list