[Mlir-commits] [mlir] [mlir][vector] Fix extractelement/insertelement folder crash on poiso… (PR #71333)
Ivan Butygin
llvmlistbot at llvm.org
Sun Nov 5 13:49:49 PST 2023
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/71333
…n attr
Types of incoming attributes weren't properly checked.
>From 02f76ffc1d6f714d10461d13b72b9c5cdf16baf7 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 5 Nov 2023 22:40:23 +0100
Subject: [PATCH] [mlir][vector] Fix extractelement/insertelement folder crash
on poison attr
Types of incoming attributes weren't properly checked.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 26 +++++----
mlir/test/Dialect/Vector/canonicalize.mlir | 66 +++++++++++++++++++++-
2 files changed, 79 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 60416f550ee619d..69cbdcd3f536f98 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1188,9 +1188,6 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getPosition())
return {};
- Attribute src = adaptor.getVector();
- Attribute pos = adaptor.getPosition();
-
// Fold extractelement (splat X) -> X.
if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
return splat.getInput();
@@ -1200,13 +1197,16 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
return broadcast.getSource();
+ auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
+ auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
if (!pos || !src)
return {};
- auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<Attribute>();
+ auto srcElements = src.getValues<Attribute>();
- auto attr = llvm::dyn_cast<IntegerAttr>(pos);
- uint64_t posIdx = attr.getInt();
+ uint64_t posIdx = pos.getInt();
+ if (posIdx >= srcElements.size())
+ return {};
return srcElements[posIdx];
}
@@ -2511,18 +2511,20 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getPosition())
return {};
- Attribute src = adaptor.getSource();
- Attribute dst = adaptor.getDest();
- Attribute pos = adaptor.getPosition();
+ auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
+ auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
+ auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
if (!src || !dst || !pos)
return {};
- auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<Attribute>();
+ if (src.getType() != getDestVectorType().getElementType())
+ return {};
+
+ auto dstElements = dst.getValues<Attribute>();
SmallVector<Attribute> results(dstElements);
- auto attr = llvm::dyn_cast<IntegerAttr>(pos);
- uint64_t posIdx = attr.getInt();
+ uint64_t posIdx = pos.getInt();
if (posIdx >= results.size())
return {};
results[posIdx] = src;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f6bb42b1b249153..163fdd67b0cfd34 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2027,6 +2027,46 @@ func.func @insert_element_invalid_fold() -> vector<1xf32> {
return %46 : vector<1xf32>
}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold1
+// CHECK: vector.insertelement
+func.func @insert_poison_fold1() -> vector<4xi32> {
+ %v = ub.poison : vector<4xi32>
+ %s = arith.constant 7 : i32
+ %i = arith.constant 2 : i32
+ %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+ return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold2
+// CHECK: vector.insertelement
+func.func @insert_poison_fold2() -> vector<4xi32> {
+ %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+ %s = ub.poison : i32
+ %i = arith.constant 2 : i32
+ %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+ return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold3
+// CHECK: vector.insertelement
+func.func @insert_poison_fold3() -> vector<4xi32> {
+ %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+ %s = arith.constant 7 : i32
+ %i = ub.poison : i32
+ %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+ return %1 : vector<4xi32>
+}
+
// -----
// CHECK-LABEL: func @extract_element_fold
@@ -2051,6 +2091,30 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
// -----
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold1
+// CHECK: vector.extractelement
+func.func @extract_element_poison_fold1() -> i32 {
+ %v = ub.poison : vector<4xi32>
+ %i = arith.constant 2 : i32
+ %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+ return %1 : i32
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold2
+// CHECK: vector.extractelement
+func.func @extract_element_poison_fold2() -> i32 {
+ %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+ %i = ub.poison : i32
+ %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+ return %1 : i32
+}
+
+// -----
+
// CHECK-LABEL: func @reduce_one_element_vector_extract
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2436,4 +2500,4 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
-}
\ No newline at end of file
+}
More information about the Mlir-commits
mailing list