[Mlir-commits] [mlir] fd5387f - [mlir][LLVM] Add `BitcastStores` type-consistency pattern
Markus Böck
llvmlistbot at llvm.org
Sun Jul 9 23:57:11 PDT 2023
Author: Markus Böck
Date: 2023-07-10T08:56:50+02:00
New Revision: fd5387f44f0f1886a5509baa563e2b8e53e67453
URL: https://github.com/llvm/llvm-project/commit/fd5387f44f0f1886a5509baa563e2b8e53e67453
DIFF: https://github.com/llvm/llvm-project/commit/fd5387f44f0f1886a5509baa563e2b8e53e67453.diff
LOG: [mlir][LLVM] Add `BitcastStores` type-consistency pattern
Current patterns attempt to immediately create type-consistent stores by bitcasting the value to the type-hint leading to needless code duplication.
This patch extracts that case into its own pattern, allowing other patterns to create type-inconsistent stores and have subsequent pattern applications turn it into a type-consistent store.
Differential Revision: https://reviews.llvm.org/D154587
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
mlir/test/Dialect/LLVMIR/type-consistency.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index a8eebcd3e93405..417fb04a829dcb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -68,6 +68,17 @@ class SplitStores : public OpRewritePattern<StoreOp> {
PatternRewriter &rewrite) const override;
};
+/// Transforms type-inconsistent stores, aka stores where the type hint of
+/// the address contradicts the value stored, by inserting a bitcast if
+/// possible.
+class BitcastStores : public OpRewritePattern<StoreOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(StoreOp store,
+ PatternRewriter &rewriter) const override;
+};
+
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 0a760c7ee48117..3f581cf9bd7092 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -43,7 +43,7 @@ static Type isElementTypeInconsistent(Value addr, Type expectedType) {
}
/// Checks that two types are the same or can be bitcast into one another.
-static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) {
+static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
!isa<LLVMStructType, LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
@@ -104,7 +104,7 @@ LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
if (!firstType)
return failure();
DataLayout layout = DataLayout::closest(load);
- if (!areCastCompatible(layout, firstType, load.getResult().getType()))
+ if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
return failure();
insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
@@ -144,20 +144,13 @@ LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
DataLayout layout = DataLayout::closest(store);
// Check that the first field has the right type or can at least be bitcast
// to the right type.
- if (!areCastCompatible(layout, firstType, store.getValue().getType()))
+ if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
return failure();
insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
- Value replaceValue = store.getValue();
- if (firstType != store.getValue().getType()) {
- rewriter.setInsertionPointAfterValue(store.getValue());
- replaceValue = rewriter.create<BitcastOp>(store->getLoc(), firstType,
- store.getValue());
- }
-
rewriter.updateRootInPlace(
- store, [&]() { store.getValueMutable().assign(replaceValue); });
+ store, [&]() { store.getValueMutable().assign(store.getValue()); });
return success();
}
@@ -458,12 +451,6 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
- if (fieldIntType != type) {
- // Bitcast to the right type. `fieldIntType` was explicitly created
- // to be of the same size as `type` and must currently be a primitive as
- // well.
- valueToStore = rewriter.create<BitcastOp>(loc, type, valueToStore);
- }
// We create an `i8` indexed GEP here as that is the easiest (offset is
// already known). Other patterns turn this into a type-consistent GEP.
@@ -558,6 +545,26 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
return success();
}
+LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
+ PatternRewriter &rewriter) const {
+ Type sourceType = store.getValue().getType();
+ Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
+ if (!typeHint) {
+ // Nothing to do, since it is already consistent.
+ return failure();
+ }
+
+ auto dataLayout = DataLayout::closest(store);
+ if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
+ return failure();
+
+ auto bitcastOp =
+ rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
+ rewriter.updateRootInPlace(
+ store, [&] { store.getValueMutable().assign(bitcastOp); });
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Type consistency pass
//===----------------------------------------------------------------------===//
@@ -572,6 +579,7 @@ struct LLVMTypeConsistencyPass
&getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
+ rewritePatterns.add<BitcastStores>(&getContext());
FrozenRewritePatternSet frozen(std::move(rewritePatterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index ba477a51812a51..489dfcb8a072b7 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_floats(%arg: i64) {
// CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
- // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
+ // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
// CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
llvm.store %arg, %1 : i64, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
@@ -409,12 +409,27 @@ llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
+ // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
// CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @bitcast_insertion
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @bitcast_insertion(%arg: i32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+ // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32
+ // CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]]
+ llvm.store %arg, %1 : i32, !llvm.ptr
+ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.return
+}
More information about the Mlir-commits
mailing list