[Mlir-commits] [mlir] [MLIR][Linalg] Remove matmul_transpose variants (PR #147961)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 10 06:31:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Renato Golin (rengolin)
<details>
<summary>Changes</summary>
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since #<!-- -->104783 merged.
See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245
Issues investigated:
* pad transform tests that could use `matmul` instead, so change to that.
* ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps.
Arm tests validated by @<!-- -->banach-space (thanks!!).
---
Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff
20 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-286)
- (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+1-5)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (-16)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (-11)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+48-20)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-7)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-93)
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (-50)
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul.mlir (-144)
- (modified) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (-30)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (-111)
- (modified) mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir (-85)
- (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+3-3)
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (-89)
- (modified) mlir/test/Dialect/Linalg/transpose-matmul.mlir (+26-12)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+7-2)
- (modified) mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py (+1-1)
- (modified) mlir/utils/tree-sitter-mlir/dialect/linalg.js (-2)
- (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_a
- cpp_class_name: MatmulTransposeAOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_b
- cpp_class_name: MatmulTransposeBOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_a
- cpp_class_name: BatchMatmulTransposeAOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_b
- cpp_class_name: BatchMatmulTransposeBOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+ BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(), controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
///
/// with
///
-/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = matmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ newLHS = matmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ matmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
///
/// with
///
-/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2, d3;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2, d3);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
- batchMatmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = batchMatmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
} else {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
- batchMatmulOp.getOutputs());
+ newLHS = batchMatmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ batchMatmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization\n");
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported\n");
return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwc...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/147961
More information about the Mlir-commits
mailing list