[Mlir-commits] [mlir] 569f073 - [mlir][LLVM] Add support for arrays in `SplitStores` pattern
Markus Böck
llvmlistbot at llvm.org
Mon Jul 10 06:56:55 PDT 2023
Author: Markus Böck
Date: 2023-07-10T15:55:07+02:00
New Revision: 569f07319105500f0b19e6962370439b54a67e80
URL: https://github.com/llvm/llvm-project/commit/569f07319105500f0b19e6962370439b54a67e80
DIFF: https://github.com/llvm/llvm-project/commit/569f07319105500f0b19e6962370439b54a67e80.diff
LOG: [mlir][LLVM] Add support for arrays in `SplitStores` pattern
The pattern so far has only supported splitting stores into struct types, marking arrays as explicitly unsupported. This would lead to stores into arrays not being made type-consistent like structs and therefore also not being properly split by SROA and mem2reg.
This patch adds support for array types by creating a common abstraction for both structs and arrays, making an array of size n essentially be treated like a struct with n fields of the arrays element type.
This gives us immediate feature parity without special casing for either of the types.
Differential Revision: https://reviews.llvm.org/D154840
Added:
Modified:
mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
mlir/test/Dialect/LLVMIR/type-consistency.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 157f973ca9a32a..9731689e551762 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -348,20 +348,74 @@ CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
return success();
}
-/// Returns the list of fields of `structType` that are written to by a store
-/// operation writing `storeSize` bytes at `storeOffset` within the struct.
-/// `storeOffset` is required to cleanly point to an immediate field within
-/// the struct.
-/// If the write operation were to write to any padding, write beyond the
-/// struct, partially write to a field, or contains currently unsupported
-/// types, failure is returned.
-static FailureOr<ArrayRef<Type>>
-getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
+namespace {
+/// Class abstracting over both array and struct types, turning each into ranges
+/// of their sub-types.
+class DestructurableTypeRange
+ : public llvm::indexed_accessor_range<DestructurableTypeRange,
+ DestructurableTypeInterface, Type,
+ Type *, Type> {
+
+ using Base = llvm::indexed_accessor_range<
+ DestructurableTypeRange, DestructurableTypeInterface, Type, Type *, Type>;
+
+public:
+ using Base::Base;
+
+ /// Constructs a DestructurableTypeRange from either a LLVMStructType or
+ /// LLVMArrayType.
+ explicit DestructurableTypeRange(DestructurableTypeInterface base)
+ : Base(base, 0, [&]() -> ptr
diff _t {
+ return TypeSwitch<DestructurableTypeInterface, ptr
diff _t>(base)
+ .Case([](LLVMStructType structType) {
+ return structType.getBody().size();
+ })
+ .Case([](LLVMArrayType arrayType) {
+ return arrayType.getNumElements();
+ })
+ .Default([](auto) -> ptr
diff _t {
+ llvm_unreachable(
+ "Only LLVMStructType or LLVMArrayType supported");
+ });
+ }()) {}
+
+ /// Returns true if this is a range over a packed struct.
+ bool isPacked() const {
+ if (auto structType = dyn_cast<LLVMStructType>(getBase()))
+ return structType.isPacked();
+ return false;
+ }
+
+private:
+ static Type dereference(DestructurableTypeInterface base, ptr
diff _t index) {
+ // i32 chosen because the implementations of ArrayType and StructType
+ // specifically expect it to be 32 bit. They will fail otherwise.
+ Type result = base.getTypeAtIndex(
+ IntegerAttr::get(IntegerType::get(base.getContext(), 32), index));
+ assert(result && "Should always succeed");
+ return result;
+ }
+
+ friend Base;
+};
+} // namespace
+
+/// Returns the list of elements of `destructurableType` that are written to by
+/// a store operation writing `storeSize` bytes at `storeOffset`.
+/// `storeOffset` is required to cleanly point to an immediate element within
+/// the type. If the write operation were to write to any padding, write beyond
+/// the aggregate or partially write to a non-aggregate, failure is returned.
+static FailureOr<DestructurableTypeRange>
+getWrittenToFields(const DataLayout &dataLayout,
+ DestructurableTypeInterface destructurableType,
unsigned storeSize, unsigned storeOffset) {
- ArrayRef<Type> body = structType.getBody();
+ DestructurableTypeRange destructurableTypeRange(destructurableType);
+
unsigned currentOffset = 0;
- body = body.drop_until([&](Type type) {
- if (!structType.isPacked()) {
+ for (; !destructurableTypeRange.empty();
+ destructurableTypeRange = destructurableTypeRange.drop_front()) {
+ Type type = destructurableTypeRange.front();
+ if (!destructurableTypeRange.isPacked()) {
unsigned alignment = dataLayout.getTypeABIAlignment(type);
currentOffset = llvm::alignTo(currentOffset, alignment);
}
@@ -370,40 +424,43 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
// 0 or stems from a type-consistent GEP indexing into just a single
// aggregate.
if (currentOffset == storeOffset)
- return true;
+ break;
assert(currentOffset < storeOffset &&
"storeOffset should cleanly point into an immediate field");
currentOffset += dataLayout.getTypeSize(type);
- return false;
- });
+ }
size_t exclusiveEnd = 0;
- for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) {
- if (!structType.isPacked()) {
- unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]);
+ for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
+ exclusiveEnd++) {
+ if (!destructurableTypeRange.isPacked()) {
+ unsigned alignment =
+ dataLayout.getTypeABIAlignment(destructurableTypeRange[exclusiveEnd]);
// No padding allowed inbetween fields at this point in time.
if (!llvm::isAligned(llvm::Align(alignment), currentOffset))
return failure();
}
- unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]);
+ unsigned fieldSize =
+ dataLayout.getTypeSize(destructurableTypeRange[exclusiveEnd]);
if (fieldSize > storeSize) {
// Partial writes into an aggregate are okay since subsequent pattern
// applications can further split these up into writes into the
// sub-elements.
- auto subStruct = dyn_cast<LLVMStructType>(body[exclusiveEnd]);
- if (!subStruct)
+ auto subAggregate = dyn_cast<DestructurableTypeInterface>(
+ destructurableTypeRange[exclusiveEnd]);
+ if (!subAggregate)
return failure();
- // Avoid splitting redundantly by making sure the store into the struct
- // can actually be split.
- if (failed(getWrittenToFields(dataLayout, subStruct, storeSize,
+ // Avoid splitting redundantly by making sure the store into the
+ // aggregate can actually be split.
+ if (failed(getWrittenToFields(dataLayout, subAggregate, storeSize,
/*storeOffset=*/0)))
return failure();
- return body.take_front(exclusiveEnd + 1);
+ return destructurableTypeRange.take_front(exclusiveEnd + 1);
}
currentOffset += fieldSize;
storeSize -= fieldSize;
@@ -413,7 +470,7 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
// as a whole. Abort.
if (storeSize > 0)
return failure();
- return body.take_front(exclusiveEnd);
+ return destructurableTypeRange.take_front(exclusiveEnd);
}
/// Splits a store of the vector `value` into `address` at `storeOffset` into
@@ -443,13 +500,13 @@ static void splitVectorStore(const DataLayout &dataLayout, Location loc,
}
/// Splits a store of the integer `value` into `address` at `storeOffset` into
-/// multiple stores to each 'writtenFields', making each store operation
+/// multiple stores to each 'writtenToFields', making each store operation
/// type-consistent.
static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
RewriterBase &rewriter, Value address,
Value value, unsigned storeSize,
unsigned storeOffset,
- ArrayRef<Type> writtenToFields) {
+ DestructurableTypeRange writtenToFields) {
unsigned currentOffset = storeOffset;
for (Type type : writtenToFields) {
unsigned fieldSize = dataLayout.getTypeSize(type);
@@ -527,26 +584,24 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
}
}
- auto structType = typeHint.dyn_cast<LLVMStructType>();
- if (!structType) {
- // TODO: Handle array types in the future.
+ auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
+ if (!destructurableType)
return failure();
- }
- FailureOr<ArrayRef<Type>> writtenToFields =
- getWrittenToFields(dataLayout, structType, storeSize, offset);
- if (failed(writtenToFields))
+ FailureOr<DestructurableTypeRange> writtenToElements =
+ getWrittenToFields(dataLayout, destructurableType, storeSize, offset);
+ if (failed(writtenToElements))
return failure();
- if (writtenToFields->size() <= 1) {
+ if (writtenToElements->size() <= 1) {
// Other patterns should take care of this case, we are only interested in
- // splitting field stores.
+ // splitting element stores.
return failure();
}
if (isa<IntegerType>(sourceType)) {
splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
- store.getValue(), storeSize, offset, *writtenToFields);
+ store.getValue(), storeSize, offset, *writtenToElements);
rewriter.eraseOp(store);
return success();
}
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 8c08c3d89ef40f..1504a98e6f8cca 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -624,3 +624,28 @@ llvm.func @undesirable_overlapping_aggregate_store(%arg: i64) {
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @coalesced_store_ints_array
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @coalesced_store_ints_array(%arg: i64) {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x i32>
+ %1 = llvm.alloca %0 x !llvm.array<2 x i32> : (i32) -> !llvm.ptr
+
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
+ // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
+ // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+ // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
+ // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
+ // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ llvm.store %arg, %1 : i64, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
More information about the Mlir-commits
mailing list