[Mlir-commits] [mlir] [MLIR][LLVM][Mem2Reg] Extends support for partial stores (PR #89740)

Christian Ulmann llvmlistbot at llvm.org
Wed Apr 24 00:30:47 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/89740

>From ab238b599e0fc0e27e9c8bc16d06a912310136a4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 23 Apr 2024 11:32:41 +0000
Subject: [PATCH 1/2] [MLIR][LLVM][Mem2Reg] Extends support for partial stores

This commit enhances the LLVM dialect's Mem2Reg interfaces to support
partial stores to memory slots. To achieve this support, the
`getStored` interface method has to be extended with a parameter of the
reaching definition, which is now necessary to produce the resulting
value after this store.
---
 .../mlir/Interfaces/MemorySlotInterfaces.td   |   1 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 122 +++++++++++----
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |   2 +
 mlir/lib/Transforms/Mem2Reg.cpp               |   8 +-
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         | 141 +++++++++++++++---
 5 files changed, 223 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 8c642c0ed26aca..764fa6d547b2eb 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
            "::mlir::RewriterBase &":$rewriter,
+           "::mlir::Value":$reachingDef,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f2ab3eae2c343e..230c7fe8001bc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
 Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
-                              const DataLayout &dataLayout) {
+                              Value reachingDef, const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -144,7 +144,7 @@ static bool isSupportedTypeForConversion(Type type) {
 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
 /// truncations.
 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
-                                    Type srcType) {
+                                    Type srcType, bool allowWidening = false) {
   if (targetType == srcType)
     return true;
 
@@ -158,7 +158,8 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
       isa<LLVM::LLVMPointerType>(srcType))
     return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
 
-  return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+  return allowWidening ||
+         layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
 }
 
 /// Checks if `dataLayout` describes a little endian layout.
@@ -170,6 +171,35 @@ static bool isBigEndian(const DataLayout &dataLayout) {
 /// The size of a byte in bits.
 constexpr const static uint64_t kBitsInByte = 8;
 
+/// Converts a value to an integer type of the same size.
+/// Assumes that the type can be converted.
+static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
+                               const DataLayout &dataLayout) {
+  Type type = val.getType();
+  assert(isSupportedTypeForConversion(type));
+
+  if (isa<IntegerType>(type))
+    return val;
+
+  uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
+  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+
+  if (isa<LLVM::LLVMPointerType>(type))
+    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+}
+
+/// Converts an value with an integer type to `targetType`.
+static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
+                                   Value val, Type targetType) {
+  assert(isa<IntegerType>(val.getType()));
+  if (val.getType() == targetType)
+    return val;
+  if (isa<LLVM::LLVMPointerType>(targetType))
+    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+}
+
 /// Constructs operations that convert `inputValue` into a new value of type
 /// `targetType`. Assumes that this conversion is possible.
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
@@ -196,17 +226,8 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
     return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
                                                         srcValue);
 
-  IntegerType valueSizeInteger =
-      rewriter.getIntegerType(srcTypeSize * kBitsInByte);
-  Value replacement = srcValue;
-
   // First, cast the value to a same-sized integer type.
-  if (isa<LLVM::LLVMPointerType>(srcType))
-    replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
-                                                          replacement);
-  else if (replacement.getType() != valueSizeInteger)
-    replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
-                                                         replacement);
+  Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
   if (targetTypeSize != srcTypeSize) {
@@ -224,20 +245,67 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
   }
 
   // Now cast the integer to the actual target type if required.
-  if (isa<LLVM::LLVMPointerType>(targetType))
-    replacement =
-        rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
-  else if (replacement.getType() != targetType)
-    replacement =
-        rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
-
-  return replacement;
+  return convertIntValueToType(rewriter, loc, replacement, targetType);
 }
 
 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               Value reachingDef,
                                const DataLayout &dataLayout) {
-  return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
-                                  dataLayout);
+  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
+  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
+  if (slotTypeSize <= valueTypeSize)
+    return createConversionSequence(rewriter, getLoc(), getValue(),
+                                    slot.elemType, dataLayout);
+
+  assert(reachingDef && reachingDef.getType() == slot.elemType &&
+         "expected the reaching definition's type to slot's type");
+
+  // In the case where the store only overwrites parts of the memory,
+  // bit fiddling is required to construct the new value.
+
+  // First convert both values to integers of the same size.
+  Value defAsInt =
+      convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
+  Value valueAsInt =
+      convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+  // Extend the value to the size of the reaching definition.
+  valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
+                                                   valueAsInt);
+  uint64_t sizeDifference = slotTypeSize - valueTypeSize;
+  if (isBigEndian(dataLayout)) {
+    // On big endian systems, a store to the base pointer overwrites the most
+    // significant bits. To accomodate for this, the stored value needs to be
+    // shifted into the according position.
+    Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
+        getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+    valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
+                                                    bigEndianShift);
+  }
+
+  // Construct the mask that is used to erase the bits that are overwritten by
+  // the store.
+  APInt maskValue;
+  if (isBigEndian(dataLayout)) {
+    // Build a mask that has the most significant bits set to zero.
+    // Note: This is the same as 2^sizeDifference - 1
+    maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
+  } else {
+    // Build a mask that has the least significant bits set to zero.
+    // Note: This is the same as -(2^valueTypeSize)
+    maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
+    maskValue.flipAllBits();
+  }
+
+  // Mask out the affected bits ...
+  Value mask = rewriter.create<LLVM::ConstantOp>(
+      getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+  Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+
+  // ... and combine the result with the new value.
+  Value combined =
+      rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+
+  return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -283,7 +351,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          getValue() != slot.ptr &&
          areConversionCompatible(dataLayout, slot.elemType,
-                                 getValue().getType()) &&
+                                 getValue().getType(),
+                                 /*allowWidening=*/true) &&
          !getVolatile_();
 }
 
@@ -838,6 +907,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1219,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
@@ -1199,7 +1270,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter,
+                                      RewriterBase &rewriter, Value reachingDef,
                                       const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
@@ -1252,6 +1323,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 Value reachingDef,
                                  const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index ebbf20f1b76b67..958c5f0c8dbc75 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
 Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
 }
 
 Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 Value reachingDef,
                                  const DataLayout &dataLayout) {
   return getValue();
 }
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0c1ce70f070852..d6881b600aea7b 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
       if (memOp.storesTo(slot)) {
         rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter, dataLayout);
+        Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
 void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
                                                     Value reachingDef) {
+  assert(reachingDef && "expected an initial reaching def to be provided");
   if (region->hasOneBlock()) {
     computeReachingDefInBlock(&region->front(), reachingDef);
     return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     }
 
     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
+    assert(job.reachingDef);
 
     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
         if (info.mergePoints.contains(blockOperand.get())) {
-          if (!job.reachingDef)
-            job.reachingDef = getLazyDefaultValue();
           rewriter.modifyOpInPlace(terminator, [&]() {
             terminator.getSuccessorOperands(blockOperand.getOperandNumber())
                 .append(job.reachingDef);
@@ -601,7 +601,7 @@ void MemorySlotPromoter::removeBlockingUses() {
 }
 
 void MemorySlotPromoter::promoteSlot() {
-  computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+  computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
 
   // Now that reaching definitions are known, remove all users.
   removeBlockingUses();
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 644d30f9f9f133..130a8fce2def14 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -856,28 +856,6 @@ llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64
 
 // -----
 
-// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
-// implementation will be incorrect due to endianness considerations.
-
-// CHECK-LABEL: @stores_with_different_type_sizes
-llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.cond_br %cond, ^bb1, ^bb2
-^bb1:
-  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
-  llvm.br ^bb3
-^bb2:
-  llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
-  llvm.br ^bb3
-^bb3:
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
-  llvm.return %2 : f64
-}
-
-// -----
-
 // CHECK-LABEL: @load_smaller_int
 llvm.func @load_smaller_int() -> i16 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -1047,3 +1025,122 @@ llvm.func @scalable_llvm_vector() -> i16 {
   %2 = llvm.load %1 : !llvm.ptr -> i16
   llvm.return %2 : i16
 }
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding
+// CHECK-SAME: %[[ARG:.+]]: i16
+llvm.func @smaller_store_forwarding(%arg : i16) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+  %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-65536 : i32) : i32
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  llvm.store %arg, %1 : i16, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+  // CHECK-LABEL: @smaller_store_forwarding_big_endian
+  // CHECK-SAME: %[[ARG:.+]]: i16
+  llvm.func @smaller_store_forwarding_big_endian(%arg : i16) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NOT: llvm.alloca
+    // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+    %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+    // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+    // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+    // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) : i32
+    // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+    // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+    llvm.store %arg, %1 : i16, !llvm.ptr
+    llvm.return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding_type_mix
+// CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+  %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+  // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+  // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-256 : i32) : i32
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+  llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+  // CHECK-LABEL: @smaller_store_forwarding_type_mix
+  // CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+  llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NOT: llvm.alloca
+    // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+    %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+    // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+    // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+    // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+    // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(24 : i32) : i32
+    // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+    // CHECK: %[[MASK:.+]] = llvm.mlir.constant(16777215 : i32) : i32
+    // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+    // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+    // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+    llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+    llvm.return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_different_types_branches
+// CHECK-SAME: %[[ARG0:.+]]: i64
+// CHECK-SAME: %[[ARG1:.+]]: f32
+llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i64
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+  // CHECK: llvm.br ^[[BB3:.+]](%[[ARG0]] : i64)
+  llvm.br ^bb3
+^bb2:
+  llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+  // CHECK: %[[CAST:.+]] = llvm.bitcast %[[ARG1]] : f32 to i32
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CAST]] : i32 to i64
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-4294967296 : i64) : i64
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  // CHECK: llvm.br ^[[BB3]](%[[NEW_DEF]] : i64)
+  llvm.br ^bb3
+^bb3:
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  llvm.return %2 : f64
+}

>From bf375cf4899f720f055041e56f5ca14b33cf92a1 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 24 Apr 2024 07:30:34 +0000
Subject: [PATCH 2/2] address review comments

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 173 +++++++++++-------
 mlir/lib/Transforms/Mem2Reg.cpp               |  11 +-
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         |  13 ++
 3 files changed, 122 insertions(+), 75 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 230c7fe8001bc1..f3502bc6da1ca3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -142,9 +142,10 @@ static bool isSupportedTypeForConversion(Type type) {
 }
 
 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
-/// truncations.
+/// truncations. Checks for narrowing or widening conversion compatibility
+/// depending on `narrowingConversion`.
 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
-                                    Type srcType, bool allowWidening = false) {
+                                    Type srcType, bool narrowingConversion) {
   if (targetType == srcType)
     return true;
 
@@ -152,14 +153,18 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
       !isSupportedTypeForConversion(srcType))
     return false;
 
+  uint64_t targetSize = layout.getTypeSize(targetType);
+  uint64_t srcSize = layout.getTypeSize(srcType);
+
   // Pointer casts will only be sane when the bitsize of both pointer types is
   // the same.
   if (isa<LLVM::LLVMPointerType>(targetType) &&
       isa<LLVM::LLVMPointerType>(srcType))
-    return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
+    return targetSize == srcSize;
 
-  return allowWidening ||
-         layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+  if (narrowingConversion)
+    return targetSize <= srcSize;
+  return targetSize >= srcSize;
 }
 
 /// Checks if `dataLayout` describes a little endian layout.
@@ -168,15 +173,13 @@ static bool isBigEndian(const DataLayout &dataLayout) {
   return endiannessStr && endiannessStr == "big";
 }
 
-/// The size of a byte in bits.
-constexpr const static uint64_t kBitsInByte = 8;
-
 /// Converts a value to an integer type of the same size.
 /// Assumes that the type can be converted.
-static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
-                               const DataLayout &dataLayout) {
+static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+                                const DataLayout &dataLayout) {
   Type type = val.getType();
-  assert(isSupportedTypeForConversion(type));
+  assert(isSupportedTypeForConversion(type) &&
+         "expected value to have a convertible type");
 
   if (isa<IntegerType>(type))
     return val;
@@ -189,10 +192,13 @@ static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
   return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
 }
 
-/// Converts an value with an integer type to `targetType`.
-static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
-                                   Value val, Type targetType) {
-  assert(isa<IntegerType>(val.getType()));
+/// Converts a value with an integer type to `targetType`.
+static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+                                         Value val, Type targetType) {
+  assert(isa<IntegerType>(val.getType()) &&
+         "expected value to have an integer type");
+  assert(isSupportedTypeForConversion(targetType) &&
+         "expected the target type to be supported for conversions");
   if (val.getType() == targetType)
     return val;
   if (isa<LLVM::LLVMPointerType>(targetType))
@@ -200,19 +206,16 @@ static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
   return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
 }
 
-/// Constructs operations that convert `inputValue` into a new value of type
-/// `targetType`. Assumes that this conversion is possible.
-static Value createConversionSequence(RewriterBase &rewriter, Location loc,
-                                      Value srcValue, Type targetType,
-                                      const DataLayout &dataLayout) {
-  // Get the types of the source and target values.
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Assumes the types have the same bitsize.
+static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+                                Value srcValue, Type targetType,
+                                const DataLayout &dataLayout) {
   Type srcType = srcValue.getType();
-  assert(areConversionCompatible(dataLayout, targetType, srcType) &&
+  assert(areConversionCompatible(dataLayout, targetType, srcType,
+                                 /*narrowingConversion=*/true) &&
          "expected that the compatibility was checked before");
 
-  uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
-  uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
-
   // Nothing has to be done if the types are already the same.
   if (srcType == targetType)
     return srcValue;
@@ -226,60 +229,83 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
     return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
                                                         srcValue);
 
+  // For all other castable types, casting through integers is necessary.
+  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+}
+
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Performs bitlevel extraction if the source type is larger than
+/// the target type.
+/// Assumes that this conversion is possible.
+static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+                                  Value srcValue, Type targetType,
+                                  const DataLayout &dataLayout) {
+  // Get the types of the source and target values.
+  Type srcType = srcValue.getType();
+  assert(areConversionCompatible(dataLayout, targetType, srcType,
+                                 /*narrowingConversion=*/true) &&
+         "expected that the compatibility was checked before");
+
+  uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
+  uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
+  if (srcTypeSize == targetTypeSize)
+    return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
+
   // First, cast the value to a same-sized integer type.
-  Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
+  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
-  if (targetTypeSize != srcTypeSize) {
-    if (isBigEndian(dataLayout)) {
-      uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
-      auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
-          loc, rewriter.getIntegerAttr(srcType, shiftAmount));
-      replacement =
-          rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
-    }
-
-    replacement = rewriter.create<LLVM::TruncOp>(
-        loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
-        replacement);
+  if (isBigEndian(dataLayout)) {
+    uint64_t shiftAmount = srcTypeSize - targetTypeSize;
+    auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+    replacement =
+        rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
   }
 
+  replacement = rewriter.create<LLVM::TruncOp>(
+      loc, rewriter.getIntegerType(targetTypeSize), replacement);
+
   // Now cast the integer to the actual target type if required.
-  return convertIntValueToType(rewriter, loc, replacement, targetType);
+  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
 }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
-                               Value reachingDef,
-                               const DataLayout &dataLayout) {
-  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
-  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
-  if (slotTypeSize <= valueTypeSize)
-    return createConversionSequence(rewriter, getLoc(), getValue(),
-                                    slot.elemType, dataLayout);
+/// Constructs operations that insert the bits of `srcValue` into the
+/// "beginning" of `reachingDef` (beginning is endianness dependent).
+/// Assumes that this conversion is possible.
+static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+                                 Value srcValue, Value reachingDef,
+                                 const DataLayout &dataLayout) {
 
-  assert(reachingDef && reachingDef.getType() == slot.elemType &&
-         "expected the reaching definition's type to slot's type");
+  assert(areConversionCompatible(dataLayout, reachingDef.getType(),
+                                 srcValue.getType(),
+                                 /*narrowingConversion=*/false) &&
+         "expected that the compatibility was checked before");
+  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
+  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
+  if (slotTypeSize == valueTypeSize)
+    return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+                              dataLayout);
 
   // In the case where the store only overwrites parts of the memory,
   // bit fiddling is required to construct the new value.
 
   // First convert both values to integers of the same size.
-  Value defAsInt =
-      convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
-  Value valueAsInt =
-      convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+  Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
+  Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
   // Extend the value to the size of the reaching definition.
-  valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
-                                                   valueAsInt);
+  valueAsInt =
+      rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
   uint64_t sizeDifference = slotTypeSize - valueTypeSize;
   if (isBigEndian(dataLayout)) {
     // On big endian systems, a store to the base pointer overwrites the most
     // significant bits. To accomodate for this, the stored value needs to be
     // shifted into the according position.
     Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
-        getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
-    valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
-                                                    bigEndianShift);
+        loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+    valueAsInt =
+        rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
   }
 
   // Construct the mask that is used to erase the bits that are overwritten by
@@ -298,14 +324,23 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
 
   // Mask out the affected bits ...
   Value mask = rewriter.create<LLVM::ConstantOp>(
-      getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
-  Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+      loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+  Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
 
   // ... and combine the result with the new value.
-  Value combined =
-      rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+  Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
+
+  return castIntValueToSameSizedType(rewriter, loc, combined,
+                                     reachingDef.getType());
+}
 
-  return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               Value reachingDef,
+                               const DataLayout &dataLayout) {
+  assert(reachingDef && reachingDef.getType() == slot.elemType &&
+         "expected the reaching definition's type to match the slot's type");
+  return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
+                             dataLayout);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -317,11 +352,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   Value blockingUse = (*blockingUses.begin())->get();
   // If the blocking use is the slot ptr itself, there will be enough
   // context to reconstruct the result of the load at removal time, so it can
-  // be removed (provided it loads the exact stored value and is not
-  // volatile).
+  // be removed (provided it is not volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          areConversionCompatible(dataLayout, getResult().getType(),
-                                 slot.elemType) &&
+                                 slot.elemType, /*narrowingConversion=*/true) &&
          !getVolatile_();
 }
 
@@ -331,9 +365,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
     const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult =
-      createConversionSequence(rewriter, getLoc(), reachingDefinition,
-                               getResult().getType(), dataLayout);
+  Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
+                                         getResult().getType(), dataLayout);
   rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
@@ -352,7 +385,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
          getValue() != slot.ptr &&
          areConversionCompatible(dataLayout, slot.elemType,
                                  getValue().getType(),
-                                 /*allowWidening=*/true) &&
+                                 /*narrowingConversion=*/false) &&
          !getVolatile_();
 }
 
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index d6881b600aea7b..927d72c6477220 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -191,7 +191,7 @@ class MemorySlotPromoter {
 
   /// Lazily-constructed default value representing the content of the slot when
   /// no store has been executed. This function may mutate IR.
-  Value getLazyDefaultValue();
+  Value getOrCreateDefaultValue();
 
   MemorySlot slot;
   PromotableAllocationOpInterface allocator;
@@ -232,7 +232,7 @@ MemorySlotPromoter::MemorySlotPromoter(
 #endif // NDEBUG
 }
 
-Value MemorySlotPromoter::getLazyDefaultValue() {
+Value MemorySlotPromoter::getOrCreateDefaultValue() {
   if (defaultValue)
     return defaultValue;
 
@@ -567,7 +567,7 @@ void MemorySlotPromoter::removeBlockingUses() {
       // If no reaching definition is known, this use is outside the reach of
       // the slot. The default value should thus be used.
       if (!reachingDef)
-        reachingDef = getLazyDefaultValue();
+        reachingDef = getOrCreateDefaultValue();
 
       rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
@@ -601,7 +601,8 @@ void MemorySlotPromoter::removeBlockingUses() {
 }
 
 void MemorySlotPromoter::promoteSlot() {
-  computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
+  computeReachingDefInRegion(slot.ptr.getParentRegion(),
+                             getOrCreateDefaultValue());
 
   // Now that reaching definitions are known, remove all users.
   removeBlockingUses();
@@ -617,7 +618,7 @@ void MemorySlotPromoter::promoteSlot() {
              succOperands.size() + 1 == mergePoint->getNumArguments());
       if (succOperands.size() + 1 == mergePoint->getNumArguments())
         rewriter.modifyOpInPlace(
-            user, [&]() { succOperands.append(getLazyDefaultValue()); });
+            user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
     }
   }
 
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 130a8fce2def14..38c836c139da62 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -1144,3 +1144,16 @@ llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
   llvm.return %2 : f64
 }
+
+// -----
+
+// Verifiy that mem2reg does not touch stores with undefined semantics.
+
+// CHECK-LABEL: @store_out_of_bounds
+llvm.func @store_out_of_bounds(%arg : i64) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 : i64, !llvm.ptr
+  llvm.return
+}



More information about the Mlir-commits mailing list