[Mlir-commits] [mlir] 6e9ea6e - [MLIR][LLVM][Mem2Reg] Extends support for partial stores (#89740)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 05:28:20 PDT 2024
Author: Christian Ulmann
Date: 2024-04-24T14:28:15+02:00
New Revision: 6e9ea6ea6897561a9c3bd77b0b93e415fdc7eeb3
URL: https://github.com/llvm/llvm-project/commit/6e9ea6ea6897561a9c3bd77b0b93e415fdc7eeb3
DIFF: https://github.com/llvm/llvm-project/commit/6e9ea6ea6897561a9c3bd77b0b93e415fdc7eeb3.diff
LOG: [MLIR][LLVM][Mem2Reg] Extends support for partial stores (#89740)
This commit enhances the LLVM dialect's Mem2Reg interfaces to support
partial stores to memory slots. To achieve this support, the `getStored`
interface method has to be extended with a parameter of the reaching
definition, which is now necessary to produce the resulting value after
this store.
Added:
Modified:
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/lib/Transforms/Mem2Reg.cpp
mlir/test/Dialect/LLVMIR/mem2reg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 8c642c0ed26aca..764fa6d547b2eb 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter,
+ "::mlir::Value":$reachingDef,
"const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f2ab3eae2c343e..6eeb13ebffb51f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
- const DataLayout &dataLayout) {
+ Value reachingDef, const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -142,9 +142,10 @@ static bool isSupportedTypeForConversion(Type type) {
}
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
-/// truncations.
+/// truncations. Checks for narrowing or widening conversion compatibility
+/// depending on `narrowingConversion`.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
- Type srcType) {
+ Type srcType, bool narrowingConversion) {
if (targetType == srcType)
return true;
@@ -152,13 +153,18 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
!isSupportedTypeForConversion(srcType))
return false;
+ uint64_t targetSize = layout.getTypeSize(targetType);
+ uint64_t srcSize = layout.getTypeSize(srcType);
+
// Pointer casts will only be sane when the bitsize of both pointer types is
// the same.
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
- return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
+ return targetSize == srcSize;
- return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+ if (narrowingConversion)
+ return targetSize <= srcSize;
+ return targetSize >= srcSize;
}
/// Checks if `dataLayout` describes a little endian layout.
@@ -167,22 +173,49 @@ static bool isBigEndian(const DataLayout &dataLayout) {
return endiannessStr && endiannessStr == "big";
}
-/// The size of a byte in bits.
-constexpr const static uint64_t kBitsInByte = 8;
+/// Converts a value to an integer type of the same size.
+/// Assumes that the type can be converted.
+static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+ const DataLayout &dataLayout) {
+ Type type = val.getType();
+ assert(isSupportedTypeForConversion(type) &&
+ "expected value to have a convertible type");
+
+ if (isa<IntegerType>(type))
+ return val;
+
+ uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
+ IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+
+ if (isa<LLVM::LLVMPointerType>(type))
+ return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+}
+
+/// Converts a value with an integer type to `targetType`.
+static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+ Value val, Type targetType) {
+ assert(isa<IntegerType>(val.getType()) &&
+ "expected value to have an integer type");
+ assert(isSupportedTypeForConversion(targetType) &&
+ "expected the target type to be supported for conversions");
+ if (val.getType() == targetType)
+ return val;
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+}
-/// Constructs operations that convert `inputValue` into a new value of type
-/// `targetType`. Assumes that this conversion is possible.
-static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value srcValue, Type targetType,
- const DataLayout &dataLayout) {
- // Get the types of the source and target values.
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Assumes the types have the same bitsize.
+static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
Type srcType = srcValue.getType();
- assert(areConversionCompatible(dataLayout, targetType, srcType) &&
+ assert(areConversionCompatible(dataLayout, targetType, srcType,
+ /*narrowingConversion=*/true) &&
"expected that the compatibility was checked before");
- uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
- uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
-
// Nothing has to be done if the types are already the same.
if (srcType == targetType)
return srcValue;
@@ -196,48 +229,117 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);
- IntegerType valueSizeInteger =
- rewriter.getIntegerType(srcTypeSize * kBitsInByte);
- Value replacement = srcValue;
+ // For all other castable types, casting through integers is necessary.
+ Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+ return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+}
+
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Performs bit-level extraction if the source type is larger
+/// than the target type. Assumes that this conversion is possible.
+static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and target values.
+ Type srcType = srcValue.getType();
+ assert(areConversionCompatible(dataLayout, targetType, srcType,
+ /*narrowingConversion=*/true) &&
+ "expected that the compatibility was checked before");
+
+ uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
+ if (srcTypeSize == targetTypeSize)
+ return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
// First, cast the value to a same-sized integer type.
- if (isa<LLVM::LLVMPointerType>(srcType))
- replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
- replacement);
- else if (replacement.getType() != valueSizeInteger)
- replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
- replacement);
+ Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
// Truncate the integer if the size of the target is less than the value.
- if (targetTypeSize != srcTypeSize) {
- if (isBigEndian(dataLayout)) {
- uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
- auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getIntegerAttr(srcType, shiftAmount));
- replacement =
- rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
- }
-
- replacement = rewriter.create<LLVM::TruncOp>(
- loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
- replacement);
+ if (isBigEndian(dataLayout)) {
+ uint64_t shiftAmount = srcTypeSize - targetTypeSize;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize), replacement);
+
// Now cast the integer to the actual target type if required.
- if (isa<LLVM::LLVMPointerType>(targetType))
- replacement =
- rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
- else if (replacement.getType() != targetType)
- replacement =
- rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
+ return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+}
+
+/// Constructs operations that insert the bits of `srcValue` into the
+/// "beginning" of `reachingDef` (beginning is endianness dependent).
+/// Assumes that this conversion is possible.
+static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+ Value srcValue, Value reachingDef,
+ const DataLayout &dataLayout) {
+
+ assert(areConversionCompatible(dataLayout, reachingDef.getType(),
+ srcValue.getType(),
+ /*narrowingConversion=*/false) &&
+ "expected that the compatibility was checked before");
+ uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
+ uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
+ if (slotTypeSize == valueTypeSize)
+ return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+ dataLayout);
+
+ // In the case where the store only overwrites parts of the memory,
+ // bit fiddling is required to construct the new value.
+
+ // First convert both values to integers of the same size.
+ Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
+ Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+ // Extend the value to the size of the reaching definition.
+ valueAsInt =
+ rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
+ uint64_t sizeDifference = slotTypeSize - valueTypeSize;
+ if (isBigEndian(dataLayout)) {
+ // On big endian systems, a store to the base pointer overwrites the most
+ // significant bits. To accomodate for this, the stored value needs to be
+ // shifted into the according position.
+ Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+ valueAsInt =
+ rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
+ }
+
+ // Construct the mask that is used to erase the bits that are overwritten by
+ // the store.
+ APInt maskValue;
+ if (isBigEndian(dataLayout)) {
+ // Build a mask that has the most significant bits set to zero.
+ // Note: This is the same as 2^sizeDifference - 1
+ maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
+ } else {
+ // Build a mask that has the least significant bits set to zero.
+ // Note: This is the same as -(2^valueTypeSize)
+ maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
+ maskValue.flipAllBits();
+ }
+
+ // Mask out the affected bits ...
+ Value mask = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+ Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
+
+ // ... and combine the result with the new value.
+ Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
- return replacement;
+ return castIntValueToSameSizedType(rewriter, loc, combined,
+ reachingDef.getType());
}
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
- return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
- dataLayout);
+ assert(reachingDef && reachingDef.getType() == slot.elemType &&
+ "expected the reaching definition's type to match the slot's type");
+ return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -249,11 +351,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
Value blockingUse = (*blockingUses.begin())->get();
// If the blocking use is the slot ptr itself, there will be enough
// context to reconstruct the result of the load at removal time, so it can
- // be removed (provided it loads the exact stored value and is not
- // volatile).
+ // be removed (provided it is not volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areConversionCompatible(dataLayout, getResult().getType(),
- slot.elemType) &&
+ slot.elemType, /*narrowingConversion=*/true) &&
!getVolatile_();
}
@@ -263,9 +364,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult =
- createConversionSequence(rewriter, getLoc(), reachingDefinition,
- getResult().getType(), dataLayout);
+ Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -283,7 +383,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
- getValue().getType()) &&
+ getValue().getType(),
+ /*narrowingConversion=*/false) &&
!getVolatile_();
}
@@ -838,6 +939,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1251,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1199,7 +1302,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter,
+ RewriterBase &rewriter, Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1252,6 +1355,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index ebbf20f1b76b67..958c5f0c8dbc75 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
}
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return getValue();
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0c1ce70f070852..71ba5bc076f0e6 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -191,13 +191,13 @@ class MemorySlotPromoter {
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
- Value getLazyDefaultValue();
+ Value getOrCreateDefaultValue();
MemorySlot slot;
PromotableAllocationOpInterface allocator;
RewriterBase &rewriter;
- /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
- /// initialize it on demand.
+ /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
+ /// to initialize it on demand.
Value defaultValue;
/// Contains the reaching definition at this operation. Reaching definitions
/// are only computed for promotable memory operations with blocking uses.
@@ -232,7 +232,7 @@ MemorySlotPromoter::MemorySlotPromoter(
#endif // NDEBUG
}
-Value MemorySlotPromoter::getLazyDefaultValue() {
+Value MemorySlotPromoter::getOrCreateDefaultValue() {
if (defaultValue)
return defaultValue;
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter, dataLayout);
+ Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Value reachingDef) {
+ assert(reachingDef && "expected an initial reaching def to be provided");
if (region->hasOneBlock()) {
computeReachingDefInBlock(®ion->front(), reachingDef);
return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
+ assert(job.reachingDef);
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
if (info.mergePoints.contains(blockOperand.get())) {
- if (!job.reachingDef)
- job.reachingDef = getLazyDefaultValue();
rewriter.modifyOpInPlace(terminator, [&]() {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
@@ -567,7 +567,7 @@ void MemorySlotPromoter::removeBlockingUses() {
// If no reaching definition is known, this use is outside the reach of
// the slot. The default value should thus be used.
if (!reachingDef)
- reachingDef = getLazyDefaultValue();
+ reachingDef = getOrCreateDefaultValue();
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
@@ -601,7 +601,8 @@ void MemorySlotPromoter::removeBlockingUses() {
}
void MemorySlotPromoter::promoteSlot() {
- computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+ computeReachingDefInRegion(slot.ptr.getParentRegion(),
+ getOrCreateDefaultValue());
// Now that reaching definitions are known, remove all users.
removeBlockingUses();
@@ -617,7 +618,7 @@ void MemorySlotPromoter::promoteSlot() {
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
rewriter.modifyOpInPlace(
- user, [&]() { succOperands.append(getLazyDefaultValue()); });
+ user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
}
}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 644d30f9f9f133..38c836c139da62 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -856,28 +856,6 @@ llvm.func @stores_with_
diff erent_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64
// -----
-// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
-// implementation will be incorrect due to endianness considerations.
-
-// CHECK-LABEL: @stores_with_
diff erent_type_sizes
-llvm.func @stores_with_
diff erent_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
- %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.cond_br %cond, ^bb1, ^bb2
-^bb1:
- llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
- llvm.br ^bb3
-^bb2:
- llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
- llvm.br ^bb3
-^bb3:
- %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
- llvm.return %2 : f64
-}
-
-// -----
-
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -1047,3 +1025,135 @@ llvm.func @scalable_llvm_vector() -> i16 {
%2 = llvm.load %1 : !llvm.ptr -> i16
llvm.return %2 : i16
}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding
+// CHECK-SAME: %[[ARG:.+]]: i16
+llvm.func @smaller_store_forwarding(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-65536 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_big_endian
+ // CHECK-SAME: %[[ARG:.+]]: i16
+ llvm.func @smaller_store_forwarding_big_endian(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding_type_mix
+// CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-256 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_type_mix
+ // CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+ llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(24 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(16777215 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_
diff erent_types_branches
+// CHECK-SAME: %[[ARG0:.+]]: i64
+// CHECK-SAME: %[[ARG1:.+]]: f32
+llvm.func @stores_with_
diff erent_types_branches(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i64
+ %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+ // CHECK: llvm.br ^[[BB3:.+]](%[[ARG0]] : i64)
+ llvm.br ^bb3
+^bb2:
+ llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+ // CHECK: %[[CAST:.+]] = llvm.bitcast %[[ARG1]] : f32 to i32
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CAST]] : i32 to i64
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-4294967296 : i64) : i64
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: llvm.br ^[[BB3]](%[[NEW_DEF]] : i64)
+ llvm.br ^bb3
+^bb3:
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+ llvm.return %2 : f64
+}
+
+// -----
+
+// Verifiy that mem2reg does not touch stores with undefined semantics.
+
+// CHECK-LABEL: @store_out_of_bounds
+llvm.func @store_out_of_bounds(%arg : i64) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+ llvm.store %arg, %1 : i64, !llvm.ptr
+ llvm.return
+}
More information about the Mlir-commits
mailing list