[Mlir-commits] [mlir] [mlir][vector] Bugfix of linearize `vector.extract` (PR #106836)
Longsheng Mou
llvmlistbot at llvm.org
Wed Sep 4 00:26:00 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/106836
>From af153c82a877f804f5f8e7297a04f1ddcbb08335 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Sat, 31 Aug 2024 16:35:35 +0800
Subject: [PATCH] [mlir][vector] Bugfix of linearize `vector.extract`
This patch add check for `vector.extract` with scalar type, which
is not allowed when linearize `vector.extract`.
---
.../Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++++
mlir/test/Dialect/Vector/linearize.mlir | 12 ++++++++++++
2 files changed, 16 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 868397f2daaae4..11917ac1e40226 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -337,6 +337,10 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(extractOp,
+ "expected n-D vector type.");
+
if (extractOp.getVector().getType().isScalable() ||
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 916e3e5fd2529d..543e76b5b26e0c 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -306,3 +306,15 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// ALL: return %[[RES]] : vector<2x8x[4]xf32>
return %0 : vector<2x8x[4]xf32>
}
+
+// -----
+
+// ALL-LABEL: test_vector_extract_scalar
+func.func @test_vector_extract_scalar() {
+ %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
+ // ALL-NOT: vector.shuffle
+ // ALL: vector.extract
+ // ALL-NOT: vector.shuffle
+ %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+ return
+}
More information about the Mlir-commits
mailing list