[Mlir-commits] [mlir] [MLIR][LLVM][Mem2Reg] Extends support for partial stores (PR #89740)
Christian Ulmann
llvmlistbot at llvm.org
Wed Apr 24 00:30:47 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/89740
>From ab238b599e0fc0e27e9c8bc16d06a912310136a4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 23 Apr 2024 11:32:41 +0000
Subject: [PATCH 1/2] [MLIR][LLVM][Mem2Reg] Extends support for partial stores
This commit enhances the LLVM dialect's Mem2Reg interfaces to support
partial stores to memory slots. To achieve this support, the
`getStored` interface method has to be extended with a parameter of the
reaching definition, which is now necessary to produce the resulting
value after this store.
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 1 +
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 122 +++++++++++----
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 2 +
mlir/lib/Transforms/Mem2Reg.cpp | 8 +-
mlir/test/Dialect/LLVMIR/mem2reg.mlir | 141 +++++++++++++++---
5 files changed, 223 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 8c642c0ed26aca..764fa6d547b2eb 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter,
+ "::mlir::Value":$reachingDef,
"const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f2ab3eae2c343e..230c7fe8001bc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
- const DataLayout &dataLayout) {
+ Value reachingDef, const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -144,7 +144,7 @@ static bool isSupportedTypeForConversion(Type type) {
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
/// truncations.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
- Type srcType) {
+ Type srcType, bool allowWidening = false) {
if (targetType == srcType)
return true;
@@ -158,7 +158,8 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
isa<LLVM::LLVMPointerType>(srcType))
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
- return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+ return allowWidening ||
+ layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
}
/// Checks if `dataLayout` describes a little endian layout.
@@ -170,6 +171,35 @@ static bool isBigEndian(const DataLayout &dataLayout) {
/// The size of a byte in bits.
constexpr const static uint64_t kBitsInByte = 8;
+/// Converts a value to an integer type of the same size.
+/// Assumes that the type can be converted.
+static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
+ const DataLayout &dataLayout) {
+ Type type = val.getType();
+ assert(isSupportedTypeForConversion(type));
+
+ if (isa<IntegerType>(type))
+ return val;
+
+ uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
+ IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+
+ if (isa<LLVM::LLVMPointerType>(type))
+ return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+}
+
+/// Converts an value with an integer type to `targetType`.
+static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
+ Value val, Type targetType) {
+ assert(isa<IntegerType>(val.getType()));
+ if (val.getType() == targetType)
+ return val;
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+}
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
@@ -196,17 +226,8 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);
- IntegerType valueSizeInteger =
- rewriter.getIntegerType(srcTypeSize * kBitsInByte);
- Value replacement = srcValue;
-
// First, cast the value to a same-sized integer type.
- if (isa<LLVM::LLVMPointerType>(srcType))
- replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
- replacement);
- else if (replacement.getType() != valueSizeInteger)
- replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
- replacement);
+ Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
// Truncate the integer if the size of the target is less than the value.
if (targetTypeSize != srcTypeSize) {
@@ -224,20 +245,67 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
}
// Now cast the integer to the actual target type if required.
- if (isa<LLVM::LLVMPointerType>(targetType))
- replacement =
- rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
- else if (replacement.getType() != targetType)
- replacement =
- rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
-
- return replacement;
+ return convertIntValueToType(rewriter, loc, replacement, targetType);
}
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
- return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
- dataLayout);
+ uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
+ uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
+ if (slotTypeSize <= valueTypeSize)
+ return createConversionSequence(rewriter, getLoc(), getValue(),
+ slot.elemType, dataLayout);
+
+ assert(reachingDef && reachingDef.getType() == slot.elemType &&
+ "expected the reaching definition's type to slot's type");
+
+ // In the case where the store only overwrites parts of the memory,
+ // bit fiddling is required to construct the new value.
+
+ // First convert both values to integers of the same size.
+ Value defAsInt =
+ convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
+ Value valueAsInt =
+ convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+ // Extend the value to the size of the reaching definition.
+ valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
+ valueAsInt);
+ uint64_t sizeDifference = slotTypeSize - valueTypeSize;
+ if (isBigEndian(dataLayout)) {
+ // On big endian systems, a store to the base pointer overwrites the most
+ // significant bits. To accomodate for this, the stored value needs to be
+ // shifted into the according position.
+ Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
+ getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+ valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
+ bigEndianShift);
+ }
+
+ // Construct the mask that is used to erase the bits that are overwritten by
+ // the store.
+ APInt maskValue;
+ if (isBigEndian(dataLayout)) {
+ // Build a mask that has the most significant bits set to zero.
+ // Note: This is the same as 2^sizeDifference - 1
+ maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
+ } else {
+ // Build a mask that has the least significant bits set to zero.
+ // Note: This is the same as -(2^valueTypeSize)
+ maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
+ maskValue.flipAllBits();
+ }
+
+ // Mask out the affected bits ...
+ Value mask = rewriter.create<LLVM::ConstantOp>(
+ getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+ Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+
+ // ... and combine the result with the new value.
+ Value combined =
+ rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+
+ return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -283,7 +351,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
- getValue().getType()) &&
+ getValue().getType(),
+ /*allowWidening=*/true) &&
!getVolatile_();
}
@@ -838,6 +907,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1219,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1199,7 +1270,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter,
+ RewriterBase &rewriter, Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1252,6 +1323,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index ebbf20f1b76b67..958c5f0c8dbc75 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
}
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return getValue();
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0c1ce70f070852..d6881b600aea7b 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter, dataLayout);
+ Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Value reachingDef) {
+ assert(reachingDef && "expected an initial reaching def to be provided");
if (region->hasOneBlock()) {
computeReachingDefInBlock(®ion->front(), reachingDef);
return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
+ assert(job.reachingDef);
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
if (info.mergePoints.contains(blockOperand.get())) {
- if (!job.reachingDef)
- job.reachingDef = getLazyDefaultValue();
rewriter.modifyOpInPlace(terminator, [&]() {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
@@ -601,7 +601,7 @@ void MemorySlotPromoter::removeBlockingUses() {
}
void MemorySlotPromoter::promoteSlot() {
- computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+ computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
// Now that reaching definitions are known, remove all users.
removeBlockingUses();
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 644d30f9f9f133..130a8fce2def14 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -856,28 +856,6 @@ llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64
// -----
-// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
-// implementation will be incorrect due to endianness considerations.
-
-// CHECK-LABEL: @stores_with_different_type_sizes
-llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
- %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.cond_br %cond, ^bb1, ^bb2
-^bb1:
- llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
- llvm.br ^bb3
-^bb2:
- llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
- llvm.br ^bb3
-^bb3:
- %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
- llvm.return %2 : f64
-}
-
-// -----
-
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -1047,3 +1025,122 @@ llvm.func @scalable_llvm_vector() -> i16 {
%2 = llvm.load %1 : !llvm.ptr -> i16
llvm.return %2 : i16
}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding
+// CHECK-SAME: %[[ARG:.+]]: i16
+llvm.func @smaller_store_forwarding(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-65536 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_big_endian
+ // CHECK-SAME: %[[ARG:.+]]: i16
+ llvm.func @smaller_store_forwarding_big_endian(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding_type_mix
+// CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-256 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_type_mix
+ // CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+ llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(24 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(16777215 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_different_types_branches
+// CHECK-SAME: %[[ARG0:.+]]: i64
+// CHECK-SAME: %[[ARG1:.+]]: f32
+llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i64
+ %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+ // CHECK: llvm.br ^[[BB3:.+]](%[[ARG0]] : i64)
+ llvm.br ^bb3
+^bb2:
+ llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+ // CHECK: %[[CAST:.+]] = llvm.bitcast %[[ARG1]] : f32 to i32
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CAST]] : i32 to i64
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-4294967296 : i64) : i64
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: llvm.br ^[[BB3]](%[[NEW_DEF]] : i64)
+ llvm.br ^bb3
+^bb3:
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+ llvm.return %2 : f64
+}
>From bf375cf4899f720f055041e56f5ca14b33cf92a1 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 24 Apr 2024 07:30:34 +0000
Subject: [PATCH 2/2] address review comments
---
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 173 +++++++++++-------
mlir/lib/Transforms/Mem2Reg.cpp | 11 +-
mlir/test/Dialect/LLVMIR/mem2reg.mlir | 13 ++
3 files changed, 122 insertions(+), 75 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 230c7fe8001bc1..f3502bc6da1ca3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -142,9 +142,10 @@ static bool isSupportedTypeForConversion(Type type) {
}
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
-/// truncations.
+/// truncations. Checks for narrowing or widening conversion compatibility
+/// depending on `narrowingConversion`.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
- Type srcType, bool allowWidening = false) {
+ Type srcType, bool narrowingConversion) {
if (targetType == srcType)
return true;
@@ -152,14 +153,18 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
!isSupportedTypeForConversion(srcType))
return false;
+ uint64_t targetSize = layout.getTypeSize(targetType);
+ uint64_t srcSize = layout.getTypeSize(srcType);
+
// Pointer casts will only be sane when the bitsize of both pointer types is
// the same.
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
- return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
+ return targetSize == srcSize;
- return allowWidening ||
- layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+ if (narrowingConversion)
+ return targetSize <= srcSize;
+ return targetSize >= srcSize;
}
/// Checks if `dataLayout` describes a little endian layout.
@@ -168,15 +173,13 @@ static bool isBigEndian(const DataLayout &dataLayout) {
return endiannessStr && endiannessStr == "big";
}
-/// The size of a byte in bits.
-constexpr const static uint64_t kBitsInByte = 8;
-
/// Converts a value to an integer type of the same size.
/// Assumes that the type can be converted.
-static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
- const DataLayout &dataLayout) {
+static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+ const DataLayout &dataLayout) {
Type type = val.getType();
- assert(isSupportedTypeForConversion(type));
+ assert(isSupportedTypeForConversion(type) &&
+ "expected value to have a convertible type");
if (isa<IntegerType>(type))
return val;
@@ -189,10 +192,13 @@ static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
}
-/// Converts an value with an integer type to `targetType`.
-static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
- Value val, Type targetType) {
- assert(isa<IntegerType>(val.getType()));
+/// Converts a value with an integer type to `targetType`.
+static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+ Value val, Type targetType) {
+ assert(isa<IntegerType>(val.getType()) &&
+ "expected value to have an integer type");
+ assert(isSupportedTypeForConversion(targetType) &&
+ "expected the target type to be supported for conversions");
if (val.getType() == targetType)
return val;
if (isa<LLVM::LLVMPointerType>(targetType))
@@ -200,19 +206,16 @@ static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
}
-/// Constructs operations that convert `inputValue` into a new value of type
-/// `targetType`. Assumes that this conversion is possible.
-static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value srcValue, Type targetType,
- const DataLayout &dataLayout) {
- // Get the types of the source and target values.
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Assumes the types have the same bitsize.
+static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
Type srcType = srcValue.getType();
- assert(areConversionCompatible(dataLayout, targetType, srcType) &&
+ assert(areConversionCompatible(dataLayout, targetType, srcType,
+ /*narrowingConversion=*/true) &&
"expected that the compatibility was checked before");
- uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
- uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
-
// Nothing has to be done if the types are already the same.
if (srcType == targetType)
return srcValue;
@@ -226,60 +229,83 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);
+ // For all other castable types, casting through integers is necessary.
+ Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+ return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+}
+
+/// Constructs operations that convert `srcValue` into a new value of type
+/// `targetType`. Performs bitlevel extraction if the source type is larger than
+/// the target type.
+/// Assumes that this conversion is possible.
+static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and target values.
+ Type srcType = srcValue.getType();
+ assert(areConversionCompatible(dataLayout, targetType, srcType,
+ /*narrowingConversion=*/true) &&
+ "expected that the compatibility was checked before");
+
+ uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
+ if (srcTypeSize == targetTypeSize)
+ return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
+
// First, cast the value to a same-sized integer type.
- Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
+ Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
// Truncate the integer if the size of the target is less than the value.
- if (targetTypeSize != srcTypeSize) {
- if (isBigEndian(dataLayout)) {
- uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
- auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getIntegerAttr(srcType, shiftAmount));
- replacement =
- rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
- }
-
- replacement = rewriter.create<LLVM::TruncOp>(
- loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
- replacement);
+ if (isBigEndian(dataLayout)) {
+ uint64_t shiftAmount = srcTypeSize - targetTypeSize;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize), replacement);
+
// Now cast the integer to the actual target type if required.
- return convertIntValueToType(rewriter, loc, replacement, targetType);
+ return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
}
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
- Value reachingDef,
- const DataLayout &dataLayout) {
- uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
- uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
- if (slotTypeSize <= valueTypeSize)
- return createConversionSequence(rewriter, getLoc(), getValue(),
- slot.elemType, dataLayout);
+/// Constructs operations that insert the bits of `srcValue` into the
+/// "beginning" of `reachingDef` (beginning is endianness dependent).
+/// Assumes that this conversion is possible.
+static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+ Value srcValue, Value reachingDef,
+ const DataLayout &dataLayout) {
- assert(reachingDef && reachingDef.getType() == slot.elemType &&
- "expected the reaching definition's type to slot's type");
+ assert(areConversionCompatible(dataLayout, reachingDef.getType(),
+ srcValue.getType(),
+ /*narrowingConversion=*/false) &&
+ "expected that the compatibility was checked before");
+ uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
+ uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
+ if (slotTypeSize == valueTypeSize)
+ return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+ dataLayout);
// In the case where the store only overwrites parts of the memory,
// bit fiddling is required to construct the new value.
// First convert both values to integers of the same size.
- Value defAsInt =
- convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
- Value valueAsInt =
- convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+ Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
+ Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
// Extend the value to the size of the reaching definition.
- valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
- valueAsInt);
+ valueAsInt =
+ rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
if (isBigEndian(dataLayout)) {
// On big endian systems, a store to the base pointer overwrites the most
// significant bits. To accomodate for this, the stored value needs to be
// shifted into the according position.
Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
- getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
- valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
- bigEndianShift);
+ loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+ valueAsInt =
+ rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
}
// Construct the mask that is used to erase the bits that are overwritten by
@@ -298,14 +324,23 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
// Mask out the affected bits ...
Value mask = rewriter.create<LLVM::ConstantOp>(
- getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
- Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+ loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+ Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
// ... and combine the result with the new value.
- Value combined =
- rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+ Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
+
+ return castIntValueToSameSizedType(rewriter, loc, combined,
+ reachingDef.getType());
+}
- return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
+ const DataLayout &dataLayout) {
+ assert(reachingDef && reachingDef.getType() == slot.elemType &&
+ "expected the reaching definition's type to match the slot's type");
+ return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -317,11 +352,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
Value blockingUse = (*blockingUses.begin())->get();
// If the blocking use is the slot ptr itself, there will be enough
// context to reconstruct the result of the load at removal time, so it can
- // be removed (provided it loads the exact stored value and is not
- // volatile).
+ // be removed (provided it is not volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areConversionCompatible(dataLayout, getResult().getType(),
- slot.elemType) &&
+ slot.elemType, /*narrowingConversion=*/true) &&
!getVolatile_();
}
@@ -331,9 +365,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult =
- createConversionSequence(rewriter, getLoc(), reachingDefinition,
- getResult().getType(), dataLayout);
+ Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -352,7 +385,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
getValue().getType(),
- /*allowWidening=*/true) &&
+ /*narrowingConversion=*/false) &&
!getVolatile_();
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index d6881b600aea7b..927d72c6477220 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -191,7 +191,7 @@ class MemorySlotPromoter {
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
- Value getLazyDefaultValue();
+ Value getOrCreateDefaultValue();
MemorySlot slot;
PromotableAllocationOpInterface allocator;
@@ -232,7 +232,7 @@ MemorySlotPromoter::MemorySlotPromoter(
#endif // NDEBUG
}
-Value MemorySlotPromoter::getLazyDefaultValue() {
+Value MemorySlotPromoter::getOrCreateDefaultValue() {
if (defaultValue)
return defaultValue;
@@ -567,7 +567,7 @@ void MemorySlotPromoter::removeBlockingUses() {
// If no reaching definition is known, this use is outside the reach of
// the slot. The default value should thus be used.
if (!reachingDef)
- reachingDef = getLazyDefaultValue();
+ reachingDef = getOrCreateDefaultValue();
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
@@ -601,7 +601,8 @@ void MemorySlotPromoter::removeBlockingUses() {
}
void MemorySlotPromoter::promoteSlot() {
- computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
+ computeReachingDefInRegion(slot.ptr.getParentRegion(),
+ getOrCreateDefaultValue());
// Now that reaching definitions are known, remove all users.
removeBlockingUses();
@@ -617,7 +618,7 @@ void MemorySlotPromoter::promoteSlot() {
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
rewriter.modifyOpInPlace(
- user, [&]() { succOperands.append(getLazyDefaultValue()); });
+ user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
}
}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 130a8fce2def14..38c836c139da62 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -1144,3 +1144,16 @@ llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
llvm.return %2 : f64
}
+
+// -----
+
+// Verifiy that mem2reg does not touch stores with undefined semantics.
+
+// CHECK-LABEL: @store_out_of_bounds
+llvm.func @store_out_of_bounds(%arg : i64) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+ llvm.store %arg, %1 : i64, !llvm.ptr
+ llvm.return
+}
More information about the Mlir-commits
mailing list