[Mlir-commits] [mlir] [mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize (PR #93590)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Jun 5 08:51:47 PDT 2024
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/93590
>From 38689e21143e76228b6db435e2e5f4752d0d089c Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 5 Jun 2024 08:51:12 -0700
Subject: [PATCH] [mlir][vector] Use notifyMatchFailure instead of assert in
VectorLinearize
---
.../Vector/Transforms/VectorLinearize.cpp | 23 +++++++++++--------
mlir/test/Dialect/Vector/linearize.mlir | 14 +++++++++--
2 files changed, 26 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf742f6297..d12686a26094c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstType).isScalable()) &&
- "scalable vectors are not supported.");
+ if (extractOp.getVector().getType().isScalable() ||
+ cast<VectorType>(dstType).isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -265,6 +266,8 @@ struct LinearizeVectorShuffle final
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
+ // The assert is used because vector.shuffle does not support scalable
+ // vectors.
assert(!(shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
cast<VectorType>(dstType).isScalable()) &&
@@ -336,9 +339,10 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ if (extractOp.getVector().getType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -395,9 +399,10 @@ struct LinearizeVectorInsert final
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
- assert(!(insertOp.getDestVectorType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ if (insertOp.getDestVectorType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable())
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
targetVectorBitWidth))
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 31a59b809a74b..70334613be808 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -129,8 +129,8 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
// -----
// ALL-LABEL: func.func @test_scalable_no_linearize(
-// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
-func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
@@ -140,6 +140,15 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
// ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
%2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+ // ALL: %[[EXTRACTSLICE:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [1, 0], sizes = [1, 2], strides = [1, 1]} : vector<2x[2]xf32> to vector<1x[2]xf32>
+ %3 = vector.extract_strided_slice %arg1 { sizes = [1, 2], strides = [1, 1], offsets = [1, 0] } : vector<2x[2]xf32> to vector<1x[2]xf32>
+
+ // ALL: %[[EXTRACT:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<2x[2]xf32>
+ %4 = vector.extract %arg1[0, 0]: f32 from vector<2x[2]xf32>
+
+ // ALL: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[VAL_1]] [0, 0] : f32 into vector<2x[2]xf32>
+ %5 = vector.insert %4, %arg1[0, 0]: f32 into vector<2x[2]xf32>
+
// ALL: return %[[RES]] : vector<[2]x[2]xf32>
return %2 : vector<[2]x[2]xf32>
}
@@ -274,3 +283,4 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
return %0 : vector<2x8x4xf32>
}
+
More information about the Mlir-commits
mailing list