[Mlir-commits] [mlir] ac39fa7 - [MLIR][Mem2Reg][LLVM] Enhance partial load support (#89094)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 18 04:09:20 PDT 2024
Author: Christian Ulmann
Date: 2024-04-18T13:09:16+02:00
New Revision: ac39fa740b067f6197dca1caecc97c0da91ebf3d
URL: https://github.com/llvm/llvm-project/commit/ac39fa740b067f6197dca1caecc97c0da91ebf3d
DIFF: https://github.com/llvm/llvm-project/commit/ac39fa740b067f6197dca1caecc97c0da91ebf3d.diff
LOG: [MLIR][Mem2Reg][LLVM] Enhance partial load support (#89094)
This commit improves LLVM dialect's Mem2Reg interfaces to support
promotions of partial loads from larger memory slots. To support this,
the Mem2Reg interface methods are extended with additional data layout
parameters. The data layout is required to determine type sizes to
produce correct conversion sequences.
Note: There will be additional followups that introduce a similar
functionality for stores, and there are plans to support accesses into
the middle of memory slots.
Added:
Modified:
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/lib/Transforms/Mem2Reg.cpp
mlir/test/Dialect/LLVMIR/mem2reg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
- "::mlir::RewriterBase &":$rewriter)
+ "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter,
- "::mlir::Value":$reachingDefinition)
+ "::mlir::Value":$reachingDefinition,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..93901477b58204 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -122,37 +123,121 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
- layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+ // Aggregate types are not bitcastable.
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+ return false;
+
+ // LLVM vector types are only used for either pointers or target specific
+ // types. These types cannot be casted in the general case, thus the memory
+ // optimizations do not support them.
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+ return false;
+
+ // Scalable types are not supported.
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return !vectorType.isScalable();
+ return true;
}
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type targetType,
+ Type srcType) {
+ if (targetType == srcType)
+ return true;
+
+ if (!isSupportedTypeForConversion(targetType) ||
+ !isSupportedTypeForConversion(srcType))
+ return false;
+
+ // Pointer casts will only be sane when the bitsize of both pointer types is
+ // the same.
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
+ isa<LLVM::LLVMPointerType>(srcType))
+ return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
+
+ return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+}
+
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isBigEndian(const DataLayout &dataLayout) {
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+ return endiannessStr && endiannessStr == "big";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value inputValue, Type targetType) {
- if (inputValue.getType() == targetType)
- return inputValue;
-
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
- !isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and target values.
+ Type srcType = srcValue.getType();
+ assert(areConversionCompatible(dataLayout, targetType, srcType) &&
+ "expected that the compatibility was checked before");
+
+ uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+ // Nothing has to be done if the types are already the same.
+ if (srcType == targetType)
+ return srcValue;
+
+ // In the special case of casting one pointer to another, we want to generate
+ // an address space cast. Bitcasts of pointers are not allowed and using
+ // pointer to integer conversions are not equivalent due to the loss of
+ // provenance.
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
+ isa<LLVM::LLVMPointerType>(srcType))
+ return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+ srcValue);
+
+ IntegerType valueSizeInteger =
+ rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+ Value replacement = srcValue;
+
+ // First, cast the value to a same-sized integer type.
+ if (isa<LLVM::LLVMPointerType>(srcType))
+ replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+ replacement);
+ else if (replacement.getType() != valueSizeInteger)
+ replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+ replacement);
+
+ // Truncate the integer if the size of the target is less than the value.
+ if (targetTypeSize != srcTypeSize) {
+ if (isBigEndian(dataLayout)) {
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType))
- return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+ replacement);
+ }
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+ // Now cast the integer to the actual target type if required.
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ replacement =
+ rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+ else if (replacement.getType() != targetType)
+ replacement =
+ rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
- return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
- inputValue);
+ return replacement;
}
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
- return createConversionSequence(rewriter, getLoc(), getValue(),
- slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +252,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+ areConversionCompatible(dataLayout, getResult().getType(),
+ slot.elemType) &&
!getVolatile_();
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult = createConversionSequence(
- rewriter, getLoc(), reachingDefinition, getResult().getType());
+ Value newResult =
+ createConversionSequence(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -194,13 +282,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
- areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+ areConversionCompatible(dataLayout, slot.elemType,
+ getValue().getType()) &&
!getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -747,8 +837,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +892,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -1059,8 +1150,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1074,7 +1165,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1109,7 +1201,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1123,7 +1216,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1159,8 +1253,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1174,7 +1268,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value memref::LoadOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
-Value memref::StoreOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return getValue();
}
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
public:
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info,
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics);
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
+ const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
};
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics)
: slot(slot), allocator(allocator), rewriter(rewriter),
- dominance(dominance), info(std::move(info)), statistics(statistics) {
+ dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+ statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter);
+ Value stored = memOp.getStored(slot, rewriter, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], rewriter,
- reachingDef) == DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+ dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
- MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
std::move(*info), statistics)
.promoteSlot();
promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..644d30f9f9f133 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
// -----
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- %1 = llvm.mlir.constant(0 : i32) : i32
- // CHECK: = llvm.alloca
- %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
- %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
- llvm.return %3 : i16
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @merge_point_cycle
llvm.func @merge_point_cycle() {
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_
diff erent_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
+ // CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
llvm.return %2 : i16
@@ -902,10 +889,10 @@ llvm.func @load_smaller_int() -> i16 {
// -----
-// CHECK-LABEL: @load_
diff erent_type_smaller
-llvm.func @load_
diff erent_type_smaller() -> f32 {
+// CHECK-LABEL: @load_
diff erent_type_same_size
+llvm.func @load_
diff erent_type_same_size() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
+ // CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
llvm.return %2 : f32
@@ -942,4 +929,121 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
llvm.return %2 : !llvm.ptr<2>
}
+
+ // CHECK-LABEL: @load_ptr_addrspace_cast_
diff erent_size2
+ llvm.func @load_ptr_addrspace_cast_
diff erent_size2() -> !llvm.ptr<1> {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.ptr<2> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<1>
+ llvm.return %2 : !llvm.ptr<1>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_int_type
+llvm.func @load_smaller_int_type() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i32
+ // CHECK: %[[RES:.*]] = llvm.trunc %{{.*}} : i64 to i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @load_smaller_int_type_big_endian
+ llvm.func @load_smaller_int_type_big_endian() -> i8 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i8
+ // CHECK: %[[SHIFT_WIDTH:.*]] = llvm.mlir.constant(56 : i64) : i64
+ // CHECK: %[[SHIFT:.*]] = llvm.lshr %{{.*}}, %[[SHIFT_WIDTH]]
+ // CHECK: %[[RES:.*]] = llvm.trunc %[[SHIFT]] : i64 to i8
+ // CHECK: llvm.return %[[RES]] : i8
+ llvm.return %2 : i8
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @load_
diff erent_type_smaller
+llvm.func @load_
diff erent_type_smaller() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> f32
+ // CHECK: %[[TRUNC:.*]] = llvm.trunc %{{.*}} : i64 to i32
+ // CHECK: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
+ // CHECK: llvm.return %[[RES]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_float_type
+llvm.func @load_smaller_float_type() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x f64 : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> f32
+ // CHECK: %[[CAST:.*]] = llvm.bitcast %{{.*}} : f64 to i64
+ // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[CAST]] : i64 to i32
+ // CHECK: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
+ // CHECK: llvm.return %[[RES]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_first_vector_elem
+llvm.func @load_first_vector_elem() -> i16 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x vector<4xi16> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i16
+ // CHECK: %[[TRUNC:.*]] = llvm.bitcast %{{.*}} : vector<4xi16> to i64
+ // CHECK: %[[RES:.*]] = llvm.trunc %[[TRUNC]] : i64 to i16
+ // CHECK: llvm.return %[[RES]] : i16
+ llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @load_first_llvm_vector_elem
+llvm.func @load_first_llvm_vector_elem() -> i16 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.vec<4 x ptr> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i16
+ llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @scalable_vector
+llvm.func @scalable_vector() -> i16 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x vector<[4]xi16> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i16
+ llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @scalable_llvm_vector
+llvm.func @scalable_llvm_vector() -> i16 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.vec<? x 4 x ppc_fp128> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 : !llvm.ptr -> i16
+ llvm.return %2 : i16
}
More information about the Mlir-commits
mailing list