[Mlir-commits] [mlir] 974f1ee - [MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87637)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 4 23:26:00 PDT 2024


Author: Christian Ulmann
Date: 2024-04-05T08:25:36+02:00
New Revision: 974f1ee58da1c51c547eaf5c7007a215fd286c68

URL: https://github.com/llvm/llvm-project/commit/974f1ee58da1c51c547eaf5c7007a215fd286c68
DIFF: https://github.com/llvm/llvm-project/commit/974f1ee58da1c51c547eaf5c7007a215fd286c68.diff

LOG: [MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87637)

This commit relaxes Mem2Reg's type equality requirement for the LLVM
dialect's load and store operations. For now, we only allow loads to be
promoted if the reaching definition can be casted into a value of the
target type.

For stores, the same conversion casting check is applied and we ensure
that their result is properly casted to the type of the memory slot.
This is necessary to satisfy assumptions of the general mem2reg pass, as
it creates block arguments with the types of the memory slot.

This relands https://github.com/llvm/llvm-project/pull/87504

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/test/Dialect/LLVMIR/mem2reg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 563755297c5791..c7ca0b4a5843ad 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -122,8 +122,37 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
+/// Checks that two types are the same or can be cast into one another.
+static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
+  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
+                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
+                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+}
+
+/// Constructs operations that convert `inputValue` into a new value of type
+/// `targetType`. Assumes that this conversion is possible.
+static Value createConversionSequence(RewriterBase &rewriter, Location loc,
+                                      Value inputValue, Type targetType) {
+  if (inputValue.getType() == targetType)
+    return inputValue;
+
+  if (!isa<LLVM::LLVMPointerType>(targetType) &&
+      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
+    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+
+  if (!isa<LLVM::LLVMPointerType>(targetType))
+    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+
+  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
+    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+
+  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                      inputValue);
+}
+
 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
-  return getValue();
+  return createConversionSequence(rewriter, getLoc(), getValue(),
+                                  slot.elemType);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -138,7 +167,8 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         getResult().getType() == slot.elemType && !getVolatile_();
+         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+         !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
@@ -146,7 +176,9 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
     RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+  Value newResult = createConversionSequence(
+      rewriter, getLoc(), reachingDefinition, getResult().getType());
+  rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
 
@@ -161,7 +193,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // fine, provided we are currently promoting its target value. Don't allow a
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         getValue() != slot.ptr && getValue().getType() == slot.elemType &&
+         getValue() != slot.ptr &&
+         areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
          !getVolatile_();
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 90e56c1166edfd..61a3d933ee1510 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -697,3 +697,250 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
   %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
   llvm.return %3 : !llvm.ptr
 }
+
+// -----
+
+// CHECK-LABEL: @load_int_from_float
+llvm.func @load_int_from_float() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @load_float_from_int
+llvm.func @load_float_from_int() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_vector
+llvm.func @load_int_from_vector() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
+
+// CHECK-LABEL: @load_int_from_array
+llvm.func @load_int_from_array() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK-NOT: llvm.bitcast
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_float
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_float(%arg: i32) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[ARG]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_float_to_int
+// CHECK-SAME: %[[ARG:.*]]: f32
+llvm.func @store_float_to_int(%arg: f32) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_vector
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
+  // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
+  // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
+  // CHECK: llvm.return %[[BITCAST1]]
+  llvm.return %2 : vector<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_from_int
+llvm.func @load_ptr_from_int() -> !llvm.ptr {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : !llvm.ptr
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_ptr
+llvm.func @load_int_from_ptr() -> i64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_addrspace_cast
+llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : !llvm.ptr<2>
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_
diff erent_types
+// CHECK-SAME: %[[ARG0:.*]]: i64
+// CHECK-SAME: %[[ARG1:.*]]: f64
+llvm.func @stores_with_
diff erent_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+  // CHECK: llvm.br ^[[BB3:.*]](%[[ARG0]]
+  llvm.br ^bb3
+^bb2:
+  llvm.store %arg1, %1 {alignment = 4 : i64} : f64, !llvm.ptr
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG1]] : f64 to i64
+  // CHECK: llvm.br ^[[BB3]](%[[BITCAST]]
+  llvm.br ^bb3
+// CHECK: ^[[BB3]](%[[BLOCK_ARG:.*]]: i64)
+^bb3:
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[BLOCK_ARG]] : i64 to f64
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %2 : f64
+}
+
+// -----
+
+// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
+// implementation will be incorrect due to endianness considerations.
+
+// CHECK-LABEL: @stores_with_
diff erent_type_sizes
+llvm.func @stores_with_
diff erent_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+  llvm.br ^bb3
+^bb2:
+  llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+  llvm.br ^bb3
+^bb3:
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  llvm.return %2 : f64
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_int
+llvm.func @load_smaller_int() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
+  llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @load_
diff erent_type_smaller
+llvm.func @load_
diff erent_type_smaller() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// This alloca is too small for the load, still, mem2reg should not touch it.
+
+// CHECK-LABEL: @impossible_load
+llvm.func @impossible_load() -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  llvm.return %2 : f64
+}
+
+// -----
+
+// Verifies that mem2reg does not introduce address space casts of pointers
+// with 
diff erent bitsize.
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
+  #dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
+>} {
+
+  // CHECK-LABEL: @load_ptr_addrspace_cast_
diff erent_size
+  llvm.func @load_ptr_addrspace_cast_
diff erent_size() -> !llvm.ptr<2> {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: llvm.alloca
+    %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+    %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+    llvm.return %2 : !llvm.ptr<2>
+  }
+}


        


More information about the Mlir-commits mailing list