[Mlir-commits] [mlir] [MLIR][Linalg] Remove matmul_transpose variants (PR #147961)
Renato Golin
llvmlistbot at llvm.org
Fri Jul 11 10:09:44 PDT 2025
https://github.com/rengolin updated https://github.com/llvm/llvm-project/pull/147961
>From ed24919966d1d62a2cb5ae908f40c9932d137ad9 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Tue, 8 Jul 2025 10:58:42 +0100
Subject: [PATCH 1/3] [MLIR][Linalg] Remove matmul_transpose variants
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and
replace it with `matmul affine_maps [...]` whenever appropriate.
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 286 ------------------
.../Linalg/Transforms/BlockPackMatmul.cpp | 6 +-
.../Linalg/Transforms/DropUnitDims.cpp | 16 -
.../Dialect/Linalg/Transforms/Specialize.cpp | 11 -
.../Linalg/Transforms/TransposeMatmul.cpp | 68 +++--
.../Linalg/Transforms/Vectorization.cpp | 8 +-
.../linalg/opdsl/ops/core_named_ops.py | 93 ------
.../Linalg/block-pack-matmul-layout.mlir | 50 ---
.../Dialect/Linalg/block-pack-matmul.mlir | 144 ---------
.../Dialect/Linalg/fold-add-into-dest.mlir | 30 --
mlir/test/Dialect/Linalg/named-ops.mlir | 111 -------
.../Linalg/rank-reduce-contraction-ops.mlir | 85 ------
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 2 +-
.../test/Dialect/Linalg/transform-op-pad.mlir | 6 +-
.../transform-op-specialize-matmul.mlir | 89 ------
.../test/Dialect/Linalg/transpose-matmul.mlir | 38 ++-
.../Linalg/CPU/ArmSME/matmul-transpose-a.mlir | 9 +-
.../linalg/opdsl/test_core_named_ops.py | 2 +-
mlir/utils/tree-sitter-mlir/dialect/linalg.js | 2 -
.../tree-sitter-mlir/queries/highlights.scm | 2 -
20 files changed, 88 insertions(+), 970 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_a
- cpp_class_name: MatmulTransposeAOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_b
- cpp_class_name: MatmulTransposeBOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_a
- cpp_class_name: BatchMatmulTransposeAOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_b
- cpp_class_name: BatchMatmulTransposeBOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+ BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(), controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
///
/// with
///
-/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = matmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ newLHS = matmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ matmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
///
/// with
///
-/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2, d3;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2, d3);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
- batchMatmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = batchMatmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
} else {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
- batchMatmulOp.getOutputs());
+ newLHS = batchMatmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ batchMatmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization\n");
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported\n");
return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 1b359da40a291..fd4a5a848f1e3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -373,42 +373,6 @@ def quantized_matmul(
)
- at linalg_structured_op
-def matmul_transpose_a(
- A=TensorDef(T1, S.K, S.N),
- B=TensorDef(T2, S.K, S.M),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
-
-
- at linalg_structured_op
-def matmul_transpose_b(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.N, S.K),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
-
-
@linalg_structured_op
def mmt4d(
lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
@@ -453,44 +417,6 @@ def batch_mmt4d(
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
- at linalg_structured_op
-def batch_matmul_transpose_a(
- A=TensorDef(T1, Batch, S.K, S.M),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
- at linalg_structured_op
-def batch_matmul_transpose_b(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.N, S.K),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.n, D.k]
- )
-
-
@linalg_structured_op
def quantized_batch_matmul(
A=TensorDef(T1, Batch, S.M, S.K),
@@ -512,25 +438,6 @@ def quantized_batch_matmul(
) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
- at linalg_structured_op
-def batch_reduce_matmul(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
-):
- """Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
@linalg_structured_op
def matvec(
A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
index 4ba4b09f52163..2f30e8b9d01e7 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
@@ -20,20 +20,6 @@ func.func @block_matmul(
return %0 : tensor<64x64xf32>
}
-func.func @block_matmul_transpose_a(
- %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-func.func @block_matmul_transpose_b(
- %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
// MMT4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// MMT4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
// MMT4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
@@ -43,18 +29,6 @@ func.func @block_matmul_transpose_b(
// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_a
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_b
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
// MM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// MM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -65,18 +39,6 @@ func.func @block_matmul_transpose_b(
// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_a
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_b
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
// MTM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
// MTM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -87,15 +49,3 @@ func.func @block_matmul_transpose_b(
// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_a
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_b
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index aa860dbd581a9..e16af1f6da0e3 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -197,150 +197,6 @@ func.func @block_batch_matmul(
// -----
-func.func @block_matmul_transpose_a(
- %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_a(
-// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_a(
- %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
- %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
- outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
- return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_a(
-// CHECK-SAME: %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
-func.func @block_matmul_transpose_b(
- %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_b(
-// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_b(
- %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
- %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
- outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
- return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_b(
-// CHECK-SAME: %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
index d8e92e40739dc..e90247d052b4b 100644
--- a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -157,36 +157,6 @@ module attributes {transform.with_named_sequence} {
// -----
-!type = tensor<2048x2048xf32>
-func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
- %0 = arith.constant dense<1.111111e+00> : !type
- %cst = arith.constant 0.000000e+00 : f32
- %1 = tensor.empty() : !type
- %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
- %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
- %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
- %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
- return %5 : !type
-}
-
-// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
-// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
-// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
-// CHECK-NOT: linalg.add
-// CHECK-NEXT: return %[[RES]]
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.linalg.fold_add_into_dest
- } : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
!type = tensor<2048x2048xf32>
func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
%0 = arith.constant dense<1.111111e+00> : !type
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 412f40d501154..3da8fb950d8f7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1222,84 +1222,6 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
// -----
-// CHECK-LABEL: func @matmul_transpose_a
-// CHECK: linalg.matmul_transpose_a
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @matmul_transpose_a_explicit
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// -----
-
-func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
-// CHECK: return
-// CHECK: }
-
-// -----
-
-func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
-// CHECK: return
-// CHECK: }
-
-// -----
-
func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2)>,
@@ -1478,17 +1400,6 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// -----
-// CHECK-LABEL: func @matmul_transpose_b
-// CHECK: linalg.matmul_transpose_b
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// -----
-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -1806,28 +1717,6 @@ func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: m
// -----
-// CHECK-LABEL: func @batchmatmul_transpose_a
-// CHECK: linalg.batch_matmul_transpose_a
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
- linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @batchmatmul_transpose_b
-// CHECK: linalg.batch_matmul_transpose_b
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
- linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
- return
-}
-
-// -----
-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 43bddb075e649..704576de41960 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -92,38 +92,6 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
// -----
-func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singleton_batchmatmul_transpose_a
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
-func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singleton_batchmatmul_transpose_b
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK-LABEL: @matmul_to_matvec_tensor
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -226,59 +194,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
// -----
-func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
- // CHECK-LABEL: @matmul_transpose_a_to_vecmat
- // CHECK: collapse_shape {{.*}} into tensor<256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<512xf32>
- // CHECK: linalg.vecmat
- // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
- %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
- return %0 : tensor<1x512xf32>
-}
-
-// -----
-
-func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
- // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
- // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
- // CHECK: linalg.batch_vecmat
- // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
- %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
- return %0 : tensor<64x1x512xf32>
-}
-
-// -----
-
-func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
- // CHECK-LABEL: @matmul_transpose_b_to_matvec
- // CHECK: linalg.matvec
- linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
- return
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
- // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
- // CHECK: linalg.batch_matvec
- // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32>
- return %0 : tensor<64x128x1xf32>
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
- // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
- // CHECK: linalg.dot
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
- return %0 : tensor<1x1x1xf32>
-}
-
-// -----
-
func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
// CHECK-LABEL: @nonsingleton_batch_matmul
// CHECK-NOT: collapse_shape
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 778d5bb8b9c84..1b0bade728b44 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -504,7 +504,7 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
%c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
%sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index f91eb9c30a51a..51bf4a23406d4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -465,14 +465,14 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[RHS:.*]] = tensor.pad
// CHECK: scf.for
// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
-// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK-DAG: tensor.extract_slice %[[RHS]][%{{.*}}, 0] [32, %{{.*}}]
func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
index f64953bceefe1..bd4c65512dfa6 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -30,66 +30,6 @@ module attributes {transform.with_named_sequence} {
// -----
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %0 = arith.mulf %in, %in_0 : f32
- %1 = arith.addf %out, %0 : f32
- linalg.yield %1 : f32
- }
- return
-}
-
-// CHECK-LABEL: @matmul_transpose_a
-// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.mulf %in, %in_0 : f32
- %2 = arith.addf %out, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: @matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -117,32 +57,3 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
-// -----
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
- %0 = linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.mulf %in, %in_0 : f32
- %2 = arith.addf %out, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
-}
-
-// CHECK-LABEL: @batch_matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
index d2b7e9f7f1992..4ee87fb11d527 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,6 +1,20 @@
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
+// TRANSPOSE-A-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// TRANSPOSE-A-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// TRANSPOSE-A-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-A-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// TRANSPOSE-A-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// TRANSPOSE-A-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// TRANSPOSE-B-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// TRANSPOSE-B-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// TRANSPOSE-B-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-B-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// TRANSPOSE-B-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// TRANSPOSE-B-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
// CHECK-LABEL: func.func @matmul_static(
// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
@@ -9,10 +23,10 @@
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// CHECK: return %[[C]] : tensor<16x16xf32>
// CHECK: }
func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -38,11 +52,11 @@ func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<
// TRANSPOSE-A: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return %[[C]] : tensor<?x?xf32>
// CHECK: }
func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -69,10 +83,10 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// CHECK: return %[[B0]] : tensor<?x16xf32>
// CHECK: }
func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
@@ -96,10 +110,10 @@ func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// CHECK: return %[[C]] : tensor<2x16x16xf32>
// CHECK: }
func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) {
@@ -127,12 +141,12 @@ func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -
// TRANSPOSE-A: %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: return %[[C]] : tensor<?x?x?xf32>
// CHECK: }
func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
@@ -161,10 +175,10 @@ func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) ->
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-A: %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-B: %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// CHECK: return %[[B0]] : tensor<2x?x16xf32>
// CHECK: }
func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index 06a6e2279b6a7..9d043573091eb 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -9,7 +9,12 @@
// RUN: FileCheck %s
func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
- %res = linalg.matmul_transpose_a ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ %res = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
%xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
@@ -56,7 +61,7 @@ func.func @main() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
- %matmul_transpose_a = transform.structured.match ops{["linalg.matmul_transpose_a"]} in %module
+ %matmul_transpose_a = transform.structured.match ops{["linalg.matmul"]} in %module
: (!transform.any_op) -> !transform.any_op
// Step 1: Tile for size [4] x [4], which corresponds to SVLs x SVLs, where
diff --git a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
index ee76b6d25cae1..bc273bfa06f86 100644
--- a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
+++ b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
@@ -1,7 +1,7 @@
# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops | FileCheck %s
# Just verify that at least one known op is generated.
-# CHECK: name: matmul
+# CHECK: name: copy
# verify some special cases: negf->NegFOp, powf->PowFOp
# CHECK cpp_class_name: NegFOp
diff --git a/mlir/utils/tree-sitter-mlir/dialect/linalg.js b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
index ddde92b2f692b..f4658085ce6f3 100644
--- a/mlir/utils/tree-sitter-mlir/dialect/linalg.js
+++ b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
@@ -4,7 +4,6 @@ module.exports = {
linalg_dialect : $ => prec.right(choice(
seq(choice(
'linalg.batch_matmul',
- 'linalg.batch_matmul_transpose_b',
'linalg.batch_matvec',
'linalg.batch_reduce_matmul', 'linalg.broadcast',
'linalg.conv_1d_ncw_fcw', 'linalg.conv_1d_nwc_wcf',
@@ -27,7 +26,6 @@ module.exports = {
'linalg.dot', 'linalg.elemwise_binary',
'linalg.elemwise_unary', 'linalg.fill',
'linalg.fill_rng_2d', 'linalg.matmul',
- 'linalg.matmul_transpose_b',
'linalg.matmul_unsigned', 'linalg.matvec',
'linalg.mmt4d', 'linalg.pooling_nchw_max',
'linalg.pooling_nchw_sum',
diff --git a/mlir/utils/tree-sitter-mlir/queries/highlights.scm b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
index 4cbea7bbca031..59e280bab414a 100644
--- a/mlir/utils/tree-sitter-mlir/queries/highlights.scm
+++ b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
@@ -213,7 +213,6 @@
"bufferization.to_tensor"
"linalg.batch_matmul"
- "linalg.batch_matmul_transpose_b"
"linalg.batch_matvec"
"linalg.batch_reduce_matmul"
"linalg.broadcast"
@@ -244,7 +243,6 @@
"linalg.fill"
"linalg.fill_rng_2d"
"linalg.matmul"
- "linalg.matmul_transpose_b"
"linalg.matmul_unsigned"
"linalg.matvec"
"linalg.mmt4d"
>From b2460adbe58be1456d092abdcc2bdf558657e6f2 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Thu, 10 Jul 2025 14:34:01 +0100
Subject: [PATCH 2/3] format
---
mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 57f898458516e..b4507a93e0098 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,6 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>>(
- patterns.getContext(), controlFn);
+ BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+ controlFn);
}
>From 8744e49402e73887849c78c7865910f48526bb42 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 11 Jul 2025 12:36:37 +0100
Subject: [PATCH 3/3] Specialization for transpose variants
No code changes needed anymore, all handled by the new specialization
classes. Fixing a bug in the previous matmul builder where the affine
map was being ignored.
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 118 ++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 173 +++++++++++++++++-
.../Linalg/Transforms/TransposeMatmul.cpp | 68 ++-----
3 files changed, 308 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..d18c1fdcb171e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -144,4 +144,122 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
+namespace mlir::linalg {
+
+/// Specialization of `linalg.matmul` op that has a transpose map on A
+class MatmulTransposeAOp : public MatmulOp {
+ /// Create an affine map for a transpose-A matmul. Used only in the builders.
+ static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+ using MatmulOp::MatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+ /// Build a transpose A matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.matmul` op that has a transpose map on B
+class MatmulTransposeBOp : public MatmulOp {
+ /// Create an affine map for a transpose-B matmul. Used only in the builders.
+ static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+ using MatmulOp::MatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+ /// Build a transpose B matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
+class BatchMatmulTransposeAOp : public BatchMatmulOp {
+ /// Create an affine map for a transpose-A batch_matmul. Used only in the
+ /// builders.
+ static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+ using BatchMatmulOp::BatchMatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+ /// Build a transpose A matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
+class BatchMatmulTransposeBOp : public BatchMatmulOp {
+ /// Create an affine map for a transpose-B batch_matmul. Used only in the
+ /// builders.
+ static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
+
+public:
+ using BatchMatmulOp::BatchMatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+ /// Build a transpose A matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static bool classof(Operation *op);
+};
+
+} // namespace mlir::linalg
+
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..d4b46037267d8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -193,9 +193,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
ArrayRef<AffineMap> indexingMaps) {
// Initialize indexingMaps attribute, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
- indexingMapsAttrVal = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(b.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -3881,6 +3882,172 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+SmallVector<AffineMap> MatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ auto context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+ return affineMaps;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op);
+}
+
+SmallVector<AffineMap> MatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ auto context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+ return affineMaps;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op);
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ auto context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+ return affineMaps;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ auto context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
+ return affineMaps;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
+}
+
//===----------------------------------------------------------------------===//
// ContractOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 086f9e5d05e6f..934781d1cab75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,11 +23,10 @@ using namespace mlir::linalg;
///
/// with
///
-/// linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
+/// linalg.matmul_transpose_a(linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
-/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
@@ -58,31 +57,18 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Value newLHS, newRHS;
- AffineMap mapLHS, mapRHS, mapOut;
- AffineExpr d0, d1, d2;
- auto context = rewriter.getContext();
- bindDims(context, d0, d1, d2);
+ Operation *newMatmulOp;
if (transposeLHS) {
- newLHS = transposeOp->getResult(0);
- newRHS = matmulOp.getInputs()[1];
- mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
- mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
- mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+ loc, matmulOp.getResultTypes(),
+ ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+ matmulOp.getOutputs());
} else {
- newLHS = matmulOp.getInputs()[0];
- newRHS = transposeOp->getResult(0);
- mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
- mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
- mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+ loc, matmulOp.getResultTypes(),
+ ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+ matmulOp.getOutputs());
}
- Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
- loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
- matmulOp.getOutputs());
- newMatmulOp->setAttr("indexing_maps",
- rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
- AffineMapAttr::get(mapRHS),
- AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
@@ -93,11 +79,10 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
///
/// with
///
-/// linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
+/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
-/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
@@ -129,31 +114,18 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- Value newLHS, newRHS;
- AffineMap mapLHS, mapRHS, mapOut;
- AffineExpr d0, d1, d2, d3;
- auto context = rewriter.getContext();
- bindDims(context, d0, d1, d2, d3);
+ Operation *newMatmulOp;
if (transposeLHS) {
- newLHS = transposeOp->getResult(0);
- newRHS = batchMatmulOp.getInputs()[1];
- mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
- mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
- mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+ loc, batchMatmulOp.getResultTypes(),
+ ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+ batchMatmulOp.getOutputs());
} else {
- newLHS = batchMatmulOp.getInputs()[0];
- newRHS = transposeOp->getResult(0);
- mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
- mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
- mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+ loc, batchMatmulOp.getResultTypes(),
+ ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+ batchMatmulOp.getOutputs());
}
- Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
- loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
- batchMatmulOp.getOutputs());
- newMatmulOp->setAttr("indexing_maps",
- rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
- AffineMapAttr::get(mapRHS),
- AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
More information about the Mlir-commits
mailing list