[Mlir-commits] [mlir] 569f073 - [mlir][LLVM] Add support for arrays in `SplitStores` pattern

Markus Böck llvmlistbot at llvm.org
Mon Jul 10 06:56:55 PDT 2023


Author: Markus Böck
Date: 2023-07-10T15:55:07+02:00
New Revision: 569f07319105500f0b19e6962370439b54a67e80

URL: https://github.com/llvm/llvm-project/commit/569f07319105500f0b19e6962370439b54a67e80
DIFF: https://github.com/llvm/llvm-project/commit/569f07319105500f0b19e6962370439b54a67e80.diff

LOG: [mlir][LLVM] Add support for arrays in `SplitStores` pattern

The pattern so far has only supported splitting stores into struct types, marking arrays as explicitly unsupported. This would lead to stores into arrays not being made type-consistent like structs and therefore also not being properly split by SROA and mem2reg.

This patch adds support for array types by creating a common abstraction for both structs and arrays, making an array of size n essentially be treated like a struct with n fields of the arrays element type.
This gives us immediate feature parity without special casing for either of the types.

Differential Revision: https://reviews.llvm.org/D154840

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 157f973ca9a32a..9731689e551762 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -348,20 +348,74 @@ CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
   return success();
 }
 
-/// Returns the list of fields of `structType` that are written to by a store
-/// operation writing `storeSize` bytes at `storeOffset` within the struct.
-/// `storeOffset` is required to cleanly point to an immediate field within
-/// the struct.
-/// If the write operation were to write to any padding, write beyond the
-/// struct, partially write to a field, or contains currently unsupported
-/// types, failure is returned.
-static FailureOr<ArrayRef<Type>>
-getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
+namespace {
+/// Class abstracting over both array and struct types, turning each into ranges
+/// of their sub-types.
+class DestructurableTypeRange
+    : public llvm::indexed_accessor_range<DestructurableTypeRange,
+                                          DestructurableTypeInterface, Type,
+                                          Type *, Type> {
+
+  using Base = llvm::indexed_accessor_range<
+      DestructurableTypeRange, DestructurableTypeInterface, Type, Type *, Type>;
+
+public:
+  using Base::Base;
+
+  /// Constructs a DestructurableTypeRange from either a LLVMStructType or
+  /// LLVMArrayType.
+  explicit DestructurableTypeRange(DestructurableTypeInterface base)
+      : Base(base, 0, [&]() -> ptr
diff _t {
+          return TypeSwitch<DestructurableTypeInterface, ptr
diff _t>(base)
+              .Case([](LLVMStructType structType) {
+                return structType.getBody().size();
+              })
+              .Case([](LLVMArrayType arrayType) {
+                return arrayType.getNumElements();
+              })
+              .Default([](auto) -> ptr
diff _t {
+                llvm_unreachable(
+                    "Only LLVMStructType or LLVMArrayType supported");
+              });
+        }()) {}
+
+  /// Returns true if this is a range over a packed struct.
+  bool isPacked() const {
+    if (auto structType = dyn_cast<LLVMStructType>(getBase()))
+      return structType.isPacked();
+    return false;
+  }
+
+private:
+  static Type dereference(DestructurableTypeInterface base, ptr
diff _t index) {
+    // i32 chosen because the implementations of ArrayType and StructType
+    // specifically expect it to be 32 bit. They will fail otherwise.
+    Type result = base.getTypeAtIndex(
+        IntegerAttr::get(IntegerType::get(base.getContext(), 32), index));
+    assert(result && "Should always succeed");
+    return result;
+  }
+
+  friend Base;
+};
+} // namespace
+
+/// Returns the list of elements of `destructurableType` that are written to by
+/// a store operation writing `storeSize` bytes at `storeOffset`.
+/// `storeOffset` is required to cleanly point to an immediate element within
+/// the type. If the write operation were to write to any padding, write beyond
+/// the aggregate or partially write to a non-aggregate, failure is returned.
+static FailureOr<DestructurableTypeRange>
+getWrittenToFields(const DataLayout &dataLayout,
+                   DestructurableTypeInterface destructurableType,
                    unsigned storeSize, unsigned storeOffset) {
-  ArrayRef<Type> body = structType.getBody();
+  DestructurableTypeRange destructurableTypeRange(destructurableType);
+
   unsigned currentOffset = 0;
-  body = body.drop_until([&](Type type) {
-    if (!structType.isPacked()) {
+  for (; !destructurableTypeRange.empty();
+       destructurableTypeRange = destructurableTypeRange.drop_front()) {
+    Type type = destructurableTypeRange.front();
+    if (!destructurableTypeRange.isPacked()) {
       unsigned alignment = dataLayout.getTypeABIAlignment(type);
       currentOffset = llvm::alignTo(currentOffset, alignment);
     }
@@ -370,40 +424,43 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
     // 0 or stems from a type-consistent GEP indexing into just a single
     // aggregate.
     if (currentOffset == storeOffset)
-      return true;
+      break;
 
     assert(currentOffset < storeOffset &&
            "storeOffset should cleanly point into an immediate field");
 
     currentOffset += dataLayout.getTypeSize(type);
-    return false;
-  });
+  }
 
   size_t exclusiveEnd = 0;
-  for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) {
-    if (!structType.isPacked()) {
-      unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]);
+  for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
+       exclusiveEnd++) {
+    if (!destructurableTypeRange.isPacked()) {
+      unsigned alignment =
+          dataLayout.getTypeABIAlignment(destructurableTypeRange[exclusiveEnd]);
       // No padding allowed inbetween fields at this point in time.
       if (!llvm::isAligned(llvm::Align(alignment), currentOffset))
         return failure();
     }
 
-    unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]);
+    unsigned fieldSize =
+        dataLayout.getTypeSize(destructurableTypeRange[exclusiveEnd]);
     if (fieldSize > storeSize) {
       // Partial writes into an aggregate are okay since subsequent pattern
       // applications can further split these up into writes into the
       // sub-elements.
-      auto subStruct = dyn_cast<LLVMStructType>(body[exclusiveEnd]);
-      if (!subStruct)
+      auto subAggregate = dyn_cast<DestructurableTypeInterface>(
+          destructurableTypeRange[exclusiveEnd]);
+      if (!subAggregate)
         return failure();
 
-      // Avoid splitting redundantly by making sure the store into the struct
-      // can actually be split.
-      if (failed(getWrittenToFields(dataLayout, subStruct, storeSize,
+      // Avoid splitting redundantly by making sure the store into the
+      // aggregate can actually be split.
+      if (failed(getWrittenToFields(dataLayout, subAggregate, storeSize,
                                     /*storeOffset=*/0)))
         return failure();
 
-      return body.take_front(exclusiveEnd + 1);
+      return destructurableTypeRange.take_front(exclusiveEnd + 1);
     }
     currentOffset += fieldSize;
     storeSize -= fieldSize;
@@ -413,7 +470,7 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
   // as a whole. Abort.
   if (storeSize > 0)
     return failure();
-  return body.take_front(exclusiveEnd);
+  return destructurableTypeRange.take_front(exclusiveEnd);
 }
 
 /// Splits a store of the vector `value` into `address` at `storeOffset` into
@@ -443,13 +500,13 @@ static void splitVectorStore(const DataLayout &dataLayout, Location loc,
 }
 
 /// Splits a store of the integer `value` into `address` at `storeOffset` into
-/// multiple stores to each 'writtenFields', making each store operation
+/// multiple stores to each 'writtenToFields', making each store operation
 /// type-consistent.
 static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
                               RewriterBase &rewriter, Value address,
                               Value value, unsigned storeSize,
                               unsigned storeOffset,
-                              ArrayRef<Type> writtenToFields) {
+                              DestructurableTypeRange writtenToFields) {
   unsigned currentOffset = storeOffset;
   for (Type type : writtenToFields) {
     unsigned fieldSize = dataLayout.getTypeSize(type);
@@ -527,26 +584,24 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
     }
   }
 
-  auto structType = typeHint.dyn_cast<LLVMStructType>();
-  if (!structType) {
-    // TODO: Handle array types in the future.
+  auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
+  if (!destructurableType)
     return failure();
-  }
 
-  FailureOr<ArrayRef<Type>> writtenToFields =
-      getWrittenToFields(dataLayout, structType, storeSize, offset);
-  if (failed(writtenToFields))
+  FailureOr<DestructurableTypeRange> writtenToElements =
+      getWrittenToFields(dataLayout, destructurableType, storeSize, offset);
+  if (failed(writtenToElements))
     return failure();
 
-  if (writtenToFields->size() <= 1) {
+  if (writtenToElements->size() <= 1) {
     // Other patterns should take care of this case, we are only interested in
-    // splitting field stores.
+    // splitting element stores.
     return failure();
   }
 
   if (isa<IntegerType>(sourceType)) {
     splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
-                      store.getValue(), storeSize, offset, *writtenToFields);
+                      store.getValue(), storeSize, offset, *writtenToElements);
     rewriter.eraseOp(store);
     return success();
   }

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 8c08c3d89ef40f..1504a98e6f8cca 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -624,3 +624,28 @@ llvm.func @undesirable_overlapping_aggregate_store(%arg: i64) {
 
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @coalesced_store_ints_array
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @coalesced_store_ints_array(%arg: i64) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x i32>
+  %1 = llvm.alloca %0 x !llvm.array<2 x i32> : (i32) -> !llvm.ptr
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.array<2 x i32>
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  llvm.store %arg, %1 : i64, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}


        


More information about the Mlir-commits mailing list