[Mlir-commits] [mlir] 7786449 - [mlir][LLVM] Make `SplitStores` pattern capable of writing to sub-aggregates

Markus Böck llvmlistbot at llvm.org
Mon Jul 10 06:28:26 PDT 2023


Author: Markus Böck
Date: 2023-07-10T15:27:47+02:00
New Revision: 7786449334d8e6ccda1362fef7006bfab86333b7

URL: https://github.com/llvm/llvm-project/commit/7786449334d8e6ccda1362fef7006bfab86333b7
DIFF: https://github.com/llvm/llvm-project/commit/7786449334d8e6ccda1362fef7006bfab86333b7.diff

LOG: [mlir][LLVM] Make `SplitStores` pattern capable of writing to sub-aggregates

The pattern was previously only capable of storing into struct fields which are primitive types. If the struct contained a nested struct it immediately aborted the pattern rewrite.

This patch introduces the capability of recursively splitting stores into sub-structs as well. This is achieved by splitting an aggregate sized integer from the original store argument and letting repeated pattern applications further split it into field stores.

Additionally, the pattern is also capable of handling partial writes into aggregates, which is a pattern clang may generate as well. Special care had to be taken to make sure no stores are created that weren't in the original code.

Differential Revision: https://reviews.llvm.org/D154707

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index aad6e19f9f9c24..157f973ca9a32a 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -357,7 +357,7 @@ CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
 /// types, failure is returned.
 static FailureOr<ArrayRef<Type>>
 getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
-                   int storeSize, unsigned storeOffset) {
+                   unsigned storeSize, unsigned storeOffset) {
   ArrayRef<Type> body = structType.getBody();
   unsigned currentOffset = 0;
   body = body.drop_until([&](Type type) {
@@ -381,10 +381,6 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
 
   size_t exclusiveEnd = 0;
   for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) {
-    // Not yet recursively handling aggregates, only primitives.
-    if (!isa<IntegerType, FloatType>(body[exclusiveEnd]))
-      return failure();
-
     if (!structType.isPacked()) {
       unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]);
       // No padding allowed inbetween fields at this point in time.
@@ -393,13 +389,29 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
     }
 
     unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]);
+    if (fieldSize > storeSize) {
+      // Partial writes into an aggregate are okay since subsequent pattern
+      // applications can further split these up into writes into the
+      // sub-elements.
+      auto subStruct = dyn_cast<LLVMStructType>(body[exclusiveEnd]);
+      if (!subStruct)
+        return failure();
+
+      // Avoid splitting redundantly by making sure the store into the struct
+      // can actually be split.
+      if (failed(getWrittenToFields(dataLayout, subStruct, storeSize,
+                                    /*storeOffset=*/0)))
+        return failure();
+
+      return body.take_front(exclusiveEnd + 1);
+    }
     currentOffset += fieldSize;
     storeSize -= fieldSize;
   }
 
-  // If the storeSize is not 0 at this point we are either partially writing
-  // into a field or writing past the aggregate as a whole. Abort.
-  if (storeSize != 0)
+  // If the storeSize is not 0 at this point we are  writing past the aggregate
+  // as a whole. Abort.
+  if (storeSize > 0)
     return failure();
   return body.take_front(exclusiveEnd);
 }
@@ -435,7 +447,8 @@ static void splitVectorStore(const DataLayout &dataLayout, Location loc,
 /// type-consistent.
 static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
                               RewriterBase &rewriter, Value address,
-                              Value value, unsigned storeOffset,
+                              Value value, unsigned storeSize,
+                              unsigned storeOffset,
                               ArrayRef<Type> writtenToFields) {
   unsigned currentOffset = storeOffset;
   for (Type type : writtenToFields) {
@@ -449,7 +462,12 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
 
     auto shrOp = rewriter.create<LShrOp>(loc, value, pos);
 
-    IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
+    // If we are doing a partial write into a direct field the remaining
+    // `storeSize` will be less than the size of the field. We have to truncate
+    // to the `storeSize` to avoid creating a store that wasn't in the original
+    // code.
+    IntegerType fieldIntType =
+        rewriter.getIntegerType(std::min(fieldSize, storeSize) * 8);
     Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
 
     // We create an `i8` indexed GEP here as that is the easiest (offset is
@@ -462,6 +480,7 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
     // No need to care about padding here since we already checked previously
     // that no padding exists in this range.
     currentOffset += fieldSize;
+    storeSize -= fieldSize;
   }
 }
 
@@ -481,28 +500,31 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
 
   auto dataLayout = DataLayout::closest(store);
 
+  unsigned storeSize = dataLayout.getTypeSize(sourceType);
   unsigned offset = 0;
   Value address = store.getAddr();
   if (auto gepOp = address.getDefiningOp<GEPOp>()) {
     // Currently only handle canonical GEPs with exactly two indices,
     // indexing a single aggregate deep.
-    // Recursing into sub-structs is left as a future exercise.
     // If the GEP is not canonical we have to fail, otherwise we would not
     // create type-consistent IR.
     if (gepOp.getIndices().size() != 2 ||
         succeeded(getRequiredConsistentGEPType(gepOp)))
       return failure();
 
-    // A GEP might point somewhere into the middle of an aggregate with the
-    // store storing into multiple adjacent elements. Destructure into
-    // the base address with an offset.
-    std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
-    if (!byteOffset)
-      return failure();
+    // If the size of the element indexed by the  GEP is smaller than the store
+    // size, it is pointing into the middle of an aggregate with the store
+    // storing into multiple adjacent elements. Destructure into the base
+    // address of the aggregate with a store offset.
+    if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
+      std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
+      if (!byteOffset)
+        return failure();
 
-    offset = *byteOffset;
-    typeHint = gepOp.getSourceElementType();
-    address = gepOp.getBase();
+      offset = *byteOffset;
+      typeHint = gepOp.getSourceElementType();
+      address = gepOp.getBase();
+    }
   }
 
   auto structType = typeHint.dyn_cast<LLVMStructType>();
@@ -512,9 +534,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
   }
 
   FailureOr<ArrayRef<Type>> writtenToFields =
-      getWrittenToFields(dataLayout, structType,
-                         /*storeSize=*/dataLayout.getTypeSize(sourceType),
-                         /*storeOffset=*/offset);
+      getWrittenToFields(dataLayout, structType, storeSize, offset);
   if (failed(writtenToFields))
     return failure();
 
@@ -526,7 +546,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
 
   if (isa<IntegerType>(sourceType)) {
     splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
-                      store.getValue(), offset, *writtenToFields);
+                      store.getValue(), storeSize, offset, *writtenToFields);
     rewriter.eraseOp(store);
     return success();
   }

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 7c345d454f264c..8c08c3d89ef40f 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -493,3 +493,134 @@ llvm.func @gep_result_ptr_type_dynamic(%arg: i64) {
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @overlapping_int_aggregate_store
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @overlapping_int_aggregate_store(%arg: i64) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
+  // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
+  // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+
+  // Normal integer splitting of [[TRUNC]] follows:
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+
+  llvm.store %arg, %1 : i64, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @overlapping_vector_aggregate_store
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi16>
+llvm.func @overlapping_vector_aggregate_store(%arg: vector<4 x i16>) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32
+
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32]
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]]
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32]
+  // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32]
+  // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32]
+  // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
+  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 2] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]
+
+  llvm.store %arg, %1 : vector<4 x i16>, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @partially_overlapping_aggregate_store
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> : (i32) -> !llvm.ptr
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
+  // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
+  // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
+
+  // Normal integer splitting of [[TRUNC]] follows:
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+
+  // It is important that there are no more stores at this point.
+  // Specifically a store into the fourth field of %[[TOP_GEP]] would
+  // incorrectly change the semantics of the code.
+  // CHECK-NOT: llvm.store %{{.*}}, %{{.*}}
+
+  llvm.store %arg, %1 : i64, !llvm.ptr
+
+  llvm.return
+}
+
+// -----
+
+// Here a split is undesirable since the store does a partial store into the field.
+
+// CHECK-LABEL: llvm.func @undesirable_overlapping_aggregate_store
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @undesirable_overlapping_aggregate_store(%arg: i64) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
+  // CHECK: llvm.store %[[ARG]], %[[GEP]]
+  llvm.store %arg, %2 : i64, !llvm.ptr
+
+  llvm.return
+}


        


More information about the Mlir-commits mailing list