[Mlir-commits] [mlir] 0289ae5 - [MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 22 00:31:21 PDT 2024


Author: Christian Ulmann
Date: 2024-03-22T08:31:17+01:00
New Revision: 0289ae51aa375fd297f1d03d27ff517223e5e998

URL: https://github.com/llvm/llvm-project/commit/0289ae51aa375fd297f1d03d27ff517223e5e998
DIFF: https://github.com/llvm/llvm-project/commit/0289ae51aa375fd297f1d03d27ff517223e5e998.diff

LOG: [MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813)

This commit relaxes the assumption of type consistency for LLVM dialect
load and store operations in SROA. Instead, there is now a check that
loads and stores are in the bounds specified by the sub-slot they
access.

This commit additionally removes the corresponding patterns from the
type consistency pass, as they are no longer necessary.

Note: It will be necessary to extend Mem2Reg with the logic for
differently sized accesses as well. This is non-the-less a strict
upgrade for productive flows, as the type consistency pass can produce
invalid IR for some odd cases.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
    mlir/test/Dialect/LLVMIR/sroa.mlir
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b523374f6c06b5..f8f9264b3889be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -323,7 +323,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
 }
 
 def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
   dag args = (ins LLVM_AnyPointer:$addr,
@@ -402,7 +403,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
 }
 
 def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
   dag args = (ins LLVM_LoadableType:$value,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index b32ac56d7079c6..cacb241bfd7a10 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -29,18 +29,6 @@ namespace LLVM {
 /// interpret pointee types as consistently as possible.
 std::unique_ptr<Pass> createTypeConsistencyPass();
 
-/// Transforms uses of pointers to a whole struct to uses of pointers to the
-/// first element of a struct. This is achieved by inserting a GEP to the first
-/// element when possible.
-template <class User>
-class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
-public:
-  using OpRewritePattern<User>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(User user,
-                                PatternRewriter &rewriter) const override;
-};
-
 /// Canonicalizes GEPs of which the base type and the pointer's type hint do not
 /// match. This is done by replacing the original GEP into a GEP with the type
 /// hint as a base type when an element of the hinted type aligns with the

diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 56e5e96aecd13c..aaa261be6553f3 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -26,8 +26,7 @@ struct MemorySlot {
 
 /// Memory slot attached with information about its destructuring procedure.
 struct DestructurableMemorySlot : public MemorySlot {
-  /// Maps an index within the memory slot to the type of the pointer that
-  /// will be generated to access the element directly.
+  /// Maps an index within the memory slot to the corresponding subelement type.
   DenseMap<Attribute, Type> elementPtrs;
 };
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 00b4559658fd4d..0ef1d105aca6cb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,8 @@
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
   if (!destructuredType)
     return {};
 
-  DenseMap<Attribute, Type> allocaTypeMap;
-  for (Attribute index : llvm::make_first_range(destructuredType.value()))
-    allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
-
-  return {
-      DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
+  return {DestructurableMemorySlot{{getResult(), getElemType()},
+                                   *destructuredType}};
 }
 
 DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
   return DeletionKind::Delete;
 }
 
+/// Checks if `slot` can be accessed through the provided access type.
+static bool isValidAccessType(const MemorySlot &slot, Type accessType,
+                              const DataLayout &dataLayout) {
+  return dataLayout.getTypeSize(accessType) <=
+         dataLayout.getTypeSize(slot.elemType);
+}
+
 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
     const DataLayout &dataLayout) {
-  return success(getAddr() != slot.ptr || getType() == slot.elemType);
+  return success(getAddr() != slot.ptr ||
+                 isValidAccessType(slot, getType(), dataLayout));
 }
 
 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
     const DataLayout &dataLayout) {
   return success(getAddr() != slot.ptr ||
-                 getValue().getType() == slot.elemType);
+                 isValidAccessType(slot, getValue().getType(), dataLayout));
+}
+
+/// Returns the subslot's type at the requested index.
+static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
+                           Attribute index) {
+  auto subelementIndexMap =
+      slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+  if (!subelementIndexMap)
+    return {};
+  assert(!subelementIndexMap->empty());
+
+  // Note: Returns a null-type when no entry was found.
+  return subelementIndexMap->lookup(index);
+}
+
+bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+                             SmallPtrSetImpl<Attribute> &usedIndices,
+                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                             const DataLayout &dataLayout) {
+  if (getVolatile_())
+    return false;
+
+  // A load always accesses the first element of the destructured slot.
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  Type subslotType = getTypeAtIndex(slot, index);
+  if (!subslotType)
+    return false;
+
+  // The access can only be replaced when the subslot is read within its bounds.
+  if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
+    return false;
+
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
+                                  DenseMap<Attribute, MemorySlot> &subslots,
+                                  RewriterBase &rewriter,
+                                  const DataLayout &dataLayout) {
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  auto it = subslots.find(index);
+  assert(it != subslots.end());
+
+  rewriter.modifyOpInPlace(
+      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  return DeletionKind::Keep;
+}
+
+bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+                              SmallPtrSetImpl<Attribute> &usedIndices,
+                              SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                              const DataLayout &dataLayout) {
+  if (getVolatile_())
+    return false;
+
+  // A store always accesses the first element of the destructured slot.
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  Type subslotType = getTypeAtIndex(slot, index);
+  if (!subslotType)
+    return false;
+
+  // The access can only be replaced when the subslot is read within its bounds.
+  if (dataLayout.getTypeSize(getValue().getType()) >
+      dataLayout.getTypeSize(subslotType))
+    return false;
+
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
+                                   DenseMap<Attribute, MemorySlot> &subslots,
+                                   RewriterBase &rewriter,
+                                   const DataLayout &dataLayout) {
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  auto it = subslots.find(index);
+  assert(it != subslots.end());
+
+  rewriter.modifyOpInPlace(
+      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  return DeletionKind::Keep;
 }
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
   auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
   if (!firstLevelIndex)
     return false;
-  assert(slot.elementPtrs.contains(firstLevelIndex));
-  if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
-    return false;
   mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+  assert(slot.elementPtrs.contains(firstLevelIndex));
   usedIndices.insert(firstLevelIndex);
   return true;
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b25c831bc7172a..3d700fe94e3b9c 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -49,104 +49,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
                         layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
 }
 
-//===----------------------------------------------------------------------===//
-// AddFieldGetterToStructDirectUse
-//===----------------------------------------------------------------------===//
-
-/// Gets the type of the first subelement of `type` if `type` is destructurable,
-/// nullptr otherwise.
-static Type getFirstSubelementType(Type type) {
-  auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
-  if (!destructurable)
-    return nullptr;
-
-  Type subelementType = destructurable.getTypeAtIndex(
-      IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
-  if (subelementType)
-    return subelementType;
-
-  return nullptr;
-}
-
-/// Extracts a pointer to the first field of an `elemType` from the address
-/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
-/// instead.
-template <class MemOp>
-static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
-                                   Type elemType) {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  rewriter.setInsertionPointAfterValue(op.getAddr());
-  SmallVector<GEPArg> firstTypeIndices{0, 0};
-
-  Value properPtr = rewriter.create<GEPOp>(
-      op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
-      op.getAddr(), firstTypeIndices);
-
-  rewriter.modifyOpInPlace(op,
-                           [&]() { op.getAddrMutable().assign(properPtr); });
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
-    LoadOp load, PatternRewriter &rewriter) const {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  Type inconsistentElementType =
-      isElementTypeInconsistent(load.getAddr(), load.getType());
-  if (!inconsistentElementType)
-    return failure();
-  Type firstType = getFirstSubelementType(inconsistentElementType);
-  if (!firstType)
-    return failure();
-  DataLayout layout = DataLayout::closest(load);
-  if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
-    return failure();
-
-  insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
-
-  // If the load does not use the first type but a type that can be casted from
-  // it, add a bitcast and change the load type.
-  if (firstType != load.getResult().getType()) {
-    rewriter.setInsertionPointAfterValue(load.getResult());
-    BitcastOp bitcast = rewriter.create<BitcastOp>(
-        load->getLoc(), load.getResult().getType(), load.getResult());
-    rewriter.modifyOpInPlace(load,
-                             [&]() { load.getResult().setType(firstType); });
-    rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
-                                  bitcast);
-  }
-
-  return success();
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
-    StoreOp store, PatternRewriter &rewriter) const {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  Type inconsistentElementType =
-      isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
-  if (!inconsistentElementType)
-    return failure();
-  Type firstType = getFirstSubelementType(inconsistentElementType);
-  if (!firstType)
-    return failure();
-
-  DataLayout layout = DataLayout::closest(store);
-  // Check that the first field has the right type or can at least be bitcast
-  // to the right type.
-  if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
-    return failure();
-
-  insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
-
-  rewriter.modifyOpInPlace(
-      store, [&]() { store.getValueMutable().assign(store.getValue()); });
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // CanonicalizeAlignedGep
 //===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
     : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
   void runOnOperation() override {
     RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
-    rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
-        &getContext());
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
     rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
     rewritePatterns.add<BitcastStores>(&getContext());

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 7be4056fb2fc80..6c5250d527ade8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
   if (!destructuredType)
     return {};
 
-  DenseMap<Attribute, Type> indexMap;
-  for (auto const &[index, type] : *destructuredType)
-    indexMap.insert({index, MemRefType::get({}, type)});
-
-  return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+  return {
+      DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
 }
 
 DenseMap<Attribute, MemorySlot>

diff  --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index 02d25f27f978a6..ca49b1298b0e91 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -215,3 +215,93 @@ llvm.func @no_nested_dynamic_indexing(%arg: i32) -> i32 {
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %3 : i32
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field
+llvm.func @store_first_field(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: llvm.store %{{.*}}, %[[ALLOCA]] : i32
+  llvm.store %arg, %1 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field_
diff erent_type
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_first_field_
diff erent_type(%arg: f32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+  llvm.store %arg, %1 : f32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_sub_field
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_sub_field(%arg: f32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+  llvm.store %arg, %1 : f32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field
+llvm.func @load_first_field() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
+  %2 = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field_
diff erent_type
+llvm.func @load_first_field_
diff erent_type() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
+  %2 = llvm.load %1 : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[RES]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_sub_field
+llvm.func @load_sub_field() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 : (i32) -> !llvm.ptr
+  %1 = llvm.alloca %0 x !llvm.struct<(i64, i32)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+  %res = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_store_type_mismatch
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_store_type_mismatch(%arg: vector<4xi32>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x vector<4xf32>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+  // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+  llvm.return
+}

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 021151b929d8e2..a6176142f17463 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -26,63 +26,6 @@ llvm.func @same_address_keep_inbounds(%arg: i32) {
 
 // -----
 
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field
-llvm.func @struct_store_instead_of_first_field(%arg: i32) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32
-  llvm.store %arg, %1 : i32, !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
-  // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32
-  llvm.store %arg, %1 : f32, !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field
-llvm.func @struct_load_instead_of_first_field() -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK: %[[RES:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
-  %2 = llvm.load %1 : !llvm.ptr -> i32
-  // CHECK: llvm.return %[[RES]] : i32
-  llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field_same_size
-llvm.func @struct_load_instead_of_first_field_same_size() -> f32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK: %[[LOADED:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
-  // CHECK: %[[RES:.*]] = llvm.bitcast %[[LOADED]] : i32 to f32
-  %2 = llvm.load %1 : !llvm.ptr -> f32
-  // CHECK: llvm.return %[[RES]] : f32
-  llvm.return %2 : f32
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @index_in_final_padding
 llvm.func @index_in_final_padding(%arg: i32) {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -135,22 +78,6 @@ llvm.func @index_not_in_padding_because_packed(%arg: i16) {
 
 // -----
 
-// CHECK-LABEL: llvm.func @index_to_struct
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-llvm.func @index_to_struct(%arg: i32) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
-  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"bar", (i32, i32)>
-  %7 = llvm.getelementptr %1[4] : (!llvm.ptr) -> !llvm.ptr, i8
-  // CHECK: llvm.store %[[ARG]], %[[GEP1]]
-  llvm.store %arg, %7 : i32, !llvm.ptr
-  llvm.return
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @no_crash_on_negative_gep_index
 llvm.func @no_crash_on_negative_gep_index() {
   %0 = llvm.mlir.constant(1.000000e+00 : f16) : f16
@@ -175,10 +102,9 @@ llvm.func @coalesced_store_ints(%arg: i64) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
@@ -225,11 +151,9 @@ llvm.func @coalesced_store_floats(%arg: i64) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (f32, f32)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
-  // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
@@ -298,10 +222,9 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)>
@@ -328,9 +251,8 @@ llvm.func @vector_write_split(%arg: vector<4xi32>) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
   // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
-  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+  // CHECK: llvm.store %[[EXTRACT]], %[[ALLOCA]] : i32, !llvm.ptr
 
   // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
@@ -405,36 +327,6 @@ llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {
 
 // -----
 
-// CHECK-LABEL: llvm.func @type_consistent_vector_store
-// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
-llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)>
-  // CHECK: llvm.store %[[ARG]], %[[GEP]]
-  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type
-// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
-llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
-  // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
-  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
-  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
-  llvm.return
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @bitcast_insertion
 // CHECK-SAME: %[[ARG:.*]]: i32
 llvm.func @bitcast_insertion(%arg: i32) {
@@ -478,10 +370,9 @@ llvm.func @coalesced_store_ints_subaggregate(%arg: i64) {
   %3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
 
   // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[TOP_GEP]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i32, i32)>
@@ -520,10 +411,9 @@ llvm.func @overlapping_int_aggregate_store(%arg: i64) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
 
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
   // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
@@ -531,8 +421,7 @@ llvm.func @overlapping_int_aggregate_store(%arg: i64) {
 
   // Normal integer splitting of [[TRUNC]] follows:
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
-  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: llvm.store %{{.*}}, %[[TOP_GEP]]
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
   // CHECK: llvm.store %{{.*}}, %[[GEP]]
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
@@ -557,14 +446,12 @@ llvm.func @overlapping_vector_aggregate_store(%arg: vector<4 x i16>) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
   // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32]
-  // CHECK: llvm.store %[[EXTRACT]], %[[GEP]]
+  // CHECK: llvm.store %[[EXTRACT]], %[[ALLOCA]]
 
   // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32]
   // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
-  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
-  // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]
+  // CHECK: llvm.store %[[EXTRACT]], %[[GEP0]]
 
   // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32]
   // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
@@ -593,10 +480,9 @@ llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
 
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
   // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
@@ -604,8 +490,7 @@ llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
 
   // Normal integer splitting of [[TRUNC]] follows:
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
-  // CHECK: llvm.store %{{.*}}, %[[GEP]]
+  // CHECK: llvm.store %{{.*}}, %[[TOP_GEP]]
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
   // CHECK: llvm.store %{{.*}}, %[[GEP]]
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
@@ -651,10 +536,9 @@ llvm.func @coalesced_store_ints_array(%arg: i64) {
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x i32>
   %1 = llvm.alloca %0 x !llvm.array<2 x i32> : (i32) -> !llvm.ptr
 
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.array<2 x i32>


        


More information about the Mlir-commits mailing list