[Mlir-commits] [mlir] [mlir][memref] Fix segfault in SROA. (PR #71063)

Théo Degioanni llvmlistbot at llvm.org
Sun Nov 5 12:22:53 PST 2023


https://github.com/Moxinilian updated https://github.com/llvm/llvm-project/pull/71063

>From a6c6a8be3df970d0a1d4fa615979348c41c7d553 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
 <30992420+Moxinilian at users.noreply.github.com>
Date: Thu, 2 Nov 2023 15:08:57 +0100
Subject: [PATCH 1/2] [mlir][memref] Fix segfault in SROA.

---
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    | 36 +++++++++++++------
 mlir/test/Dialect/MemRef/sroa.mlir            | 25 +++++++++++--
 2 files changed, 48 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 93ec2bcdf58fa4a..dd7c39a01d36bcf 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -187,14 +187,28 @@ DeletionKind memref::LoadOp::removeBlockingUses(
   return DeletionKind::Delete;
 }
 
-/// Returns the index of a memref in attribute form, given its indices.
+/// Returns the index of a memref in attribute form, given its indices. Returns
+/// a null pointer if whether the indices form a valid index for the provided
+/// MemRefType cannot be computed.
 static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
-                                                    ValueRange indices) {
+                                                    ValueRange indices,
+                                                    MemRefType memrefType) {
+  ArrayRef<int64_t> shape = memrefType.getShape();
+  if (indices.size() != memrefType.getShape().size())
+    return {};
   SmallVector<Attribute> index;
-  for (Value coord : indices) {
+  for (size_t i = 0; i < shape.size(); i++) {
+    Value coord = indices[i];
+    int64_t dimSize = shape[i];
     IntegerAttr coordAttr;
     if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
       return {};
+    if (!coordAttr.getType().isIndex())
+      return {};
+    // MemRefType shape dimensions are always positive (checked by verifier).
+    std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
+    if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
+      return {};
     index.push_back(coordAttr);
   }
   return ArrayAttr::get(ctx, index);
@@ -205,8 +219,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
   if (slot.ptr != getMemRef())
     return false;
-  Attribute index =
-      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  Attribute index = getAttributeIndexFromIndexOperands(
+      getContext(), getIndices(), getMemRefType());
   if (!index)
     return false;
   usedIndices.insert(index);
@@ -216,8 +230,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
                                     RewriterBase &rewriter) {
-  Attribute index =
-      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  Attribute index = getAttributeIndexFromIndexOperands(
+      getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
   rewriter.updateRootInPlace(*this, [&]() {
     setMemRef(memorySlot.ptr);
@@ -258,8 +272,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
                                 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
   if (slot.ptr != getMemRef() || getValue() == slot.ptr)
     return false;
-  Attribute index =
-      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  Attribute index = getAttributeIndexFromIndexOperands(
+      getContext(), getIndices(), getMemRefType());
   if (!index || !slot.elementPtrs.contains(index))
     return false;
   usedIndices.insert(index);
@@ -269,8 +283,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
                                      DenseMap<Attribute, MemorySlot> &subslots,
                                      RewriterBase &rewriter) {
-  Attribute index =
-      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  Attribute index = getAttributeIndexFromIndexOperands(
+      getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
   rewriter.updateRootInPlace(*this, [&]() {
     setMemRef(memorySlot.ptr);
diff --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir
index d78053d8ea777e7..40ab9b3483b833a 100644
--- a/mlir/test/Dialect/MemRef/sroa.mlir
+++ b/mlir/test/Dialect/MemRef/sroa.mlir
@@ -132,9 +132,9 @@ func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
 
 // -----
 
-// CHECK-LABEL: func.func @no_out_of_bounds
+// CHECK-LABEL: func.func @no_out_of_bound_write
 // CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
-func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
+func.func @no_out_of_bound_write(%arg0: i32, %arg1: i32) -> i32 {
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK: %[[C100:.*]] = arith.constant 100 : index
@@ -152,3 +152,24 @@ func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
   // CHECK: return %[[RES]] : i32
   return %res : i32
 }
+
+// -----
+
+// CHECK-LABEL: func.func @no_out_of_bound_load
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_out_of_bound_load(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[C100:.*]] = arith.constant 100 : index
+  %c100 = arith.constant 100 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C100]]]
+  %res = memref.load %alloca[%c100] : memref<2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}

>From e5e879548ed420755b6105f771789e3715655f46 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
 <30992420+Moxinilian at users.noreply.github.com>
Date: Sun, 5 Nov 2023 21:22:39 +0100
Subject: [PATCH 2/2] Address review comments

---
 mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dd7c39a01d36bcf..be301c191d5139e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -189,22 +189,16 @@ DeletionKind memref::LoadOp::removeBlockingUses(
 
 /// Returns the index of a memref in attribute form, given its indices. Returns
 /// a null pointer if whether the indices form a valid index for the provided
-/// MemRefType cannot be computed.
+/// MemRefType cannot be computed. The indices must come from a valid memref
+/// StoreOp or LoadOp.
 static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
                                                     ValueRange indices,
                                                     MemRefType memrefType) {
-  ArrayRef<int64_t> shape = memrefType.getShape();
-  if (indices.size() != memrefType.getShape().size())
-    return {};
   SmallVector<Attribute> index;
-  for (size_t i = 0; i < shape.size(); i++) {
-    Value coord = indices[i];
-    int64_t dimSize = shape[i];
+  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
     IntegerAttr coordAttr;
     if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
       return {};
-    if (!coordAttr.getType().isIndex())
-      return {};
     // MemRefType shape dimensions are always positive (checked by verifier).
     std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
     if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))



More information about the Mlir-commits mailing list