[Mlir-commits] [mlir] [MLIR][Mem2Reg][LLVM] Enhance partial load support (PR #89094)

Christian Ulmann llvmlistbot at llvm.org
Wed Apr 17 09:15:36 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/89094

>From 5e37eebaa3df3be5f6569320e2dcbcbfde2b4102 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 5 Apr 2024 08:37:31 +0000
Subject: [PATCH 1/3] [MLIR][Mem2Reg][LLVM] Enhance partial load support

This commit improves LLVM dialect's Mem2Reg interfaces to support
promotions of partial loads from larger memory slots. To support this,
the Mem2Reg interface methods are extended with additional data layout
parameters. The data layout is required to determine type sizes to
produce correct conversion sequences.
---
 .../mlir/Interfaces/MemorySlotInterfaces.td   |   6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 174 ++++++++++++++----
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |  14 +-
 mlir/lib/Transforms/Mem2Reg.cpp               |  17 +-
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         | 129 +++++++++++--
 5 files changed, 270 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::RewriterBase &":$rewriter,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
         Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
            "::mlir::RewriterBase &":$rewriter,
-           "::mlir::Value":$reachingDefinition)
+           "::mlir::Value":$reachingDefinition,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                              const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
-  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
-                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
-                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+  // Aggregate types are not bitcastable.
+  if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+    return false;
+
+  // LLVM vector types are only used for either pointers or target specific
+  // types. These types cannot be casted in the general case, thus the memory
+  // optimizations do not support them.
+  if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+    return false;
+
+  // Scalable types are not supported.
+  if (auto vectorType = dyn_cast<VectorType>(type))
+    return !vectorType.isScalable();
+  return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+                                    Type rhs) {
+  if (lhs == rhs)
+    return true;
+
+  // Aggregate types cannot be casted.
+  if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+    return false;
+  return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
 }
 
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+  auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+  return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
 /// Constructs operations that convert `inputValue` into a new value of type
 /// `targetType`. Assumes that this conversion is possible.
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
-                                      Value inputValue, Type targetType) {
-  if (inputValue.getType() == targetType)
-    return inputValue;
+                                      Value srcValue, Type targetType,
+                                      const DataLayout &dataLayout) {
+  // Get the types of the source and destination values.
+  Type srcType = srcValue.getType();
+
+  uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+  uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+  // Nothing has to be done if the types are already the same.
+  if (srcType == targetType)
+    return srcValue;
+
+  // The code below is currently not capable of handling aggregate types as it
+  // makes use of bitcasts. Aggregates cannot be bitcast.
+  // TODO: We should have a `LLVMAggregateType` base class to easily perform
+  // this `isa`.
+  if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+      isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+    return nullptr;
+
+  // In the special case of casting one pointer to another, we want to generate
+  // an address space cast. Bitcasts of pointers are not allowed and using
+  // pointer to integer conversions are not equivalent due to the loss or
+  // provenance.
+  if (isa<LLVM::LLVMPointerType>(targetType) &&
+      isa<LLVM::LLVMPointerType>(srcType)) {
+    // Abort the conversion if the pointers have different bitwidths.
+    if (srcTypeSize != targetTypeSize)
+      return nullptr;
+    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                        srcValue);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType) &&
-      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+  IntegerType valueSizeInteger =
+      rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+  Value replacement = srcValue;
+
+  // First, cast the value to a same-sized integer type.
+  if (isa<LLVM::LLVMPointerType>(srcType))
+    replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+                                                          replacement);
+  else if (replacement.getType() != valueSizeInteger)
+    replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+                                                         replacement);
+
+  // Truncate the integer if the size of the read is less than the value.
+  if (targetTypeSize != srcTypeSize) {
+    if (!isLittleEndian(dataLayout)) {
+      uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+      auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+      replacement =
+          rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+    }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+    replacement = rewriter.create<LLVM::TruncOp>(
+        loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+        replacement);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+  // Now cast the integer to the actual destination type if required.
+  if (isa<LLVM::LLVMPointerType>(targetType))
+    replacement =
+        rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+  else if (replacement.getType() != targetType)
+    replacement =
+        rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
 
-  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                      inputValue);
+  return replacement;
 }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
-  return createConversionSequence(rewriter, getLoc(), getValue(),
-                                  slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               const DataLayout &dataLayout) {
+  return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+                                  dataLayout);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+         areConversionCompatible(dataLayout, getResult().getType(),
+                                 slot.elemType) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult = createConversionSequence(
-      rewriter, getLoc(), reachingDefinition, getResult().getType());
+  Value newResult =
+      createConversionSequence(rewriter, getLoc(), reachingDefinition,
+                               getResult().getType(), dataLayout);
   rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          getValue() != slot.ptr &&
-         areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+         areConversionCompatible(dataLayout, slot.elemType,
+                                 getValue().getType()) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
   return getDst() == slot.ptr;
 }
 
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemsetOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
+                                      RewriterBase &rewriter,
+                                      const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value memref::LoadOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
   rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
   return getMemRef() == slot.ptr;
 }
 
-Value memref::StoreOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return getValue();
 }
 
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
 public:
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
                      RewriterBase &rewriter, DominanceInfo &dominance,
-                     MemorySlotPromotionInfo info,
+                     const DataLayout &dataLayout, MemorySlotPromotionInfo info,
                      const Mem2RegStatistics &statistics);
 
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
   DenseMap<PromotableMemOpInterface, Value> reachingDefs;
   DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
   DominanceInfo &dominance;
+  const DataLayout &dataLayout;
   MemorySlotPromotionInfo info;
   const Mem2RegStatistics &statistics;
 };
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
     RewriterBase &rewriter, DominanceInfo &dominance,
-    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+    const Mem2RegStatistics &statistics)
     : slot(slot), allocator(allocator), rewriter(rewriter),
-      dominance(dominance), info(std::move(info)), statistics(statistics) {
+      dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+      statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
       if (memOp.storesTo(slot)) {
         rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter);
+        Value stored = memOp.getStored(slot, rewriter, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
 
       rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], rewriter,
-              reachingDef) == DeletionKind::Delete)
+              slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+              dataLayout) == DeletionKind::Delete)
         toErase.push_back(toPromote);
       if (toPromoteMemOp.storesTo(slot))
         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
       MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
-        MemorySlotPromoter(slot, allocator, rewriter, dominance,
+        MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
                            std::move(*info), statistics)
             .promoteSlot();
         promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  %1 = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: = llvm.alloca
-  %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
-  %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
-  llvm.return %3 : i16
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @merge_point_cycle
 llvm.func @merge_point_cycle() {
   // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
 // CHECK-LABEL: @load_smaller_int
 llvm.func @load_smaller_int() -> i16 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
+  // CHECK-NOT: llvm.alloca
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
   llvm.return %2 : i16
@@ -902,10 +889,10 @@ llvm.func @load_smaller_int() -> i16 {
 
 // -----
 
-// CHECK-LABEL: @load_different_type_smaller
-llvm.func @load_different_type_smaller() -> f32 {
+// CHECK-LABEL: @load_different_type_same_size
+llvm.func @load_different_type_same_size() -> f32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
+  // CHECK-NOT: llvm.alloca
   %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
   llvm.return %2 : f32
@@ -943,3 +930,111 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     llvm.return %2 : !llvm.ptr<2>
   }
 }
+
+// -----
+
+// CHECK-LABEL: @load_smaller_int_type
+llvm.func @load_smaller_int_type() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: %[[RES:.*]] = llvm.trunc %{{.*}} : i64 to i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+  // CHECK-LABEL: @load_smaller_int_type_big_endian
+  llvm.func @load_smaller_int_type_big_endian() -> i8 {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NOT: llvm.alloca
+    %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+    %2 = llvm.load %1 : !llvm.ptr -> i8
+    // CHECK: %[[SHIFT_WIDTH:.*]] = llvm.mlir.constant(56 : i64) : i64
+    // CHECK: %[[SHIFT:.*]] = llvm.lshr %{{.*}}, %[[SHIFT_WIDTH]]
+    // CHECK: %[[RES:.*]] = llvm.trunc %[[SHIFT]] : i64 to i8
+    // CHECK: llvm.return %[[RES]] : i8
+    llvm.return %2 : i8
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @load_different_type_smaller
+llvm.func @load_different_type_smaller() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> f32
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %{{.*}} : i64 to i32
+  // CHECK: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
+  // CHECK: llvm.return %[[RES]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_float_type
+llvm.func @load_smaller_float_type() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x f64 : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> f32
+  // CHECK: %[[CAST:.*]] = llvm.bitcast %{{.*}} : f64 to i64
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[CAST]] : i64 to i32
+  // CHECK: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
+  // CHECK: llvm.return %[[RES]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_first_vector_elem
+llvm.func @load_first_vector_elem() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x vector<4xi16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> i16
+  // CHECK: %[[TRUNC:.*]] = llvm.bitcast %{{.*}} : vector<4xi16> to i64
+  // CHECK: %[[RES:.*]] = llvm.trunc %[[TRUNC]] : i64 to i16
+  // CHECK: llvm.return %[[RES]] : i16
+  llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @load_first_llvm_vector_elem
+llvm.func @load_first_llvm_vector_elem() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.vec<4 x ptr> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> i16
+  llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @scalable_vector
+llvm.func @scalable_vector() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x vector<[4]xi16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> i16
+  llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @scalable_llvm_vector
+llvm.func @scalable_llvm_vector() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.vec<? x 4 x ppc_fp128> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 : !llvm.ptr -> i16
+  llvm.return %2 : i16
+}

>From a0a1619f13210fd268d94d96e033f8fab2cc8201 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 17 Apr 2024 16:07:43 +0000
Subject: [PATCH 2/3] improve inline comments + outdated code removal

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 23 +++++--------------
 1 file changed, 6 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 0c4d019f5654ac..6c24706f6a3633 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -148,7 +148,6 @@ static bool areConversionCompatible(const DataLayout &layout, Type lhs,
   if (lhs == rhs)
     return true;
 
-  // Aggregate types cannot be casted.
   if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
     return false;
   return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
@@ -168,8 +167,10 @@ constexpr const static uint64_t kBitsInByte = 8;
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
                                       Value srcValue, Type targetType,
                                       const DataLayout &dataLayout) {
-  // Get the types of the source and destination values.
+  // Get the types of the source and target values.
   Type srcType = srcValue.getType();
+  assert(areConversionCompatible(dataLayout, targetType, srcType) &&
+         "expected that the compatibility was checked before");
 
   uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
   uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
@@ -178,26 +179,14 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
   if (srcType == targetType)
     return srcValue;
 
-  // The code below is currently not capable of handling aggregate types as it
-  // makes use of bitcasts. Aggregates cannot be bitcast.
-  // TODO: We should have a `LLVMAggregateType` base class to easily perform
-  // this `isa`.
-  if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
-      isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
-    return nullptr;
-
   // In the special case of casting one pointer to another, we want to generate
   // an address space cast. Bitcasts of pointers are not allowed and using
   // pointer to integer conversions are not equivalent due to the loss or
   // provenance.
   if (isa<LLVM::LLVMPointerType>(targetType) &&
-      isa<LLVM::LLVMPointerType>(srcType)) {
-    // Abort the conversion if the pointers have different bitwidths.
-    if (srcTypeSize != targetTypeSize)
-      return nullptr;
+      isa<LLVM::LLVMPointerType>(srcType))
     return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
                                                         srcValue);
-  }
 
   IntegerType valueSizeInteger =
       rewriter.getIntegerType(srcTypeSize * kBitsInByte);
@@ -211,7 +200,7 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
     replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
                                                          replacement);
 
-  // Truncate the integer if the size of the read is less than the value.
+  // Truncate the integer if the size of the target is less than the value.
   if (targetTypeSize != srcTypeSize) {
     if (!isLittleEndian(dataLayout)) {
       uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
@@ -226,7 +215,7 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
         replacement);
   }
 
-  // Now cast the integer to the actual destination type if required.
+  // Now cast the integer to the actual target type if required.
   if (isa<LLVM::LLVMPointerType>(targetType))
     replacement =
         rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);

>From 55399a6d6f49d40e3d9eee5feedd5bfe7a8d72a3 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 17 Apr 2024 16:15:24 +0000
Subject: [PATCH 3/3] add stronger check for pointer types

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 6 ++++++
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         | 9 +++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 6c24706f6a3633..fa67f64cc77144 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -150,6 +150,12 @@ static bool areConversionCompatible(const DataLayout &layout, Type lhs,
 
   if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
     return false;
+
+  // Pointer casts will only be sane when the bitsize of both pointer types is
+  // the same.
+  if (isa<LLVM::LLVMPointerType>(lhs) && isa<LLVM::LLVMPointerType>(rhs))
+    return layout.getTypeSize(lhs) == layout.getTypeSize(rhs);
+
   return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index e724c2e8679501..644d30f9f9f133 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -929,6 +929,15 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
     llvm.return %2 : !llvm.ptr<2>
   }
+
+  // CHECK-LABEL: @load_ptr_addrspace_cast_different_size2
+  llvm.func @load_ptr_addrspace_cast_different_size2() -> !llvm.ptr<1> {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: llvm.alloca
+    %1 = llvm.alloca %0 x !llvm.ptr<2> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+    %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<1>
+    llvm.return %2 : !llvm.ptr<1>
+  }
 }
 
 // -----



More information about the Mlir-commits mailing list