[Mlir-commits] [mlir] [MLIR][Linalg] Remove matmul_transpose variants (PR #147961)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 10 06:31:12 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Renato Golin (rengolin)

<details>
<summary>Changes</summary>

Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since #<!-- -->104783 merged.

See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245

Issues investigated:
* pad transform tests that could use `matmul` instead, so change to that.
* ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps.

Arm tests validated by @<!-- -->banach-space (thanks!!).

---

Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff


20 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-286) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+1-5) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (-16) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (-11) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+48-20) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-7) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-93) 
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (-50) 
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul.mlir (-144) 
- (modified) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (-30) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (-111) 
- (modified) mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir (-85) 
- (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+3-3) 
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (-89) 
- (modified) mlir/test/Dialect/Linalg/transpose-matmul.mlir (+26-12) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+7-2) 
- (modified) mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py (+1-1) 
- (modified) mlir/utils/tree-sitter-mlir/dialect/linalg.js (-2) 
- (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_a
-  cpp_class_name: MatmulTransposeAOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with lhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_b
-  cpp_class_name: MatmulTransposeBOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with rhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
   cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_a
-  cpp_class_name: BatchMatmulTransposeAOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where lhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_b
-  cpp_class_name: BatchMatmulTransposeBOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where rhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
   cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::GenericOp>,
                BlockPackMatmul<linalg::MatmulOp>,
-               BlockPackMatmul<linalg::BatchMatmulOp>,
-               BlockPackMatmul<linalg::MatmulTransposeAOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
-               BlockPackMatmul<linalg::MatmulTransposeBOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+               BlockPackMatmul<linalg::BatchMatmulOp>>(
       patterns.getContext(), controlFn);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   static bool constexpr reduceLeft =
       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
-      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
       (std::is_same_v<FromOpTy, MatmulOp> &&
        std::is_same_v<ToOpTy, VecmatOp>) ||
-      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, VecmatOp>) ||
       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
 
   /// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
   MLIRContext *context = patterns.getContext();
   // Unbatching patterns for unit batch size
   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-          context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-          context);
   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
 
   // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
   // Batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
-      context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
-      context);
 
   // Non-batch rank 0 reducing patterns
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
-                                                               genericOp);
-    if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
-                                                               genericOp);
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
   }
-
-  if (a == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
-  if (b == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
 ///
 /// with
 ///
-///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
       dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{1, 0});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = matmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+    mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-        matmulOp.getOutputs());
+    newLHS = matmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+    mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+      loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      matmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(matmulOp, newMatmulOp);
   return newMatmulOp;
 }
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
 ///
 /// with
 ///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// Only the non-batch dimensions are transposed. By default the LHS is
 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
       type.getElementType(), dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2, d3;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2, d3);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-        batchMatmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = batchMatmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-        batchMatmulOp.getOutputs());
+    newLHS = batchMatmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      batchMatmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
   return newMatmulOp;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
            "vectorization\n");
       return failure();
     }
-    if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+    if (isa<linalg::MatmulOp>(op)) {
       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
            "is not supported\n");
       return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
       return failure();
   }
 
-  // Check to not let go the matmul with extended semantic, through this
-  // transform.
-  if (linalgOp.hasUserDefinedMaps())
-    return failure();
-
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwc...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/147961


More information about the Mlir-commits mailing list