[Mlir-commits] [mlir] fd5387f - [mlir][LLVM] Add `BitcastStores` type-consistency pattern

Markus Böck llvmlistbot at llvm.org
Sun Jul 9 23:57:11 PDT 2023


Author: Markus Böck
Date: 2023-07-10T08:56:50+02:00
New Revision: fd5387f44f0f1886a5509baa563e2b8e53e67453

URL: https://github.com/llvm/llvm-project/commit/fd5387f44f0f1886a5509baa563e2b8e53e67453
DIFF: https://github.com/llvm/llvm-project/commit/fd5387f44f0f1886a5509baa563e2b8e53e67453.diff

LOG: [mlir][LLVM] Add `BitcastStores` type-consistency pattern

Current patterns attempt to immediately create type-consistent stores by bitcasting the value to the type-hint leading to needless code duplication.

This patch extracts that case into its own pattern, allowing other patterns to create type-inconsistent stores and have subsequent pattern applications turn it into a type-consistent store.

Differential Revision: https://reviews.llvm.org/D154587

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index a8eebcd3e93405..417fb04a829dcb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -68,6 +68,17 @@ class SplitStores : public OpRewritePattern<StoreOp> {
                                 PatternRewriter &rewrite) const override;
 };
 
+/// Transforms type-inconsistent stores, aka stores where the type hint of
+/// the address contradicts the value stored, by inserting a bitcast if
+/// possible.
+class BitcastStores : public OpRewritePattern<StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StoreOp store,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 0a760c7ee48117..3f581cf9bd7092 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -43,7 +43,7 @@ static Type isElementTypeInconsistent(Value addr, Type expectedType) {
 }
 
 /// Checks that two types are the same or can be bitcast into one another.
-static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) {
+static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
   return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
                         !isa<LLVMStructType, LLVMArrayType>(rhs) &&
                         layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
@@ -104,7 +104,7 @@ LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
   if (!firstType)
     return failure();
   DataLayout layout = DataLayout::closest(load);
-  if (!areCastCompatible(layout, firstType, load.getResult().getType()))
+  if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
     return failure();
 
   insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
@@ -144,20 +144,13 @@ LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
   DataLayout layout = DataLayout::closest(store);
   // Check that the first field has the right type or can at least be bitcast
   // to the right type.
-  if (!areCastCompatible(layout, firstType, store.getValue().getType()))
+  if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
     return failure();
 
   insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
 
-  Value replaceValue = store.getValue();
-  if (firstType != store.getValue().getType()) {
-    rewriter.setInsertionPointAfterValue(store.getValue());
-    replaceValue = rewriter.create<BitcastOp>(store->getLoc(), firstType,
-                                              store.getValue());
-  }
-
   rewriter.updateRootInPlace(
-      store, [&]() { store.getValueMutable().assign(replaceValue); });
+      store, [&]() { store.getValueMutable().assign(store.getValue()); });
 
   return success();
 }
@@ -458,12 +451,6 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
 
     IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
     Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
-    if (fieldIntType != type) {
-      // Bitcast to the right type. `fieldIntType` was explicitly created
-      // to be of the same size as `type` and must currently be a primitive as
-      // well.
-      valueToStore = rewriter.create<BitcastOp>(loc, type, valueToStore);
-    }
 
     // We create an `i8` indexed GEP here as that is the easiest (offset is
     // already known). Other patterns turn this into a type-consistent GEP.
@@ -558,6 +545,26 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
   return success();
 }
 
+LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
+                                             PatternRewriter &rewriter) const {
+  Type sourceType = store.getValue().getType();
+  Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
+  if (!typeHint) {
+    // Nothing to do, since it is already consistent.
+    return failure();
+  }
+
+  auto dataLayout = DataLayout::closest(store);
+  if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
+    return failure();
+
+  auto bitcastOp =
+      rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
+  rewriter.updateRootInPlace(
+      store, [&] { store.getValueMutable().assign(bitcastOp); });
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Type consistency pass
 //===----------------------------------------------------------------------===//
@@ -572,6 +579,7 @@ struct LLVMTypeConsistencyPass
         &getContext());
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
     rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
+    rewritePatterns.add<BitcastStores>(&getContext());
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index ba477a51812a51..489dfcb8a072b7 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_floats(%arg: i64) {
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   llvm.store %arg, %1 : i64, !llvm.ptr
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
@@ -409,12 +409,27 @@ llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
 // CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
 llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @bitcast_insertion
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @bitcast_insertion(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32
+  %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32
+  // CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]]
+  llvm.store %arg, %1 : i32, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}


        


More information about the Mlir-commits mailing list