[Mlir-commits] [mlir] [mlir][vector] Prevent folding of OOB values in insert/extract (PR #135498)

Fehr Mathieu llvmlistbot at llvm.org
Thu Apr 17 20:09:17 PDT 2025


https://github.com/math-fehr updated https://github.com/llvm/llvm-project/pull/135498

>From ca143d6a228c0dc28482a7818b5e8eea72263e8d Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Sat, 12 Apr 2025 20:15:50 +0100
Subject: [PATCH 1/3] [mlir][vector] Prevent folding of OOB values in
 insert/extract

Out of bound position values should not be folded in
vector.extract and vector.insert operations, as only in bounds
constants and -1 are valid.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 16 +++++++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 38 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..0031608e2c9d5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1996,6 +1996,12 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
   OperandRange dynamicPosition = op.getDynamicPosition();
   ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
+  ArrayRef<int64_t> vectorShape;
+  if constexpr (std::is_same_v<OpType, ExtractOp>) {
+    vectorShape = op.getSourceVectorType().getShape();
+  } else if constexpr (std::is_same_v<OpType, InsertOp>) {
+    vectorShape = op.getDestVectorType().getShape();
+  }
 
   // If the dynamic operands is empty, it is returned directly.
   if (!dynamicPosition.size())
@@ -2012,9 +2018,13 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
     Attribute positionAttr = dynamicPositionAttr[index];
     Value position = dynamicPosition[index++];
     if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
-      staticPosition[i] = attr.getInt();
-      opChange = true;
-      continue;
+      int64_t value = attr.getInt();
+      // Do not fold if the value is out of bounds.
+      if (value >= 0 && value < vectorShape[i]) {
+        staticPosition[i] = attr.getInt();
+        opChange = true;
+        continue;
+      }
     }
     operands.push_back(position);
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..b0f502a0b7c36 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3233,3 +3233,41 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
+
+// -----
+
+// Check that out of bounds indices are not folded for vector.insert
+
+// CHECK-LABEL: @fold_insert_oob
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
+//       CHECK:   %[[OOB1:.*]] = arith.constant -2 : index
+//       CHECK:   %[[OOB2:.*]] = arith.constant 2 : index
+//       CHECK:   %[[VAL:.*]] = arith.constant 1 : i32
+//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
+//       CHECK:   return %[[RES]] : vector<4x1x2xi32>
+func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
+  %0 = arith.constant 0 : index
+  %-2 = arith.constant -2 : index
+  %2 = arith.constant 2 : index
+  %1 = arith.constant 1 : i32
+  %res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32>
+  return %res : vector<4x1x2xi32>
+}
+
+// -----
+
+// Check that out of bounds indices are not folded for vector.extract
+
+// CHECK-LABEL: @fold_extract_oob
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
+//       CHECK:   %[[OOB1:.*]] = arith.constant -2 : index
+//       CHECK:   %[[OOB2:.*]] = arith.constant 2 : index
+//       CHECK:   %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
+//       CHECK:   return %[[RES]] : i32
+func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
+  %0 = arith.constant 0 : index
+  %-2 = arith.constant -2 : index
+  %2 = arith.constant 2 : index
+  %res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32>
+  return %res : i32
+}

>From 8d7718a717e7f9bfe0ffe5d8668f0cf72ef496f4 Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Thu, 17 Apr 2025 05:45:16 +0100
Subject: [PATCH 2/3] Address comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir | 22 +++++++++++-----------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0031608e2c9d5..95f9d8a134de4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1999,7 +1999,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   ArrayRef<int64_t> vectorShape;
   if constexpr (std::is_same_v<OpType, ExtractOp>) {
     vectorShape = op.getSourceVectorType().getShape();
-  } else if constexpr (std::is_same_v<OpType, InsertOp>) {
+  } else {
     vectorShape = op.getDestVectorType().getShape();
   }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b0f502a0b7c36..ec2f823f4c701 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3236,7 +3236,7 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
 
 // -----
 
-// Check that out of bounds indices are not folded for vector.insert
+// Check that out of bounds indices are not folded for vector.insert.
 
 // CHECK-LABEL: @fold_insert_oob
 //  CHECK-SAME:   %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
@@ -3246,17 +3246,17 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
 //       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
 //       CHECK:   return %[[RES]] : vector<4x1x2xi32>
 func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
-  %0 = arith.constant 0 : index
-  %-2 = arith.constant -2 : index
-  %2 = arith.constant 2 : index
-  %1 = arith.constant 1 : i32
-  %res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32>
+  %c0 = arith.constant 0 : index
+  %c-2 = arith.constant -2 : index
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : i32
+  %res = vector.insert %c1, %arg[%c0, %c-2, %c2] : i32 into vector<4x1x2xi32>
   return %res : vector<4x1x2xi32>
 }
 
 // -----
 
-// Check that out of bounds indices are not folded for vector.extract
+// Check that out of bounds indices are not folded for vector.extract.
 
 // CHECK-LABEL: @fold_extract_oob
 //  CHECK-SAME:   %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
@@ -3265,9 +3265,9 @@ func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
 //       CHECK:   %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
 //       CHECK:   return %[[RES]] : i32
 func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
-  %0 = arith.constant 0 : index
-  %-2 = arith.constant -2 : index
-  %2 = arith.constant 2 : index
-  %res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32>
+  %c0 = arith.constant 0 : index
+  %c-2 = arith.constant -2 : index
+  %c2 = arith.constant 2 : index
+  %res = vector.extract %arg[%c0, %c-2, %c2] : i32 from vector<4x1x2xi32>
   return %res : i32
 }

>From 34ab7330bb7751c0016b543b49be7e77f855e6b3 Mon Sep 17 00:00:00 2001
From: Fehr Mathieu <mathieu.fehr at gmail.com>
Date: Fri, 18 Apr 2025 05:09:09 +0200
Subject: [PATCH 3/3] Update mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f9d8a134de4..71077d4943aa5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1997,11 +1997,10 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   OperandRange dynamicPosition = op.getDynamicPosition();
   ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
   ArrayRef<int64_t> vectorShape;
-  if constexpr (std::is_same_v<OpType, ExtractOp>) {
+  if constexpr (std::is_same_v<OpType, ExtractOp>)
     vectorShape = op.getSourceVectorType().getShape();
-  } else {
+  else
     vectorShape = op.getDestVectorType().getShape();
-  }
 
   // If the dynamic operands is empty, it is returned directly.
   if (!dynamicPosition.size())



More information about the Mlir-commits mailing list