[Mlir-commits] [mlir] [MLIR][LLVM][SROA] Make GEP handling type agnostic (PR #86950)
Tobias Gysi
llvmlistbot at llvm.org
Thu Mar 28 07:49:23 PDT 2024
================
@@ -431,29 +433,163 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
return DeletionKind::Delete;
}
-static bool isFirstIndexZero(LLVM::GEPOp gep) {
- IntegerAttr index =
- llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
- return index && index.getInt() == 0;
+/// Returns the amount of bytes the provided GEP elements will offset the
+/// pointer by. Returns nullopt if no constant offset could be computed.
+static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
+ LLVM::GEPOp gep) {
+ // Collects all indices.
+ SmallVector<uint64_t> indices;
+ for (auto index : gep.getIndices()) {
+ auto constIndex = dyn_cast<IntegerAttr>(index);
+ if (!constIndex)
+ return {};
+ int64_t gepIndex = constIndex.getInt();
+ // Negative indices are not supported.
+ if (gepIndex < 0)
+ return {};
+ indices.push_back(gepIndex);
+ }
+
+ Type currentType = gep.getElemType();
+ uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
+
+ for (uint64_t index : llvm::drop_begin(indices)) {
+ bool shouldCancel =
+ TypeSwitch<Type, bool>(currentType)
+ .Case([&](LLVM::LLVMArrayType arrayType) {
+ offset +=
+ index * dataLayout.getTypeSize(arrayType.getElementType());
+ currentType = arrayType.getElementType();
+ return false;
+ })
+ .Case([&](LLVM::LLVMStructType structType) {
+ ArrayRef<Type> body = structType.getBody();
+ assert(index < body.size() && "expected valid struct indexing");
+ for (uint32_t i : llvm::seq(index)) {
+ if (!structType.isPacked())
+ offset = llvm::alignTo(
+ offset, dataLayout.getTypeABIAlignment(body[i]));
+ offset += dataLayout.getTypeSize(body[i]);
+ }
+
+ // Align for the current type as well.
+ if (!structType.isPacked())
+ offset = llvm::alignTo(
+ offset, dataLayout.getTypeABIAlignment(body[index]));
+ currentType = body[index];
+ return false;
+ })
+ .Default([&](Type type) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[sroa] Unsupported type for offset computations"
+ << type << "\n");
+ return true;
+ });
+
+ if (shouldCancel)
+ return std::nullopt;
+ }
+
+ return offset;
+}
+
+namespace {
+/// Helper that contains information about accesses into a subslot.
+struct SubslotAccessInfo {
+ /// The parent slot's index that the access falls into.
+ uint32_t index;
+ /// The offset into the subslot of the access.
+ uint64_t subslotOffset;
+};
+} // namespace
+
+/// Computes subslot access information for an access into `slot` with the given
+/// offset.
+/// Returns nullopt when the offset is out-of-bounds or when the access is into
+/// the padding of `slot`.
+static std::optional<SubslotAccessInfo>
+getSubslotAccessInfo(const DestructurableMemorySlot &slot,
+ const DataLayout &dataLayout, LLVM::GEPOp gep) {
+ std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
+ if (!offset)
+ return {};
+
+ // Helper to check that a constant index in the bounds of the GEP index
+ // representation.
+ auto isOutOfBoundsGEPIndex = [](uint64_t index) {
+ return index > (1 << LLVM::kGEPConstantBitWidth);
+ };
+
+ Type type = slot.elemType;
+ if (*offset >= dataLayout.getTypeSize(type))
+ return {};
+ return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
+ .Case([&](LLVM::LLVMArrayType arrayType)
+ -> std::optional<SubslotAccessInfo> {
+ // Find which element of the array contains the offset.
+ uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
+ uint64_t index = *offset / elemSize;
+ if (isOutOfBoundsGEPIndex(index))
+ return {};
+ return SubslotAccessInfo{static_cast<uint32_t>(index),
+ *offset - (index * elemSize)};
+ })
+ .Case([&](LLVM::LLVMStructType structType)
+ -> std::optional<SubslotAccessInfo> {
+ uint64_t distanceToStart = 0;
+ // Walk over the elements of the struct to find in which of
+ // them the offset is.
+ for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
+ uint64_t elemSize = dataLayout.getTypeSize(elem);
+ if (!structType.isPacked()) {
+ distanceToStart = llvm::alignTo(
+ distanceToStart, dataLayout.getTypeABIAlignment(elem));
+ // If the offset is in padding, cancel the rewrite.
+ if (offset < distanceToStart)
+ return {};
+ }
+
+ if (offset < distanceToStart + elemSize) {
+ if (isOutOfBoundsGEPIndex(index))
+ return {};
+ // The offset is within this element, stop iterating the
+ // struct and return the index.
+ return SubslotAccessInfo{static_cast<uint32_t>(index),
+ *offset - distanceToStart};
+ }
+
+ // The offset is not within this element, continue walking
+ // over the struct.
+ distanceToStart += elemSize;
+ }
+
+ return {};
+ });
+}
+
+/// Constructs a byte array type of the given size.
+static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
+ unsigned size) {
+ auto byteType = IntegerType::get(context, 8);
+ return LLVM::LLVMArrayType::get(context, byteType, size);
}
LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getBase() != slot.ptr)
return success();
- if (slot.elemType != getElemType())
- return failure();
- if (!isFirstIndexZero(*this))
+ std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
+ if (!gepOffset)
return failure();
- // Dynamic indices can be out-of-bounds (even negative), so an access with
- // dynamic indices can never be considered safe.
- if (!getDynamicIndices().empty())
+ uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
+ // Check that the access is strictly inside the slot.
+ if (*gepOffset >= slotSize)
return failure();
- Type reachedType = getResultPtrElementType();
- if (!reachedType)
- return failure();
- mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ // Every access that remains in bounds of the remaining slot is considered
+ // legal.
+ mustBeSafelyUsed.emplace_back<MemorySlot>(
+ {getBase(), getByteArrayType(getContext(), slotSize - *gepOffset)});
----------------
gysit wrote:
should this be getRes() as well? Maybe a test for a hierarchical case could make sense?
https://github.com/llvm/llvm-project/pull/86950
More information about the Mlir-commits
mailing list