[Mlir-commits] [mlir] [MLIR][LLVM][SROA] Make GEP handling type agnostic (PR #86950)

Tobias Gysi llvmlistbot at llvm.org
Thu Mar 28 07:49:23 PDT 2024


================
@@ -431,29 +433,163 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
   return DeletionKind::Delete;
 }
 
-static bool isFirstIndexZero(LLVM::GEPOp gep) {
-  IntegerAttr index =
-      llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
-  return index && index.getInt() == 0;
+/// Returns the amount of bytes the provided GEP elements will offset the
+/// pointer by. Returns nullopt if no constant offset could be computed.
+static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
+                                               LLVM::GEPOp gep) {
+  // Collects all indices.
+  SmallVector<uint64_t> indices;
+  for (auto index : gep.getIndices()) {
+    auto constIndex = dyn_cast<IntegerAttr>(index);
+    if (!constIndex)
+      return {};
+    int64_t gepIndex = constIndex.getInt();
+    // Negative indices are not supported.
+    if (gepIndex < 0)
+      return {};
+    indices.push_back(gepIndex);
+  }
+
+  Type currentType = gep.getElemType();
+  uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
+
+  for (uint64_t index : llvm::drop_begin(indices)) {
+    bool shouldCancel =
+        TypeSwitch<Type, bool>(currentType)
+            .Case([&](LLVM::LLVMArrayType arrayType) {
+              offset +=
+                  index * dataLayout.getTypeSize(arrayType.getElementType());
+              currentType = arrayType.getElementType();
+              return false;
+            })
+            .Case([&](LLVM::LLVMStructType structType) {
+              ArrayRef<Type> body = structType.getBody();
+              assert(index < body.size() && "expected valid struct indexing");
+              for (uint32_t i : llvm::seq(index)) {
+                if (!structType.isPacked())
+                  offset = llvm::alignTo(
+                      offset, dataLayout.getTypeABIAlignment(body[i]));
+                offset += dataLayout.getTypeSize(body[i]);
+              }
+
+              // Align for the current type as well.
+              if (!structType.isPacked())
+                offset = llvm::alignTo(
+                    offset, dataLayout.getTypeABIAlignment(body[index]));
+              currentType = body[index];
+              return false;
+            })
+            .Default([&](Type type) {
+              LLVM_DEBUG(llvm::dbgs()
+                         << "[sroa] Unsupported type for offset computations"
+                         << type << "\n");
+              return true;
+            });
+
+    if (shouldCancel)
+      return std::nullopt;
+  }
+
+  return offset;
+}
+
+namespace {
+/// Helper that contains information about accesses into a subslot.
+struct SubslotAccessInfo {
+  /// The parent slot's index that the access falls into.
+  uint32_t index;
+  /// The offset into the subslot of the access.
+  uint64_t subslotOffset;
+};
+} // namespace
+
+/// Computes subslot access information for an access into `slot` with the given
+/// offset.
+/// Returns nullopt when the offset is out-of-bounds or when the access is into
+/// the padding of `slot`.
+static std::optional<SubslotAccessInfo>
+getSubslotAccessInfo(const DestructurableMemorySlot &slot,
+                     const DataLayout &dataLayout, LLVM::GEPOp gep) {
+  std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
+  if (!offset)
+    return {};
+
+  // Helper to check that a constant index in the bounds of the GEP index
+  // representation.
+  auto isOutOfBoundsGEPIndex = [](uint64_t index) {
+    return index > (1 << LLVM::kGEPConstantBitWidth);
+  };
+
+  Type type = slot.elemType;
+  if (*offset >= dataLayout.getTypeSize(type))
+    return {};
+  return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
+      .Case([&](LLVM::LLVMArrayType arrayType)
+                -> std::optional<SubslotAccessInfo> {
+        // Find which element of the array contains the offset.
+        uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
+        uint64_t index = *offset / elemSize;
+        if (isOutOfBoundsGEPIndex(index))
+          return {};
+        return SubslotAccessInfo{static_cast<uint32_t>(index),
+                                 *offset - (index * elemSize)};
+      })
+      .Case([&](LLVM::LLVMStructType structType)
+                -> std::optional<SubslotAccessInfo> {
+        uint64_t distanceToStart = 0;
+        // Walk over the elements of the struct to find in which of
+        // them the offset is.
+        for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
+          uint64_t elemSize = dataLayout.getTypeSize(elem);
+          if (!structType.isPacked()) {
+            distanceToStart = llvm::alignTo(
+                distanceToStart, dataLayout.getTypeABIAlignment(elem));
+            // If the offset is in padding, cancel the rewrite.
+            if (offset < distanceToStart)
+              return {};
+          }
+
+          if (offset < distanceToStart + elemSize) {
+            if (isOutOfBoundsGEPIndex(index))
+              return {};
+            // The offset is within this element, stop iterating the
+            // struct and return the index.
+            return SubslotAccessInfo{static_cast<uint32_t>(index),
+                                     *offset - distanceToStart};
+          }
+
+          // The offset is not within this element, continue walking
+          // over the struct.
+          distanceToStart += elemSize;
+        }
+
+        return {};
+      });
+}
+
+/// Constructs a byte array type of the given size.
+static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
+                                            unsigned size) {
+  auto byteType = IntegerType::get(context, 8);
+  return LLVM::LLVMArrayType::get(context, byteType, size);
 }
 
 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
     const DataLayout &dataLayout) {
   if (getBase() != slot.ptr)
     return success();
-  if (slot.elemType != getElemType())
-    return failure();
-  if (!isFirstIndexZero(*this))
+  std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
+  if (!gepOffset)
     return failure();
-  // Dynamic indices can be out-of-bounds (even negative), so an access with
-  // dynamic indices can never be considered safe.
-  if (!getDynamicIndices().empty())
+  uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
+  // Check that the access is strictly inside the slot.
+  if (*gepOffset >= slotSize)
     return failure();
-  Type reachedType = getResultPtrElementType();
-  if (!reachedType)
-    return failure();
-  mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+  // Every access that remains in bounds of the remaining slot is considered
+  // legal.
+  mustBeSafelyUsed.emplace_back<MemorySlot>(
+      {getBase(), getByteArrayType(getContext(), slotSize - *gepOffset)});
----------------
gysit wrote:

should this be getRes() as well? Maybe a test for a hierarchical case could make sense?

https://github.com/llvm/llvm-project/pull/86950


More information about the Mlir-commits mailing list