[Mlir-commits] [mlir] 67c351f - [mlir][LLVM] Make `SplitIntegerStores` capable of splitting vectors as well
Markus Böck
llvmlistbot at llvm.org
Thu Jul 6 06:54:54 PDT 2023
Author: Markus Böck
Date: 2023-07-06T15:39:41+02:00
New Revision: 67c351f648e09256827b1e826e2bf80083279049
URL: https://github.com/llvm/llvm-project/commit/67c351f648e09256827b1e826e2bf80083279049
DIFF: https://github.com/llvm/llvm-project/commit/67c351f648e09256827b1e826e2bf80083279049.diff
LOG: [mlir][LLVM] Make `SplitIntegerStores` capable of splitting vectors as well
The original plan was to turn this into its own pattern, but one of the difficulties was deeming when splitting the vector is required.
`SplitIntegerStores` essentially already did that by checking for field overlap.
Therefore, it was renamed to `SplitStores` and extended to splitting stores with values of vector and integer type.
The vector splitting is done in a simple manner by simply using `extractelement` to get each vector element. Subsequent pattern applications are responsible for further cleaning up the output and making it type-consistent.
Worst case, if the code cannot be transformed into a type-consistent form (due to e.g. the code explicitly doing partial writes to elements or similar), we might needlessly do a vector split.
Differential Revision: https://reviews.llvm.org/D154583
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
mlir/test/Dialect/LLVMIR/type-consistency.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index e967cea6ae5186..b7dfc8656fc1fd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -39,6 +39,13 @@ def LLVMTypeConsistency
their associated pointee type as consistently as possible.
}];
let constructor = "::mlir::LLVM::createTypeConsistencyPass()";
+
+ let options = [
+ Option<"maxVectorSplitSize", "max-vector-split-size", "unsigned",
+ /*default=*/"512",
+ "Maximum size in bits of a vector value in a load or store operation"
+ " operating on multiple elements that should still be split">,
+ ];
}
def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index 7da8b7fe29f180..a8eebcd3e93405 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -53,13 +53,16 @@ class CanonicalizeAlignedGep : public OpRewritePattern<GEPOp> {
PatternRewriter &rewriter) const override;
};
-/// Splits stores of integers which write into multiple adjacent stores
-/// of a pointer. The integer is then split and stores are generated for
-/// every field being stored in a type-consistent manner.
-/// This is currently done on a best-effort basis.
-class SplitIntegerStores : public OpRewritePattern<StoreOp> {
+/// Splits stores which write into multiple adjacent elements of an aggregate
+/// through a pointer. Currently, integers and vector are split and stores
+/// are generated for every element being stored to in a type-consistent manner.
+/// This is done on a best-effort basis.
+class SplitStores : public OpRewritePattern<StoreOp> {
+ unsigned maxVectorSplitSize;
+
public:
- using OpRewritePattern::OpRewritePattern;
+ SplitStores(MLIRContext *context, unsigned maxVectorSplitSize)
+ : OpRewritePattern(context), maxVectorSplitSize(maxVectorSplitSize) {}
LogicalResult matchAndRewrite(StoreOp store,
PatternRewriter &rewrite) const override;
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 02eddb47d9dfd5..0a760c7ee48117 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -411,12 +411,78 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
return body.take_front(exclusiveEnd);
}
-LogicalResult
-SplitIntegerStores::matchAndRewrite(StoreOp store,
- PatternRewriter &rewriter) const {
- IntegerType sourceType = dyn_cast<IntegerType>(store.getValue().getType());
- if (!sourceType) {
- // We currently only support integer sources.
+/// Splits a store of the vector `value` into `address` at `storeOffset` into
+/// multiple stores of each element with the goal of each generated store
+/// becoming type-consistent through subsequent pattern applications.
+static void splitVectorStore(const DataLayout &dataLayout, Location loc,
+ RewriterBase &rewriter, Value address,
+ TypedValue<VectorType> value,
+ unsigned storeOffset) {
+ VectorType vectorType = value.getType();
+ unsigned elementSize = dataLayout.getTypeSize(vectorType.getElementType());
+
+ // Extract every element in the vector and store it in the given address.
+ for (size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
+ auto pos =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(index));
+ auto extractOp = rewriter.create<ExtractElementOp>(loc, value, pos);
+
+ // For convenience, we do indexing by calculating the final byte offset.
+ // Other patterns will turn this into a type-consistent GEP.
+ auto gepOp = rewriter.create<GEPOp>(
+ loc, address.getType(), rewriter.getI8Type(), address,
+ ArrayRef<GEPArg>{storeOffset + index * elementSize});
+
+ rewriter.create<StoreOp>(loc, extractOp, gepOp);
+ }
+}
+
+/// Splits a store of the integer `value` into `address` at `storeOffset` into
+/// multiple stores to each 'writtenFields', making each store operation
+/// type-consistent.
+static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
+ RewriterBase &rewriter, Value address,
+ Value value, unsigned storeOffset,
+ ArrayRef<Type> writtenToFields) {
+ unsigned currentOffset = storeOffset;
+ for (Type type : writtenToFields) {
+ unsigned fieldSize = dataLayout.getTypeSize(type);
+
+ // Extract the data out of the integer by first shifting right and then
+ // truncating it.
+ auto pos = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(value.getType(),
+ (currentOffset - storeOffset) * 8));
+
+ auto shrOp = rewriter.create<LShrOp>(loc, value, pos);
+
+ IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
+ Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
+ if (fieldIntType != type) {
+ // Bitcast to the right type. `fieldIntType` was explicitly created
+ // to be of the same size as `type` and must currently be a primitive as
+ // well.
+ valueToStore = rewriter.create<BitcastOp>(loc, type, valueToStore);
+ }
+
+ // We create an `i8` indexed GEP here as that is the easiest (offset is
+ // already known). Other patterns turn this into a type-consistent GEP.
+ auto gepOp =
+ rewriter.create<GEPOp>(loc, address.getType(), rewriter.getI8Type(),
+ address, ArrayRef<GEPArg>{currentOffset});
+ rewriter.create<StoreOp>(loc, valueToStore, gepOp);
+
+ // No need to care about padding here since we already checked previously
+ // that no padding exists in this range.
+ currentOffset += fieldSize;
+ }
+}
+
+LogicalResult SplitStores::matchAndRewrite(StoreOp store,
+ PatternRewriter &rewriter) const {
+ Type sourceType = store.getValue().getType();
+ if (!isa<IntegerType, VectorType>(sourceType)) {
+ // We currently only support integer and vector sources.
return failure();
}
@@ -465,43 +531,30 @@ SplitIntegerStores::matchAndRewrite(StoreOp store,
if (failed(writtenToFields))
return failure();
- unsigned currentOffset = offset;
- for (Type type : *writtenToFields) {
- unsigned fieldSize = dataLayout.getTypeSize(type);
-
- // Extract the data out of the integer by first shifting right and then
- // truncating it.
- auto pos = rewriter.create<ConstantOp>(
- store.getLoc(),
- rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8));
-
- auto shrOp = rewriter.create<LShrOp>(store.getLoc(), store.getValue(), pos);
-
- IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
- Value valueToStore =
- rewriter.create<TruncOp>(store.getLoc(), fieldIntType, shrOp);
- if (fieldIntType != type) {
- // Bitcast to the right type. `fieldIntType` was explicitly created
- // to be of the same size as `type` and must currently be a primitive as
- // well.
- valueToStore =
- rewriter.create<BitcastOp>(store.getLoc(), type, valueToStore);
- }
-
- // We create an `i8` indexed GEP here as that is the easiest (offset is
- // already known). Other patterns turn this into a type-consistent GEP.
- auto gepOp = rewriter.create<GEPOp>(store.getLoc(), address.getType(),
- rewriter.getI8Type(), address,
- ArrayRef<GEPArg>{currentOffset});
- rewriter.create<StoreOp>(store.getLoc(), valueToStore, gepOp);
+ if (writtenToFields->size() <= 1) {
+ // Other patterns should take care of this case, we are only interested in
+ // splitting field stores.
+ return failure();
+ }
- // No need to care about padding here since we already checked previously
- // that no padding exists in this range.
- currentOffset += fieldSize;
+ if (isa<IntegerType>(sourceType)) {
+ splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
+ store.getValue(), offset, *writtenToFields);
+ rewriter.eraseOp(store);
+ return success();
}
- rewriter.eraseOp(store);
+ // Add a reasonable bound to not split very large vectors that would end up
+ // generating lots of code.
+ if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
+ return failure();
+ // Vector types are simply split into its elements and new stores generated
+ // with those. Subsequent pattern applications will split these stores further
+ // if required.
+ splitVectorStore(dataLayout, store.getLoc(), rewriter, address,
+ cast<TypedValue<VectorType>>(store.getValue()), offset);
+ rewriter.eraseOp(store);
return success();
}
@@ -518,7 +571,7 @@ struct LLVMTypeConsistencyPass
rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
&getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
- rewritePatterns.add<SplitIntegerStores>(&getContext());
+ rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
FrozenRewritePatternSet frozen(std::move(rewritePatterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 5cef22aacd0e39..ba477a51812a51 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -300,3 +300,121 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_write_split
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_write_split(%arg: vector<4xi32>) {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_write_split_offset
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_write_split_offset(%arg: vector<4xi32>) {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32>
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)>
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+
+ llvm.store %arg, %2 : vector<4xi32>, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
+
+// -----
+
+// Small test that a split vector store will be further optimized (to than e.g.
+// split integer loads to structs as shown here)
+
+// CHECK-LABEL: llvm.func @vector_write_split_struct
+// CHECK-SAME: %[[ARG:.*]]: vector<2xi64>
+llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
+
+ // CHECK-COUNT-4: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr
+
+ llvm.store %arg, %1 : vector<2xi64>, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @type_consistent_vector_store
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)>
+ // CHECK: llvm.store %[[ARG]], %[[GEP]]
+ llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
+ // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
+ llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
More information about the Mlir-commits
mailing list