[Mlir-commits] [mlir] [MLIR][LLVM][Mem2Reg] Extends support for partial stores (PR #89740)

Christian Ulmann llvmlistbot at llvm.org
Tue Apr 23 04:40:19 PDT 2024


https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/89740

This commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the `getStored` interface method has to be extended with a parameter of the reaching definition, which is now necessary to produce the resulting value after this store.

>From ab238b599e0fc0e27e9c8bc16d06a912310136a4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 23 Apr 2024 11:32:41 +0000
Subject: [PATCH] [MLIR][LLVM][Mem2Reg] Extends support for partial stores

This commit enhances the LLVM dialect's Mem2Reg interfaces to support
partial stores to memory slots. To achieve this support, the
`getStored` interface method has to be extended with a parameter of the
reaching definition, which is now necessary to produce the resulting
value after this store.
---
 .../mlir/Interfaces/MemorySlotInterfaces.td   |   1 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 122 +++++++++++----
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |   2 +
 mlir/lib/Transforms/Mem2Reg.cpp               |   8 +-
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         | 141 +++++++++++++++---
 5 files changed, 223 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 8c642c0ed26aca..764fa6d547b2eb 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
            "::mlir::RewriterBase &":$rewriter,
+           "::mlir::Value":$reachingDef,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f2ab3eae2c343e..230c7fe8001bc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
 Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
-                              const DataLayout &dataLayout) {
+                              Value reachingDef, const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -144,7 +144,7 @@ static bool isSupportedTypeForConversion(Type type) {
 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
 /// truncations.
 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
-                                    Type srcType) {
+                                    Type srcType, bool allowWidening = false) {
   if (targetType == srcType)
     return true;
 
@@ -158,7 +158,8 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
       isa<LLVM::LLVMPointerType>(srcType))
     return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
 
-  return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+  return allowWidening ||
+         layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
 }
 
 /// Checks if `dataLayout` describes a little endian layout.
@@ -170,6 +171,35 @@ static bool isBigEndian(const DataLayout &dataLayout) {
 /// The size of a byte in bits.
 constexpr const static uint64_t kBitsInByte = 8;
 
+/// Converts a value to an integer type of the same size.
+/// Assumes that the type can be converted.
+static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
+                               const DataLayout &dataLayout) {
+  Type type = val.getType();
+  assert(isSupportedTypeForConversion(type));
+
+  if (isa<IntegerType>(type))
+    return val;
+
+  uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
+  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+
+  if (isa<LLVM::LLVMPointerType>(type))
+    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+}
+
+/// Converts an value with an integer type to `targetType`.
+static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
+                                   Value val, Type targetType) {
+  assert(isa<IntegerType>(val.getType()));
+  if (val.getType() == targetType)
+    return val;
+  if (isa<LLVM::LLVMPointerType>(targetType))
+    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+}
+
 /// Constructs operations that convert `inputValue` into a new value of type
 /// `targetType`. Assumes that this conversion is possible.
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
@@ -196,17 +226,8 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
     return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
                                                         srcValue);
 
-  IntegerType valueSizeInteger =
-      rewriter.getIntegerType(srcTypeSize * kBitsInByte);
-  Value replacement = srcValue;
-
   // First, cast the value to a same-sized integer type.
-  if (isa<LLVM::LLVMPointerType>(srcType))
-    replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
-                                                          replacement);
-  else if (replacement.getType() != valueSizeInteger)
-    replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
-                                                         replacement);
+  Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
   if (targetTypeSize != srcTypeSize) {
@@ -224,20 +245,67 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
   }
 
   // Now cast the integer to the actual target type if required.
-  if (isa<LLVM::LLVMPointerType>(targetType))
-    replacement =
-        rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
-  else if (replacement.getType() != targetType)
-    replacement =
-        rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
-
-  return replacement;
+  return convertIntValueToType(rewriter, loc, replacement, targetType);
 }
 
 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               Value reachingDef,
                                const DataLayout &dataLayout) {
-  return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
-                                  dataLayout);
+  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
+  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
+  if (slotTypeSize <= valueTypeSize)
+    return createConversionSequence(rewriter, getLoc(), getValue(),
+                                    slot.elemType, dataLayout);
+
+  assert(reachingDef && reachingDef.getType() == slot.elemType &&
+         "expected the reaching definition's type to slot's type");
+
+  // In the case where the store only overwrites parts of the memory,
+  // bit fiddling is required to construct the new value.
+
+  // First convert both values to integers of the same size.
+  Value defAsInt =
+      convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
+  Value valueAsInt =
+      convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+  // Extend the value to the size of the reaching definition.
+  valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
+                                                   valueAsInt);
+  uint64_t sizeDifference = slotTypeSize - valueTypeSize;
+  if (isBigEndian(dataLayout)) {
+    // On big endian systems, a store to the base pointer overwrites the most
+    // significant bits. To accomodate for this, the stored value needs to be
+    // shifted into the according position.
+    Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
+        getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+    valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
+                                                    bigEndianShift);
+  }
+
+  // Construct the mask that is used to erase the bits that are overwritten by
+  // the store.
+  APInt maskValue;
+  if (isBigEndian(dataLayout)) {
+    // Build a mask that has the most significant bits set to zero.
+    // Note: This is the same as 2^sizeDifference - 1
+    maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
+  } else {
+    // Build a mask that has the least significant bits set to zero.
+    // Note: This is the same as -(2^valueTypeSize)
+    maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
+    maskValue.flipAllBits();
+  }
+
+  // Mask out the affected bits ...
+  Value mask = rewriter.create<LLVM::ConstantOp>(
+      getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+  Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+
+  // ... and combine the result with the new value.
+  Value combined =
+      rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+
+  return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -283,7 +351,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          getValue() != slot.ptr &&
          areConversionCompatible(dataLayout, slot.elemType,
-                                 getValue().getType()) &&
+                                 getValue().getType(),
+                                 /*allowWidening=*/true) &&
          !getVolatile_();
 }
 
@@ -838,6 +907,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1219,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
@@ -1199,7 +1270,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter,
+                                      RewriterBase &rewriter, Value reachingDef,
                                       const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
@@ -1252,6 +1323,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 Value reachingDef,
                                  const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index ebbf20f1b76b67..958c5f0c8dbc75 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
 Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                Value reachingDef,
                                 const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
 }
 
 Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 Value reachingDef,
                                  const DataLayout &dataLayout) {
   return getValue();
 }
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0c1ce70f070852..d6881b600aea7b 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
       if (memOp.storesTo(slot)) {
         rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter, dataLayout);
+        Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
 void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
                                                     Value reachingDef) {
+  assert(reachingDef && "expected an initial reaching def to be provided");
   if (region->hasOneBlock()) {
     computeReachingDefInBlock(&region->front(), reachingDef);
     return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     }
 
     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
+    assert(job.reachingDef);
 
     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
         if (info.mergePoints.contains(blockOperand.get())) {
-          if (!job.reachingDef)
-            job.reachingDef = getLazyDefaultValue();
           rewriter.modifyOpInPlace(terminator, [&]() {
             terminator.getSuccessorOperands(blockOperand.getOperandNumber())
                 .append(job.reachingDef);
@@ -601,7 +601,7 @@ void MemorySlotPromoter::removeBlockingUses() {
 }
 
 void MemorySlotPromoter::promoteSlot() {
-  computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+  computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
 
   // Now that reaching definitions are known, remove all users.
   removeBlockingUses();
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 644d30f9f9f133..130a8fce2def14 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -856,28 +856,6 @@ llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64
 
 // -----
 
-// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
-// implementation will be incorrect due to endianness considerations.
-
-// CHECK-LABEL: @stores_with_different_type_sizes
-llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.cond_br %cond, ^bb1, ^bb2
-^bb1:
-  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
-  llvm.br ^bb3
-^bb2:
-  llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
-  llvm.br ^bb3
-^bb3:
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
-  llvm.return %2 : f64
-}
-
-// -----
-
 // CHECK-LABEL: @load_smaller_int
 llvm.func @load_smaller_int() -> i16 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -1047,3 +1025,122 @@ llvm.func @scalable_llvm_vector() -> i16 {
   %2 = llvm.load %1 : !llvm.ptr -> i16
   llvm.return %2 : i16
 }
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding
+// CHECK-SAME: %[[ARG:.+]]: i16
+llvm.func @smaller_store_forwarding(%arg : i16) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+  %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-65536 : i32) : i32
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  llvm.store %arg, %1 : i16, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+  // CHECK-LABEL: @smaller_store_forwarding_big_endian
+  // CHECK-SAME: %[[ARG:.+]]: i16
+  llvm.func @smaller_store_forwarding_big_endian(%arg : i16) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NOT: llvm.alloca
+    // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+    %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+    // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+    // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+    // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) : i32
+    // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+    // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+    llvm.store %arg, %1 : i16, !llvm.ptr
+    llvm.return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding_type_mix
+// CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+  %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+  // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+  // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-256 : i32) : i32
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+  llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+  // CHECK-LABEL: @smaller_store_forwarding_type_mix
+  // CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+  llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NOT: llvm.alloca
+    // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+    %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+    // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+    // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+    // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+    // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(24 : i32) : i32
+    // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+    // CHECK: %[[MASK:.+]] = llvm.mlir.constant(16777215 : i32) : i32
+    // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+    // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+    // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+    llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+    llvm.return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_different_types_branches
+// CHECK-SAME: %[[ARG0:.+]]: i64
+// CHECK-SAME: %[[ARG1:.+]]: f32
+llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i64
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+  // CHECK: llvm.br ^[[BB3:.+]](%[[ARG0]] : i64)
+  llvm.br ^bb3
+^bb2:
+  llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+  // CHECK: %[[CAST:.+]] = llvm.bitcast %[[ARG1]] : f32 to i32
+  // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CAST]] : i32 to i64
+  // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-4294967296 : i64) : i64
+  // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+  // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+  // CHECK: llvm.br ^[[BB3]](%[[NEW_DEF]] : i64)
+  llvm.br ^bb3
+^bb3:
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  llvm.return %2 : f64
+}



More information about the Mlir-commits mailing list