[llvm-branch-commits] [mlir] [MLIR][LLVM][SROA] Support incorrectly typed memory accesses (PR #85813)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Mar 19 09:26:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-memref

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access.

This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary.

Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases.

---

Patch is 33.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85813.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+4-2) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h (-12) 
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.h (+2-2) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+98-13) 
- (modified) mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp (-101) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+2-5) 
- (modified) mlir/test/Dialect/LLVMIR/sroa.mlir (+91) 
- (modified) mlir/test/Dialect/LLVMIR/type-consistency.mlir (+12-128) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b523374f6c06b5..f8f9264b3889be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -323,7 +323,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
 }
 
 def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
   dag args = (ins LLVM_AnyPointer:$addr,
@@ -402,7 +403,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
 }
 
 def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
   dag args = (ins LLVM_LoadableType:$value,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index b32ac56d7079c6..cacb241bfd7a10 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -29,18 +29,6 @@ namespace LLVM {
 /// interpret pointee types as consistently as possible.
 std::unique_ptr<Pass> createTypeConsistencyPass();
 
-/// Transforms uses of pointers to a whole struct to uses of pointers to the
-/// first element of a struct. This is achieved by inserting a GEP to the first
-/// element when possible.
-template <class User>
-class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
-public:
-  using OpRewritePattern<User>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(User user,
-                                PatternRewriter &rewriter) const override;
-};
-
 /// Canonicalizes GEPs of which the base type and the pointer's type hint do not
 /// match. This is done by replacing the original GEP into a GEP with the type
 /// hint as a base type when an element of the hinted type aligns with the
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 56e5e96aecd13c..87db1aaf39dea2 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -26,8 +26,8 @@ struct MemorySlot {
 
 /// Memory slot attached with information about its destructuring procedure.
 struct DestructurableMemorySlot : public MemorySlot {
-  /// Maps an index within the memory slot to the type of the pointer that
-  /// will be generated to access the element directly.
+  /// Maps an index within the memory slot to the element type of the pointer
+  /// that will be generated to access the element directly.
   DenseMap<Attribute, Type> elementPtrs;
 };
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 00b4559658fd4d..f9662789025764 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,8 @@
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
   if (!destructuredType)
     return {};
 
-  DenseMap<Attribute, Type> allocaTypeMap;
-  for (Attribute index : llvm::make_first_range(destructuredType.value()))
-    allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
-
-  return {
-      DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
+  return {DestructurableMemorySlot{{getResult(), getElemType()},
+                                   *destructuredType}};
 }
 
 DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
   return DeletionKind::Delete;
 }
 
+/// Checks if `slot` can be accessed through the provided access type.
+static bool isValidAccessType(const MemorySlot &slot, Type accessType,
+                              const DataLayout &dataLayout) {
+  return dataLayout.getTypeSize(accessType) <=
+         dataLayout.getTypeSize(slot.elemType);
+}
+
 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
     const DataLayout &dataLayout) {
-  return success(getAddr() != slot.ptr || getType() == slot.elemType);
+  return success(getAddr() != slot.ptr ||
+                 isValidAccessType(slot, getType(), dataLayout));
 }
 
 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
     const DataLayout &dataLayout) {
   return success(getAddr() != slot.ptr ||
-                 getValue().getType() == slot.elemType);
+                 isValidAccessType(slot, getValue().getType(), dataLayout));
+}
+
+/// Returns the subslot's type at the requested index.
+static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
+                           Attribute index) {
+  auto subelementIndexMap =
+      slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+  if (!subelementIndexMap)
+    return {};
+  assert(!subelementIndexMap->empty());
+
+  // Note: Returns a null-type when no entry was found.
+  return subelementIndexMap->lookup(index);
+}
+
+bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+                             SmallPtrSetImpl<Attribute> &usedIndices,
+                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                             const DataLayout &dataLayout) {
+  if (getVolatile_())
+    return false;
+
+  // A load always accesses the first element of the destructured slot.
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  Type subslotType = getTypeAtIndex(slot, index);
+  if (!subslotType)
+    return false;
+
+  // The access can only be replaced when the subslot is read within its bounds.
+  if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
+    return false;
+
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
+                                  DenseMap<Attribute, MemorySlot> &subslots,
+                                  RewriterBase &rewriter,
+                                  const DataLayout &dataLayout) {
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  auto it = subslots.find(index);
+  assert(it != subslots.end());
+
+  rewriter.modifyOpInPlace(
+      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  return DeletionKind::Keep;
+}
+
+bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+                              SmallPtrSetImpl<Attribute> &usedIndices,
+                              SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                              const DataLayout &dataLayout) {
+  if (getVolatile_())
+    return false;
+
+  // A load always accesses the first element of the destructured slot.
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  Type subslotType = getTypeAtIndex(slot, index);
+  if (!subslotType)
+    return false;
+
+  // The access can only be replaced when the subslot is read within its bounds.
+  if (dataLayout.getTypeSize(getValue().getType()) >
+      dataLayout.getTypeSize(subslotType))
+    return false;
+
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
+                                   DenseMap<Attribute, MemorySlot> &subslots,
+                                   RewriterBase &rewriter,
+                                   const DataLayout &dataLayout) {
+  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+  auto it = subslots.find(index);
+  assert(it != subslots.end());
+
+  rewriter.modifyOpInPlace(
+      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  return DeletionKind::Keep;
 }
 
 //===----------------------------------------------------------------------===//
@@ -384,16 +468,17 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
   // dynamic indices can never be properly rewired.
   if (!getDynamicIndices().empty())
     return false;
+  //// TODO: This is not necessary, I think.
+  // if (slot.elemType != getElemType())
+  //   return false;
   Type reachedType = getResultPtrElementType();
   if (!reachedType || getIndices().size() < 2)
     return false;
   auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
   if (!firstLevelIndex)
     return false;
-  assert(slot.elementPtrs.contains(firstLevelIndex));
-  if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
-    return false;
   mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+  assert(slot.elementPtrs.contains(firstLevelIndex));
   usedIndices.insert(firstLevelIndex);
   return true;
 }
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b25c831bc7172a..3d700fe94e3b9c 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -49,104 +49,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
                         layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
 }
 
-//===----------------------------------------------------------------------===//
-// AddFieldGetterToStructDirectUse
-//===----------------------------------------------------------------------===//
-
-/// Gets the type of the first subelement of `type` if `type` is destructurable,
-/// nullptr otherwise.
-static Type getFirstSubelementType(Type type) {
-  auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
-  if (!destructurable)
-    return nullptr;
-
-  Type subelementType = destructurable.getTypeAtIndex(
-      IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
-  if (subelementType)
-    return subelementType;
-
-  return nullptr;
-}
-
-/// Extracts a pointer to the first field of an `elemType` from the address
-/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
-/// instead.
-template <class MemOp>
-static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
-                                   Type elemType) {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  rewriter.setInsertionPointAfterValue(op.getAddr());
-  SmallVector<GEPArg> firstTypeIndices{0, 0};
-
-  Value properPtr = rewriter.create<GEPOp>(
-      op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
-      op.getAddr(), firstTypeIndices);
-
-  rewriter.modifyOpInPlace(op,
-                           [&]() { op.getAddrMutable().assign(properPtr); });
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
-    LoadOp load, PatternRewriter &rewriter) const {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  Type inconsistentElementType =
-      isElementTypeInconsistent(load.getAddr(), load.getType());
-  if (!inconsistentElementType)
-    return failure();
-  Type firstType = getFirstSubelementType(inconsistentElementType);
-  if (!firstType)
-    return failure();
-  DataLayout layout = DataLayout::closest(load);
-  if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
-    return failure();
-
-  insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
-
-  // If the load does not use the first type but a type that can be casted from
-  // it, add a bitcast and change the load type.
-  if (firstType != load.getResult().getType()) {
-    rewriter.setInsertionPointAfterValue(load.getResult());
-    BitcastOp bitcast = rewriter.create<BitcastOp>(
-        load->getLoc(), load.getResult().getType(), load.getResult());
-    rewriter.modifyOpInPlace(load,
-                             [&]() { load.getResult().setType(firstType); });
-    rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
-                                  bitcast);
-  }
-
-  return success();
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
-    StoreOp store, PatternRewriter &rewriter) const {
-  PatternRewriter::InsertionGuard guard(rewriter);
-
-  Type inconsistentElementType =
-      isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
-  if (!inconsistentElementType)
-    return failure();
-  Type firstType = getFirstSubelementType(inconsistentElementType);
-  if (!firstType)
-    return failure();
-
-  DataLayout layout = DataLayout::closest(store);
-  // Check that the first field has the right type or can at least be bitcast
-  // to the right type.
-  if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
-    return failure();
-
-  insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
-
-  rewriter.modifyOpInPlace(
-      store, [&]() { store.getValueMutable().assign(store.getValue()); });
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // CanonicalizeAlignedGep
 //===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
     : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
   void runOnOperation() override {
     RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
-    rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
-        &getContext());
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
     rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
     rewritePatterns.add<BitcastStores>(&getContext());
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 7be4056fb2fc80..6c5250d527ade8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
   if (!destructuredType)
     return {};
 
-  DenseMap<Attribute, Type> indexMap;
-  for (auto const &[index, type] : *destructuredType)
-    indexMap.insert({index, MemRefType::get({}, type)});
-
-  return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+  return {
+      DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
 }
 
 DenseMap<Attribute, MemorySlot>
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index 02d25f27f978a6..73666afaf66b27 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -215,3 +215,94 @@ llvm.func @no_nested_dynamic_indexing(%arg: i32) -> i32 {
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %3 : i32
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field
+llvm.func @store_first_field(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: llvm.store %{{.*}}, %[[ALLOCA]] : i32
+  llvm.store %arg, %1 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field_different_type
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_first_field_different_type(%arg: f32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+  llvm.store %arg, %1 : f32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_sub_field
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_sub_field(%arg: f32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32)> : (i32) -> !llvm.ptr
+  // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+  llvm.store %arg, %1 : f32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field
+llvm.func @load_first_field() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
+  %2 = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field_different_type
+llvm.func @load_first_field_different_type() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
+  %2 = llvm.load %1 : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[RES]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_sub_field
+llvm.func @load_sub_field() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 : (i32) -> !llvm.ptr
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.struct<(i64, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+  %res = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_store_type_mismatch
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_store_type_mismatch(%arg: vector<4xi32>) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x vector<4xf32>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+  // CHECK: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+  llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 021151b929d8e2..a6176142f17463 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -26,63 +26,6 @@ llvm.func @same_address_keep_inbounds(%arg: i32) {
 
 // -----
 
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field
-llvm.func @struct_store_instead_of_first_field(%arg: i32) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32
-  llvm.store %arg, %1 : i32, !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
-  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
-  // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
-  // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
-  // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32
-  llvm.store %arg, %1 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85813


More information about the llvm-branch-commits mailing list