[Mlir-commits] [mlir] 67c351f - [mlir][LLVM] Make `SplitIntegerStores` capable of splitting vectors as well

Markus Böck llvmlistbot at llvm.org
Thu Jul 6 06:54:54 PDT 2023


Author: Markus Böck
Date: 2023-07-06T15:39:41+02:00
New Revision: 67c351f648e09256827b1e826e2bf80083279049

URL: https://github.com/llvm/llvm-project/commit/67c351f648e09256827b1e826e2bf80083279049
DIFF: https://github.com/llvm/llvm-project/commit/67c351f648e09256827b1e826e2bf80083279049.diff

LOG: [mlir][LLVM] Make `SplitIntegerStores` capable of splitting vectors as well

The original plan was to turn this into its own pattern, but one of the difficulties was deeming when splitting the vector is required.
`SplitIntegerStores` essentially already did that by checking for field overlap.
Therefore, it was renamed to `SplitStores` and extended to splitting stores with values of vector and integer type.

The vector splitting is done in a simple manner by simply using `extractelement` to get each vector element. Subsequent pattern applications are responsible for further cleaning up the output and making it type-consistent.

Worst case, if the code cannot be transformed into a type-consistent form (due to e.g. the code explicitly doing partial writes to elements or similar), we might needlessly do a vector split.

Differential Revision: https://reviews.llvm.org/D154583

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
    mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index e967cea6ae5186..b7dfc8656fc1fd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -39,6 +39,13 @@ def LLVMTypeConsistency
     their associated pointee type as consistently as possible.
   }];
   let constructor = "::mlir::LLVM::createTypeConsistencyPass()";
+
+  let options = [
+    Option<"maxVectorSplitSize", "max-vector-split-size", "unsigned",
+           /*default=*/"512",
+           "Maximum size in bits of a vector value in a load or store operation"
+           " operating on multiple elements that should still be split">,
+  ];
 }
 
 def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index 7da8b7fe29f180..a8eebcd3e93405 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -53,13 +53,16 @@ class CanonicalizeAlignedGep : public OpRewritePattern<GEPOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Splits stores of integers which write into multiple adjacent stores
-/// of a pointer. The integer is then split and stores are generated for
-/// every field being stored in a type-consistent manner.
-/// This is currently done on a best-effort basis.
-class SplitIntegerStores : public OpRewritePattern<StoreOp> {
+/// Splits stores which write into multiple adjacent elements of an aggregate
+/// through a pointer. Currently, integers and vector are split and stores
+/// are generated for every element being stored to in a type-consistent manner.
+/// This is done on a best-effort basis.
+class SplitStores : public OpRewritePattern<StoreOp> {
+  unsigned maxVectorSplitSize;
+
 public:
-  using OpRewritePattern::OpRewritePattern;
+  SplitStores(MLIRContext *context, unsigned maxVectorSplitSize)
+      : OpRewritePattern(context), maxVectorSplitSize(maxVectorSplitSize) {}
 
   LogicalResult matchAndRewrite(StoreOp store,
                                 PatternRewriter &rewrite) const override;

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 02eddb47d9dfd5..0a760c7ee48117 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -411,12 +411,78 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
   return body.take_front(exclusiveEnd);
 }
 
-LogicalResult
-SplitIntegerStores::matchAndRewrite(StoreOp store,
-                                    PatternRewriter &rewriter) const {
-  IntegerType sourceType = dyn_cast<IntegerType>(store.getValue().getType());
-  if (!sourceType) {
-    // We currently only support integer sources.
+/// Splits a store of the vector `value` into `address` at `storeOffset` into
+/// multiple stores of each element with the goal of each generated store
+/// becoming type-consistent through subsequent pattern applications.
+static void splitVectorStore(const DataLayout &dataLayout, Location loc,
+                             RewriterBase &rewriter, Value address,
+                             TypedValue<VectorType> value,
+                             unsigned storeOffset) {
+  VectorType vectorType = value.getType();
+  unsigned elementSize = dataLayout.getTypeSize(vectorType.getElementType());
+
+  // Extract every element in the vector and store it in the given address.
+  for (size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
+    auto pos =
+        rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(index));
+    auto extractOp = rewriter.create<ExtractElementOp>(loc, value, pos);
+
+    // For convenience, we do indexing by calculating the final byte offset.
+    // Other patterns will turn this into a type-consistent GEP.
+    auto gepOp = rewriter.create<GEPOp>(
+        loc, address.getType(), rewriter.getI8Type(), address,
+        ArrayRef<GEPArg>{storeOffset + index * elementSize});
+
+    rewriter.create<StoreOp>(loc, extractOp, gepOp);
+  }
+}
+
+/// Splits a store of the integer `value` into `address` at `storeOffset` into
+/// multiple stores to each 'writtenFields', making each store operation
+/// type-consistent.
+static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
+                              RewriterBase &rewriter, Value address,
+                              Value value, unsigned storeOffset,
+                              ArrayRef<Type> writtenToFields) {
+  unsigned currentOffset = storeOffset;
+  for (Type type : writtenToFields) {
+    unsigned fieldSize = dataLayout.getTypeSize(type);
+
+    // Extract the data out of the integer by first shifting right and then
+    // truncating it.
+    auto pos = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(value.getType(),
+                                     (currentOffset - storeOffset) * 8));
+
+    auto shrOp = rewriter.create<LShrOp>(loc, value, pos);
+
+    IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
+    Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
+    if (fieldIntType != type) {
+      // Bitcast to the right type. `fieldIntType` was explicitly created
+      // to be of the same size as `type` and must currently be a primitive as
+      // well.
+      valueToStore = rewriter.create<BitcastOp>(loc, type, valueToStore);
+    }
+
+    // We create an `i8` indexed GEP here as that is the easiest (offset is
+    // already known). Other patterns turn this into a type-consistent GEP.
+    auto gepOp =
+        rewriter.create<GEPOp>(loc, address.getType(), rewriter.getI8Type(),
+                               address, ArrayRef<GEPArg>{currentOffset});
+    rewriter.create<StoreOp>(loc, valueToStore, gepOp);
+
+    // No need to care about padding here since we already checked previously
+    // that no padding exists in this range.
+    currentOffset += fieldSize;
+  }
+}
+
+LogicalResult SplitStores::matchAndRewrite(StoreOp store,
+                                           PatternRewriter &rewriter) const {
+  Type sourceType = store.getValue().getType();
+  if (!isa<IntegerType, VectorType>(sourceType)) {
+    // We currently only support integer and vector sources.
     return failure();
   }
 
@@ -465,43 +531,30 @@ SplitIntegerStores::matchAndRewrite(StoreOp store,
   if (failed(writtenToFields))
     return failure();
 
-  unsigned currentOffset = offset;
-  for (Type type : *writtenToFields) {
-    unsigned fieldSize = dataLayout.getTypeSize(type);
-
-    // Extract the data out of the integer by first shifting right and then
-    // truncating it.
-    auto pos = rewriter.create<ConstantOp>(
-        store.getLoc(),
-        rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8));
-
-    auto shrOp = rewriter.create<LShrOp>(store.getLoc(), store.getValue(), pos);
-
-    IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
-    Value valueToStore =
-        rewriter.create<TruncOp>(store.getLoc(), fieldIntType, shrOp);
-    if (fieldIntType != type) {
-      // Bitcast to the right type. `fieldIntType` was explicitly created
-      // to be of the same size as `type` and must currently be a primitive as
-      // well.
-      valueToStore =
-          rewriter.create<BitcastOp>(store.getLoc(), type, valueToStore);
-    }
-
-    // We create an `i8` indexed GEP here as that is the easiest (offset is
-    // already known). Other patterns turn this into a type-consistent GEP.
-    auto gepOp = rewriter.create<GEPOp>(store.getLoc(), address.getType(),
-                                        rewriter.getI8Type(), address,
-                                        ArrayRef<GEPArg>{currentOffset});
-    rewriter.create<StoreOp>(store.getLoc(), valueToStore, gepOp);
+  if (writtenToFields->size() <= 1) {
+    // Other patterns should take care of this case, we are only interested in
+    // splitting field stores.
+    return failure();
+  }
 
-    // No need to care about padding here since we already checked previously
-    // that no padding exists in this range.
-    currentOffset += fieldSize;
+  if (isa<IntegerType>(sourceType)) {
+    splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
+                      store.getValue(), offset, *writtenToFields);
+    rewriter.eraseOp(store);
+    return success();
   }
 
-  rewriter.eraseOp(store);
+  // Add a reasonable bound to not split very large vectors that would end up
+  // generating lots of code.
+  if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
+    return failure();
 
+  // Vector types are simply split into its elements and new stores generated
+  // with those. Subsequent pattern applications will split these stores further
+  // if required.
+  splitVectorStore(dataLayout, store.getLoc(), rewriter, address,
+                   cast<TypedValue<VectorType>>(store.getValue()), offset);
+  rewriter.eraseOp(store);
   return success();
 }
 
@@ -518,7 +571,7 @@ struct LLVMTypeConsistencyPass
     rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
         &getContext());
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
-    rewritePatterns.add<SplitIntegerStores>(&getContext());
+    rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 5cef22aacd0e39..ba477a51812a51 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -300,3 +300,121 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_write_split
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_write_split(%arg: vector<4xi32>) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_write_split_offset
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_write_split_offset(%arg: vector<4xi32>) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+  llvm.store %arg, %2 : vector<4xi32>, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// Small test that a split vector store will be further optimized (to than e.g.
+// split integer loads to structs as shown here)
+
+// CHECK-LABEL: llvm.func @vector_write_split_struct
+// CHECK-SAME: %[[ARG:.*]]: vector<2xi64>
+llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+
+  // CHECK-COUNT-4: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr
+
+  llvm.store %arg, %1 : vector<2xi64>, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @type_consistent_vector_store
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)>
+  // CHECK: llvm.store %[[ARG]], %[[GEP]]
+  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
+  // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
+  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}


        


More information about the Mlir-commits mailing list