[Mlir-commits] [mlir] [mlir][memref] Fix segfault in SROA. (PR #71063)
Théo Degioanni
llvmlistbot at llvm.org
Sun Nov 5 12:22:53 PST 2023
https://github.com/Moxinilian updated https://github.com/llvm/llvm-project/pull/71063
>From a6c6a8be3df970d0a1d4fa615979348c41c7d553 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
<30992420+Moxinilian at users.noreply.github.com>
Date: Thu, 2 Nov 2023 15:08:57 +0100
Subject: [PATCH 1/2] [mlir][memref] Fix segfault in SROA.
---
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 36 +++++++++++++------
mlir/test/Dialect/MemRef/sroa.mlir | 25 +++++++++++--
2 files changed, 48 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 93ec2bcdf58fa4a..dd7c39a01d36bcf 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -187,14 +187,28 @@ DeletionKind memref::LoadOp::removeBlockingUses(
return DeletionKind::Delete;
}
-/// Returns the index of a memref in attribute form, given its indices.
+/// Returns the index of a memref in attribute form, given its indices. Returns
+/// a null pointer if whether the indices form a valid index for the provided
+/// MemRefType cannot be computed.
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
- ValueRange indices) {
+ ValueRange indices,
+ MemRefType memrefType) {
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ if (indices.size() != memrefType.getShape().size())
+ return {};
SmallVector<Attribute> index;
- for (Value coord : indices) {
+ for (size_t i = 0; i < shape.size(); i++) {
+ Value coord = indices[i];
+ int64_t dimSize = shape[i];
IntegerAttr coordAttr;
if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
return {};
+ if (!coordAttr.getType().isIndex())
+ return {};
+ // MemRefType shape dimensions are always positive (checked by verifier).
+ std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
+ if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
+ return {};
index.push_back(coordAttr);
}
return ArrayAttr::get(ctx, index);
@@ -205,8 +219,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
if (slot.ptr != getMemRef())
return false;
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
if (!index)
return false;
usedIndices.insert(index);
@@ -216,8 +230,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
rewriter.updateRootInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
@@ -258,8 +272,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
if (slot.ptr != getMemRef() || getValue() == slot.ptr)
return false;
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
if (!index || !slot.elementPtrs.contains(index))
return false;
usedIndices.insert(index);
@@ -269,8 +283,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- Attribute index =
- getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ Attribute index = getAttributeIndexFromIndexOperands(
+ getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
rewriter.updateRootInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
diff --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir
index d78053d8ea777e7..40ab9b3483b833a 100644
--- a/mlir/test/Dialect/MemRef/sroa.mlir
+++ b/mlir/test/Dialect/MemRef/sroa.mlir
@@ -132,9 +132,9 @@ func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
// -----
-// CHECK-LABEL: func.func @no_out_of_bounds
+// CHECK-LABEL: func.func @no_out_of_bound_write
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
-func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
+func.func @no_out_of_bound_write(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[C100:.*]] = arith.constant 100 : index
@@ -152,3 +152,24 @@ func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: return %[[RES]] : i32
return %res : i32
}
+
+// -----
+
+// CHECK-LABEL: func.func @no_out_of_bound_load
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_out_of_bound_load(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C100:.*]] = arith.constant 100 : index
+ %c100 = arith.constant 100 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C100]]]
+ %res = memref.load %alloca[%c100] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
>From e5e879548ed420755b6105f771789e3715655f46 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
<30992420+Moxinilian at users.noreply.github.com>
Date: Sun, 5 Nov 2023 21:22:39 +0100
Subject: [PATCH 2/2] Address review comments
---
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 12 +++---------
1 file changed, 3 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dd7c39a01d36bcf..be301c191d5139e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -189,22 +189,16 @@ DeletionKind memref::LoadOp::removeBlockingUses(
/// Returns the index of a memref in attribute form, given its indices. Returns
/// a null pointer if whether the indices form a valid index for the provided
-/// MemRefType cannot be computed.
+/// MemRefType cannot be computed. The indices must come from a valid memref
+/// StoreOp or LoadOp.
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
ValueRange indices,
MemRefType memrefType) {
- ArrayRef<int64_t> shape = memrefType.getShape();
- if (indices.size() != memrefType.getShape().size())
- return {};
SmallVector<Attribute> index;
- for (size_t i = 0; i < shape.size(); i++) {
- Value coord = indices[i];
- int64_t dimSize = shape[i];
+ for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
IntegerAttr coordAttr;
if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
return {};
- if (!coordAttr.getType().isIndex())
- return {};
// MemRefType shape dimensions are always positive (checked by verifier).
std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
More information about the Mlir-commits
mailing list