[Mlir-commits] [mlir] [nlir][vector] Constrain `ContractionOpToMatmulOpLowering` (PR #102225)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Aug 7 02:13:41 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/102225
>From 31b1868f6a3b8cd3753773c92b4fa2b24040c47d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 6 Aug 2024 22:06:53 +0100
Subject: [PATCH 1/2] [nlir][vector] Constrain
`ContractionOpToMatmulOpLowering`
Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This
pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not
aware of any use of that in the context of scalable vectors.
---
.../Vector/Transforms/LowerVectorContract.cpp | 13 ++++++++++---
...r-contract-to-matrix-intrinsics-transforms.mlir | 14 ++++++++++++--
2 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 3a799ce8e0bce3..b8ebbd8382be82 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1283,6 +1283,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rew) const {
@@ -1302,13 +1304,18 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
!isReductionIterator(iteratorTypes[2]))
return failure();
+ Type opResType = op.getType();
+ VectorType vecType = dyn_cast<VectorType>(opResType);
+ if (vecType && vecType.isScalable()) {
+ // This should be sufficient to reject all cases with scalable vectors.
+ return failure();
+ }
+
Type elementType = op.getLhsType().getElementType();
if (!elementType.isIntOrFloat())
return failure();
- Type dstElementType = op.getType();
- if (auto vecType = dyn_cast<VectorType>(dstElementType))
- dstElementType = vecType.getElementType();
+ Type dstElementType = vecType ? vecType.getElementType() : opResType;
if (elementType != dstElementType)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 78cf82e1ab6c1a..4867a416e5d144 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -36,13 +36,23 @@
// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
func.func @matmul(%arg0: vector<2x4xf32>,
- %arg1: vector<4x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %arg1: vector<4x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-NOT: vector.matrix_multiply
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 9060cdf64a5bcaa169579193280cf7ccebaf674b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 7 Aug 2024 10:13:22 +0100
Subject: [PATCH 2/2] fixup! [nlir][vector] Constrain
`ContractionOpToMatmulOpLowering`
Remove ambiguity in comment
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index b8ebbd8382be82..21261478f0648f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1307,7 +1307,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Type opResType = op.getType();
VectorType vecType = dyn_cast<VectorType>(opResType);
if (vecType && vecType.isScalable()) {
- // This should be sufficient to reject all cases with scalable vectors.
+ // Note - this is sufficient to reject all cases with scalable vectors.
return failure();
}
More information about the Mlir-commits
mailing list