[Mlir-commits] [mlir] [MLIR][LLVM][SROA] Support incorrectly typed memory accesses (PR #85813)
Christian Ulmann
llvmlistbot at llvm.org
Wed Mar 20 06:48:40 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/85813
>From 30ab58297047bfe7737946fa2eff16f224dd8c4e Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 19 Mar 2024 15:33:34 +0000
Subject: [PATCH 1/2] [MLIR][LLVM][SROA] Support incorrectly typed memory
accesses
This commit relaxes the assumption of type consistency for LLVM dialect
load and store operations in SROA. Instead, there is now a check that
loads and stores are in the bounds specified by the sub-slot they
access.
This commit additionally removes the corresponding patterns from the
type consistency pass, as they are no longer necessary.
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 6 +-
.../LLVMIR/Transforms/TypeConsistency.h | 12 --
.../mlir/Interfaces/MemorySlotInterfaces.h | 4 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 111 ++++++++++++--
.../LLVMIR/Transforms/TypeConsistency.cpp | 101 -------------
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 7 +-
mlir/test/Dialect/LLVMIR/sroa.mlir | 91 ++++++++++++
.../test/Dialect/LLVMIR/type-consistency.mlir | 140 ++----------------
8 files changed, 209 insertions(+), 263 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b523374f6c06b5..f8f9264b3889be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -323,7 +323,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
}
def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_AnyPointer:$addr,
@@ -402,7 +403,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
}
def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_LoadableType:$value,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index b32ac56d7079c6..cacb241bfd7a10 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -29,18 +29,6 @@ namespace LLVM {
/// interpret pointee types as consistently as possible.
std::unique_ptr<Pass> createTypeConsistencyPass();
-/// Transforms uses of pointers to a whole struct to uses of pointers to the
-/// first element of a struct. This is achieved by inserting a GEP to the first
-/// element when possible.
-template <class User>
-class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
-public:
- using OpRewritePattern<User>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(User user,
- PatternRewriter &rewriter) const override;
-};
-
/// Canonicalizes GEPs of which the base type and the pointer's type hint do not
/// match. This is done by replacing the original GEP into a GEP with the type
/// hint as a base type when an element of the hinted type aligns with the
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 56e5e96aecd13c..87db1aaf39dea2 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -26,8 +26,8 @@ struct MemorySlot {
/// Memory slot attached with information about its destructuring procedure.
struct DestructurableMemorySlot : public MemorySlot {
- /// Maps an index within the memory slot to the type of the pointer that
- /// will be generated to access the element directly.
+ /// Maps an index within the memory slot to the element type of the pointer
+ /// that will be generated to access the element directly.
DenseMap<Attribute, Type> elementPtrs;
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 00b4559658fd4d..f9662789025764 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,8 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> allocaTypeMap;
- for (Attribute index : llvm::make_first_range(destructuredType.value()))
- allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
-
- return {
- DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
+ return {DestructurableMemorySlot{{getResult(), getElemType()},
+ *destructuredType}};
}
DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
return DeletionKind::Delete;
}
+/// Checks if `slot` can be accessed through the provided access type.
+static bool isValidAccessType(const MemorySlot &slot, Type accessType,
+ const DataLayout &dataLayout) {
+ return dataLayout.getTypeSize(accessType) <=
+ dataLayout.getTypeSize(slot.elemType);
+}
+
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
- return success(getAddr() != slot.ptr || getType() == slot.elemType);
+ return success(getAddr() != slot.ptr ||
+ isValidAccessType(slot, getType(), dataLayout));
}
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
- getValue().getType() == slot.elemType);
+ isValidAccessType(slot, getValue().getType(), dataLayout));
+}
+
+/// Returns the subslot's type at the requested index.
+static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
+ Attribute index) {
+ auto subelementIndexMap =
+ slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+ if (!subelementIndexMap)
+ return {};
+ assert(!subelementIndexMap->empty());
+
+ // Note: Returns a null-type when no entry was found.
+ return subelementIndexMap->lookup(index);
+}
+
+bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A load always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
+}
+
+bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A load always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getValue().getType()) >
+ dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
}
//===----------------------------------------------------------------------===//
@@ -384,16 +468,17 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
// dynamic indices can never be properly rewired.
if (!getDynamicIndices().empty())
return false;
+ //// TODO: This is not necessary, I think.
+ // if (slot.elemType != getElemType())
+ // return false;
Type reachedType = getResultPtrElementType();
if (!reachedType || getIndices().size() < 2)
return false;
auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
if (!firstLevelIndex)
return false;
- assert(slot.elementPtrs.contains(firstLevelIndex));
- if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
- return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ assert(slot.elementPtrs.contains(firstLevelIndex));
usedIndices.insert(firstLevelIndex);
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b25c831bc7172a..3d700fe94e3b9c 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -49,104 +49,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}
-//===----------------------------------------------------------------------===//
-// AddFieldGetterToStructDirectUse
-//===----------------------------------------------------------------------===//
-
-/// Gets the type of the first subelement of `type` if `type` is destructurable,
-/// nullptr otherwise.
-static Type getFirstSubelementType(Type type) {
- auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
- if (!destructurable)
- return nullptr;
-
- Type subelementType = destructurable.getTypeAtIndex(
- IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
- if (subelementType)
- return subelementType;
-
- return nullptr;
-}
-
-/// Extracts a pointer to the first field of an `elemType` from the address
-/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
-/// instead.
-template <class MemOp>
-static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
- Type elemType) {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- rewriter.setInsertionPointAfterValue(op.getAddr());
- SmallVector<GEPArg> firstTypeIndices{0, 0};
-
- Value properPtr = rewriter.create<GEPOp>(
- op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
- op.getAddr(), firstTypeIndices);
-
- rewriter.modifyOpInPlace(op,
- [&]() { op.getAddrMutable().assign(properPtr); });
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
- LoadOp load, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(load.getAddr(), load.getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
- DataLayout layout = DataLayout::closest(load);
- if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
- return failure();
-
- insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
-
- // If the load does not use the first type but a type that can be casted from
- // it, add a bitcast and change the load type.
- if (firstType != load.getResult().getType()) {
- rewriter.setInsertionPointAfterValue(load.getResult());
- BitcastOp bitcast = rewriter.create<BitcastOp>(
- load->getLoc(), load.getResult().getType(), load.getResult());
- rewriter.modifyOpInPlace(load,
- [&]() { load.getResult().setType(firstType); });
- rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
- bitcast);
- }
-
- return success();
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
- StoreOp store, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
-
- DataLayout layout = DataLayout::closest(store);
- // Check that the first field has the right type or can at least be bitcast
- // to the right type.
- if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
- return failure();
-
- insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
-
- rewriter.modifyOpInPlace(
- store, [&]() { store.getValueMutable().assign(store.getValue()); });
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// CanonicalizeAlignedGep
//===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
: public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
void runOnOperation() override {
RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
- &getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
rewritePatterns.add<BitcastStores>(&getContext());
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 7be4056fb2fc80..6c5250d527ade8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> indexMap;
- for (auto const &[index, type] : *destructuredType)
- indexMap.insert({index, MemRefType::get({}, type)});
-
- return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+ return {
+ DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
}
DenseMap<Attribute, MemorySlot>
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index 02d25f27f978a6..73666afaf66b27 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -215,3 +215,94 @@ llvm.func @no_nested_dynamic_indexing(%arg: i32) -> i32 {
// CHECK: llvm.return %[[RES]] : i32
llvm.return %3 : i32
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field
+llvm.func @store_first_field(%arg: i32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %{{.*}}, %[[ALLOCA]] : i32
+ llvm.store %arg, %1 : i32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field_different_type
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_first_field_different_type(%arg: f32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ llvm.store %arg, %1 : f32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_sub_field
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_sub_field(%arg: f32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ llvm.store %arg, %1 : f32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field
+llvm.func @load_first_field() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
+ %2 = llvm.load %1 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field_different_type
+llvm.func @load_first_field_different_type() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
+ %2 = llvm.load %1 : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[RES]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_sub_field
+llvm.func @load_sub_field() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 : (i32) -> !llvm.ptr
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.struct<(i64, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %res = llvm.load %1 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_store_type_mismatch
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_store_type_mismatch(%arg: vector<4xi32>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x vector<4xf32>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 021151b929d8e2..a6176142f17463 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -26,63 +26,6 @@ llvm.func @same_address_keep_inbounds(%arg: i32) {
// -----
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field
-llvm.func @struct_store_instead_of_first_field(%arg: i32) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32
- llvm.store %arg, %1 : i32, !llvm.ptr
- llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
- // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32
- llvm.store %arg, %1 : f32, !llvm.ptr
- llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field
-llvm.func @struct_load_instead_of_first_field() -> i32 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK: %[[RES:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
- %2 = llvm.load %1 : !llvm.ptr -> i32
- // CHECK: llvm.return %[[RES]] : i32
- llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field_same_size
-llvm.func @struct_load_instead_of_first_field_same_size() -> f32 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK: %[[LOADED:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
- // CHECK: %[[RES:.*]] = llvm.bitcast %[[LOADED]] : i32 to f32
- %2 = llvm.load %1 : !llvm.ptr -> f32
- // CHECK: llvm.return %[[RES]] : f32
- llvm.return %2 : f32
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @index_in_final_padding
llvm.func @index_in_final_padding(%arg: i32) {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -135,22 +78,6 @@ llvm.func @index_not_in_padding_because_packed(%arg: i16) {
// -----
-// CHECK-LABEL: llvm.func @index_to_struct
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-llvm.func @index_to_struct(%arg: i32) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
- // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"bar", (i32, i32)>
- %7 = llvm.getelementptr %1[4] : (!llvm.ptr) -> !llvm.ptr, i8
- // CHECK: llvm.store %[[ARG]], %[[GEP1]]
- llvm.store %arg, %7 : i32, !llvm.ptr
- llvm.return
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @no_crash_on_negative_gep_index
llvm.func @no_crash_on_negative_gep_index() {
%0 = llvm.mlir.constant(1.000000e+00 : f16) : f16
@@ -175,10 +102,9 @@ llvm.func @coalesced_store_ints(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
@@ -225,11 +151,9 @@ llvm.func @coalesced_store_floats(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (f32, f32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
- // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
- // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
@@ -298,10 +222,9 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)>
@@ -328,9 +251,8 @@ llvm.func @vector_write_split(%arg: vector<4xi32>) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32>
- // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr
+ // CHECK: llvm.store %[[EXTRACT]], %[[ALLOCA]] : i32, !llvm.ptr
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32>
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)>
@@ -405,36 +327,6 @@ llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {
// -----
-// CHECK-LABEL: llvm.func @type_consistent_vector_store
-// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
-llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)>
- // CHECK: llvm.store %[[ARG]], %[[GEP]]
- llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
- llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type
-// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
-llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
- // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
- // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
- llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
- // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
- llvm.return
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @bitcast_insertion
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @bitcast_insertion(%arg: i32) {
@@ -478,10 +370,9 @@ llvm.func @coalesced_store_ints_subaggregate(%arg: i64) {
%3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
// CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[TOP_GEP]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)>
@@ -520,10 +411,9 @@ llvm.func @overlapping_int_aggregate_store(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
// CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
@@ -531,8 +421,7 @@ llvm.func @overlapping_int_aggregate_store(%arg: i64) {
// Normal integer splitting of [[TRUNC]] follows:
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
- // CHECK: llvm.store %{{.*}}, %[[GEP]]
+ // CHECK: llvm.store %{{.*}}, %[[TOP_GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
@@ -557,14 +446,12 @@ llvm.func @overlapping_vector_aggregate_store(%arg: vector<4 x i16>) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32]
- // CHECK: llvm.store %[[EXTRACT]], %[[GEP]]
+ // CHECK: llvm.store %[[EXTRACT]], %[[ALLOCA]]
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32]
// CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
- // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
- // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]
+ // CHECK: llvm.store %[[EXTRACT]], %[[GEP0]]
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32]
// CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
@@ -593,10 +480,9 @@ llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
// CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
@@ -604,8 +490,7 @@ llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
// Normal integer splitting of [[TRUNC]] follows:
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
- // CHECK: llvm.store %{{.*}}, %[[GEP]]
+ // CHECK: llvm.store %{{.*}}, %[[TOP_GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
@@ -651,10 +536,9 @@ llvm.func @coalesced_store_ints_array(%arg: i64) {
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x i32>
%1 = llvm.alloca %0 x !llvm.array<2 x i32> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
- // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+ // CHECK: llvm.store %[[TRUNC]], %[[ALLOCA]]
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32>
>From 1d720cd60330688111a22ac3dea253c25151e189 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 20 Mar 2024 06:57:20 +0000
Subject: [PATCH 2/2] address review comments
---
.../mlir/Interfaces/MemorySlotInterfaces.h | 3 +--
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 5 +----
mlir/test/Dialect/LLVMIR/sroa.mlir | 15 +++++++--------
3 files changed, 9 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 87db1aaf39dea2..aaa261be6553f3 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -26,8 +26,7 @@ struct MemorySlot {
/// Memory slot attached with information about its destructuring procedure.
struct DestructurableMemorySlot : public MemorySlot {
- /// Maps an index within the memory slot to the element type of the pointer
- /// that will be generated to access the element directly.
+ /// Maps an index within the memory slot to the corresponding subelement type.
DenseMap<Attribute, Type> elementPtrs;
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f9662789025764..0ef1d105aca6cb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -251,7 +251,7 @@ bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
if (getVolatile_())
return false;
- // A load always accesses the first element of the destructured slot.
+ // A store always accesses the first element of the destructured slot.
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
Type subslotType = getTypeAtIndex(slot, index);
if (!subslotType)
@@ -468,9 +468,6 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
// dynamic indices can never be properly rewired.
if (!getDynamicIndices().empty())
return false;
- //// TODO: This is not necessary, I think.
- // if (slot.elemType != getElemType())
- // return false;
Type reachedType = getResultPtrElementType();
if (!reachedType || getIndices().size() < 2)
return false;
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index 73666afaf66b27..ca49b1298b0e91 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -223,7 +223,7 @@ llvm.func @store_first_field(%arg: i32) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: llvm.store %{{.*}}, %[[ALLOCA]] : i32
+ // CHECK-NEXT: llvm.store %{{.*}}, %[[ALLOCA]] : i32
llvm.store %arg, %1 : i32, !llvm.ptr
llvm.return
}
@@ -236,7 +236,7 @@ llvm.func @store_first_field_different_type(%arg: f32) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]] : f32
llvm.store %arg, %1 : f32, !llvm.ptr
llvm.return
}
@@ -249,7 +249,7 @@ llvm.func @store_sub_field(%arg: f32) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32)> : (i32) -> !llvm.ptr
- // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]] : f32
llvm.store %arg, %1 : f32, !llvm.ptr
llvm.return
}
@@ -261,7 +261,7 @@ llvm.func @load_first_field() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
+ // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
%2 = llvm.load %1 : !llvm.ptr -> i32
// CHECK: llvm.return %[[RES]] : i32
llvm.return %2 : i32
@@ -274,7 +274,7 @@ llvm.func @load_first_field_different_type() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
+ // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
%2 = llvm.load %1 : !llvm.ptr -> f32
// CHECK: llvm.return %[[RES]] : f32
llvm.return %2 : f32
@@ -286,9 +286,8 @@ llvm.func @load_first_field_different_type() -> f32 {
llvm.func @load_sub_field() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 : (i32) -> !llvm.ptr
- // CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x !llvm.struct<(i64, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ // CHECK-NEXT: %[[RES:.*]] = llvm.load %[[ALLOCA]]
%res = llvm.load %1 : !llvm.ptr -> i32
// CHECK: llvm.return %[[RES]] : i32
llvm.return %res : i32
@@ -302,7 +301,7 @@ llvm.func @vector_store_type_mismatch(%arg: vector<4xi32>) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x vector<4xf32>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
- // CHECK: llvm.store %[[ARG]], %[[ALLOCA]]
+ // CHECK-NEXT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
llvm.return
}
More information about the Mlir-commits
mailing list