[Mlir-commits] [mlir] [MLIR][Mem2Reg][LLVM] Enhance partial load support (PR #89094)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 09:00:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Christian Ulmann (Dinistro)
<details>
<summary>Changes</summary>
This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.
Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.
---
Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89094.diff
5 Files Affected:
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+4-2)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+136-38)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-6)
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+10-7)
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+112-17)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
- "::mlir::RewriterBase &":$rewriter)
+ "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter,
- "::mlir::Value":$reachingDefinition)
+ "::mlir::Value":$reachingDefinition,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
- layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+ // Aggregate types are not bitcastable.
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+ return false;
+
+ // LLVM vector types are only used for either pointers or target specific
+ // types. These types cannot be casted in the general case, thus the memory
+ // optimizations do not support them.
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+ return false;
+
+ // Scalable types are not supported.
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return !vectorType.isScalable();
+ return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+ Type rhs) {
+ if (lhs == rhs)
+ return true;
+
+ // Aggregate types cannot be casted.
+ if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+ return false;
+ return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
}
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+ return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value inputValue, Type targetType) {
- if (inputValue.getType() == targetType)
- return inputValue;
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and destination values.
+ Type srcType = srcValue.getType();
+
+ uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+ // Nothing has to be done if the types are already the same.
+ if (srcType == targetType)
+ return srcValue;
+
+ // The code below is currently not capable of handling aggregate types as it
+ // makes use of bitcasts. Aggregates cannot be bitcast.
+ // TODO: We should have a `LLVMAggregateType` base class to easily perform
+ // this `isa`.
+ if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+ isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+ return nullptr;
+
+ // In the special case of casting one pointer to another, we want to generate
+ // an address space cast. Bitcasts of pointers are not allowed and using
+ // pointer to integer conversions are not equivalent due to the loss or
+ // provenance.
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
+ isa<LLVM::LLVMPointerType>(srcType)) {
+ // Abort the conversion if the pointers have different bitwidths.
+ if (srcTypeSize != targetTypeSize)
+ return nullptr;
+ return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+ srcValue);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
- !isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+ IntegerType valueSizeInteger =
+ rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+ Value replacement = srcValue;
+
+ // First, cast the value to a same-sized integer type.
+ if (isa<LLVM::LLVMPointerType>(srcType))
+ replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+ replacement);
+ else if (replacement.getType() != valueSizeInteger)
+ replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+ replacement);
+
+ // Truncate the integer if the size of the read is less than the value.
+ if (targetTypeSize != srcTypeSize) {
+ if (!isLittleEndian(dataLayout)) {
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType))
- return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+ replacement);
+ }
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+ // Now cast the integer to the actual destination type if required.
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ replacement =
+ rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+ else if (replacement.getType() != targetType)
+ replacement =
+ rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
- return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
- inputValue);
+ return replacement;
}
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
- return createConversionSequence(rewriter, getLoc(), getValue(),
- slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+ areConversionCompatible(dataLayout, getResult().getType(),
+ slot.elemType) &&
!getVolatile_();
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult = createConversionSequence(
- rewriter, getLoc(), reachingDefinition, getResult().getType());
+ Value newResult =
+ createConversionSequence(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
- areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+ areConversionCompatible(dataLayout, slot.elemType,
+ getValue().getType()) &&
!getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value memref::LoadOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
-Value memref::StoreOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return getValue();
}
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
public:
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info,
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics);
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
+ const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
};
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics)
: slot(slot), allocator(allocator), rewriter(rewriter),
- dominance(dominance), info(std::move(info)), statistics(statistics) {
+ dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+ statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter);
+ Value stored = memOp.getStored(slot, rewriter, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], rewriter,
- reachingDef) == DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+ dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
- MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
std::move(*info), statistics)
.promoteSlot();
promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
// -----
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- %1 = llvm.mlir.constant(0 : i32) : i32
- // CHECK: = llvm.alloca
- %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
- %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
- llvm.return %3 : i16
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @merge_point_cycle
llvm.func @merge_point_cycle() {
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
+ // CHECK-NOT: llvm.alloca
%1 = ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/89094
More information about the Mlir-commits
mailing list