[Mlir-commits] [mlir] 0a22a80 - [mlir][vector] Fix extractelement/insertelement folder crash on poison attr (#71333)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 05:23:00 PST 2023


Author: Ivan Butygin
Date: 2023-11-06T16:22:56+03:00
New Revision: 0a22a80c1b83996a4424c94a3597d8f974ecb444

URL: https://github.com/llvm/llvm-project/commit/0a22a80c1b83996a4424c94a3597d8f974ecb444
DIFF: https://github.com/llvm/llvm-project/commit/0a22a80c1b83996a4424c94a3597d8f974ecb444.diff

LOG: [mlir][vector] Fix extractelement/insertelement folder crash on poison attr (#71333)

Types of incoming attributes weren't properly checked.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 60416f550ee619d..69cbdcd3f536f98 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1188,9 +1188,6 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getPosition())
     return {};
 
-  Attribute src = adaptor.getVector();
-  Attribute pos = adaptor.getPosition();
-
   // Fold extractelement (splat X) -> X.
   if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
     return splat.getInput();
@@ -1200,13 +1197,16 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
     if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
       return broadcast.getSource();
 
+  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
+  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
   if (!pos || !src)
     return {};
 
-  auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<Attribute>();
+  auto srcElements = src.getValues<Attribute>();
 
-  auto attr = llvm::dyn_cast<IntegerAttr>(pos);
-  uint64_t posIdx = attr.getInt();
+  uint64_t posIdx = pos.getInt();
+  if (posIdx >= srcElements.size())
+    return {};
 
   return srcElements[posIdx];
 }
@@ -2511,18 +2511,20 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getPosition())
     return {};
 
-  Attribute src = adaptor.getSource();
-  Attribute dst = adaptor.getDest();
-  Attribute pos = adaptor.getPosition();
+  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
+  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
+  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
   if (!src || !dst || !pos)
     return {};
 
-  auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<Attribute>();
+  if (src.getType() != getDestVectorType().getElementType())
+    return {};
+
+  auto dstElements = dst.getValues<Attribute>();
 
   SmallVector<Attribute> results(dstElements);
 
-  auto attr = llvm::dyn_cast<IntegerAttr>(pos);
-  uint64_t posIdx = attr.getInt();
+  uint64_t posIdx = pos.getInt();
   if (posIdx >= results.size())
     return {};
   results[posIdx] = src;

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f6bb42b1b249153..163fdd67b0cfd34 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2027,6 +2027,46 @@ func.func @insert_element_invalid_fold() -> vector<1xf32> {
   return %46 : vector<1xf32>
 }
 
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold1
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold1() -> vector<4xi32> {
+  %v = ub.poison : vector<4xi32>
+  %s = arith.constant 7 : i32
+  %i = arith.constant 2 : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold2
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold2() -> vector<4xi32> {
+  %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+  %s = ub.poison : i32
+  %i = arith.constant 2 : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @insert_poison_fold3
+//       CHECK:   vector.insertelement
+func.func @insert_poison_fold3() -> vector<4xi32> {
+  %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+  %s = arith.constant 7 : i32
+  %i = ub.poison : i32
+  %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
 // -----
 
 // CHECK-LABEL: func @extract_element_fold
@@ -2051,6 +2091,30 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
 
 // -----
 
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold1
+//       CHECK:   vector.extractelement
+func.func @extract_element_poison_fold1() -> i32 {
+  %v = ub.poison : vector<4xi32>
+  %i = arith.constant 2 : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}
+
+// -----
+
+// Do not crash on poison
+// CHECK-LABEL: func @extract_element_poison_fold2
+//       CHECK:   vector.extractelement
+func.func @extract_element_poison_fold2() -> i32 {
+  %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %i = ub.poison : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduce_one_element_vector_extract
 //  CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
 //       CHECK:   %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2436,4 +2500,4 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
-}
\ No newline at end of file
+}


        


More information about the Mlir-commits mailing list