[Mlir-commits] [mlir] [mlir][vector] Prevent folding of OOB values in insert/extract (PR #135498)
Fehr Mathieu
llvmlistbot at llvm.org
Sat Apr 12 11:23:49 PDT 2025
https://github.com/math-fehr created https://github.com/llvm/llvm-project/pull/135498
Out of bound position values should not be folded in vector.extract and vector.insert operations, as only in bounds constants and -1 are valid.
>From ca143d6a228c0dc28482a7818b5e8eea72263e8d Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Sat, 12 Apr 2025 20:15:50 +0100
Subject: [PATCH] [mlir][vector] Prevent folding of OOB values in
insert/extract
Out of bound position values should not be folded in
vector.extract and vector.insert operations, as only in bounds
constants and -1 are valid.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 +++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 38 ++++++++++++++++++++++
2 files changed, 51 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..0031608e2c9d5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1996,6 +1996,12 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
+ ArrayRef<int64_t> vectorShape;
+ if constexpr (std::is_same_v<OpType, ExtractOp>) {
+ vectorShape = op.getSourceVectorType().getShape();
+ } else if constexpr (std::is_same_v<OpType, InsertOp>) {
+ vectorShape = op.getDestVectorType().getShape();
+ }
// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
@@ -2012,9 +2018,13 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
- staticPosition[i] = attr.getInt();
- opChange = true;
- continue;
+ int64_t value = attr.getInt();
+ // Do not fold if the value is out of bounds.
+ if (value >= 0 && value < vectorShape[i]) {
+ staticPosition[i] = attr.getInt();
+ opChange = true;
+ continue;
+ }
}
operands.push_back(position);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..b0f502a0b7c36 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3233,3 +3233,41 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
+
+// -----
+
+// Check that out of bounds indices are not folded for vector.insert
+
+// CHECK-LABEL: @fold_insert_oob
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
+// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
+// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
+// CHECK: return %[[RES]] : vector<4x1x2xi32>
+func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
+ %0 = arith.constant 0 : index
+ %-2 = arith.constant -2 : index
+ %2 = arith.constant 2 : index
+ %1 = arith.constant 1 : i32
+ %res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32>
+ return %res : vector<4x1x2xi32>
+}
+
+// -----
+
+// Check that out of bounds indices are not folded for vector.extract
+
+// CHECK-LABEL: @fold_extract_oob
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
+// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
+// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
+// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
+// CHECK: return %[[RES]] : i32
+func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
+ %0 = arith.constant 0 : index
+ %-2 = arith.constant -2 : index
+ %2 = arith.constant 2 : index
+ %res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32>
+ return %res : i32
+}
More information about the Mlir-commits
mailing list