[Mlir-commits] [mlir] [mlir][memref] Fix segfault in SROA. (PR #71063)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 2 07:29:40 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Théo Degioanni (Moxinilian)
<details>
<summary>Changes</summary>
Fixes #<!-- -->70902.
The out of bounds check in the SROA implementation for MemRef was not actually testing anything because it only operated on a store op which does not trigger the logic by itself. It is now checked for real and the underlying bug is fixed.
I checked the LLVM implementation just in case but this should not happen as out-of-bound checks happen in GEP's verifier there.
---
Full diff: https://github.com/llvm/llvm-project/pull/71063.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+25-11)
- (modified) mlir/test/Dialect/MemRef/sroa.mlir (+23-2)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 93ec2bcdf58fa4a..dd7c39a01d36bcf 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -187,14 +187,28 @@ DeletionKind memref::LoadOp::removeBlockingUses(
return DeletionKind::Delete;
}
-/// Returns the index of a memref in attribute form, given its indices.
+/// Returns the index of a memref in attribute form, given its indices. Returns
+/// a null pointer if whether the indices form a valid index for the provided
+/// MemRefType cannot be computed.
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
- ValueRange indices) {
+ ValueRange indices,
+ MemRefType memrefType) {
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ if (indices.size() != memrefType.getShape().size())
+ return {};
SmallVector<Attribute> index;
- for (Value coord : indices) {
+ for (size_t i = 0; i < shape.size(); i++) {
+ Value coord = indices[i];
+ int64_t dimSize = shape[i];
IntegerAttr coordAttr;
if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
return {};
+ if (!coordAttr.getType().isIndex())
+ return {};
+ // MemRefType shape dimensions are always positive (checked by verifier).
+ std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
+ if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
+ return {};
index.push_back(coordAttr);
}
return ArrayAttr::get(ctx, index);
@@ -205,8 +219,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
if (slot.ptr != getMemRef())
return false;
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
if (!index)
return false;
usedIndices.insert(index);
@@ -216,8 +230,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
rewriter.updateRootInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
@@ -258,8 +272,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
if (slot.ptr != getMemRef() || getValue() == slot.ptr)
return false;
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
if (!index || !slot.elementPtrs.contains(index))
return false;
usedIndices.insert(index);
@@ -269,8 +283,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
rewriter.updateRootInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
diff --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir
index d78053d8ea777e7..40ab9b3483b833a 100644
--- a/mlir/test/Dialect/MemRef/sroa.mlir
+++ b/mlir/test/Dialect/MemRef/sroa.mlir
@@ -132,9 +132,9 @@ func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
// -----
-// CHECK-LABEL: func.func @no_out_of_bounds
+// CHECK-LABEL: func.func @no_out_of_bound_write
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
-func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
+func.func @no_out_of_bound_write(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[C100:.*]] = arith.constant 100 : index
@@ -152,3 +152,24 @@ func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: return %[[RES]] : i32
return %res : i32
}
+
+// -----
+
+// CHECK-LABEL: func.func @no_out_of_bound_load
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_out_of_bound_load(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C100:.*]] = arith.constant 100 : index
+ %c100 = arith.constant 100 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C100]]]
+ %res = memref.load %alloca[%c100] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/71063
More information about the Mlir-commits
mailing list