[Mlir-commits] [mlir] 4ed502e - [mlir] Add a generic SROA implementation.
Tobias Gysi
llvmlistbot at llvm.org
Mon May 22 02:02:31 PDT 2023
Author: Théo Degioanni
Date: 2023-05-22T09:01:09Z
New Revision: 4ed502ef4fcc5d9592a3ded662a9cac701fb4392
URL: https://github.com/llvm/llvm-project/commit/4ed502ef4fcc5d9592a3ded662a9cac701fb4392
DIFF: https://github.com/llvm/llvm-project/commit/4ed502ef4fcc5d9592a3ded662a9cac701fb4392.diff
LOG: [mlir] Add a generic SROA implementation.
This revision introduces a generic implementation of Scalar Replacement
Of Aggregates. In contrast to the implementation in LLVM, this focuses
on the core of SROA: destructuring aggregates. By implementing
interfaces on allocators and accessors, memory allocators can be
destructured into smaller allocators, through the MemorySlot
abstraction.
This pass only works on aggregates that are accessed in a "type-safe"
way, that is within the bounds and respecting the type of a given memory
slot. The destructuring pattern and functions only peel off the first
layer of aggregates and can safely be applied repeatedly. For
convenience, the transformation is also available as a pass that will
apply the pattern repeatedly.
Depends on D149958
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D150186
Added:
mlir/include/mlir/Transforms/SROA.h
mlir/lib/Transforms/SROA.cpp
mlir/test/Dialect/LLVMIR/sroa-statistics.mlir
mlir/test/Dialect/LLVMIR/sroa.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
mlir/include/mlir/Transforms/Mem2Reg.h
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Interfaces/MemorySlotInterfaces.cpp
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Mem2Reg.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e6aed104b6d3b..232e180ab8975 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -172,8 +172,9 @@ def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp<
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, "fneg", "FNeg">;
// Memory-related operations.
-def LLVM_AllocaOp : LLVM_Op<"alloca",
- [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]>,
+def LLVM_AllocaOp : LLVM_Op<"alloca",
+ [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]>,
LLVM_MemOpPatterns {
let arguments = (ins AnyInteger:$arraySize,
OptionalAttr<I64Attr>:$alignment,
@@ -232,7 +233,9 @@ def LLVM_AllocaOp : LLVM_Op<"alloca",
}
def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
- DeclareOpInterfaceMethods<PromotableOpInterface>]> {
+ DeclareOpInterfaceMethods<PromotableOpInterface>,
+ DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices,
DenseI32ArrayAttr:$rawConstantIndices,
@@ -316,7 +319,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
}
def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
- [DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+ [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins Arg<LLVM_PointerTo<LLVM_LoadableType>, "", [MemRead]>:$addr,
OptionalAttr<I64Attr>:$alignment,
UnitAttr:$volatile_,
@@ -388,7 +392,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
}
def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
- [DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+ [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_LoadableType:$value,
Arg<LLVM_PointerTo<LLVM_LoadableType>,"",[MemWrite]>:$addr,
OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 9ae0ba6365dfa..ba2f14f173aa0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -16,6 +16,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include <optional>
namespace llvm {
@@ -103,6 +104,7 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait,
+ DestructurableTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
@@ -198,6 +200,12 @@ class LLVMStructType
LogicalResult verifyEntries(DataLayoutEntryListRef entries,
Location loc) const;
+
+ /// Destructs the struct into its indexed field types.
+ std::optional<DenseMap<Attribute, Type>> getSubelementIndexMap();
+
+ /// Returns which type is stored at a given integer index within the struct.
+ Type getTypeAtIndex(Attribute index);
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index caf4b58b87f56..e26d9d8acc79e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
/// Base class for all LLVM dialect types.
class LLVMType<string typeName, string typeMnemonic, list<Trait> traits = []>
@@ -24,7 +25,8 @@ class LLVMType<string typeName, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def LLVMArrayType : LLVMType<"LLVMArray", "array", [
- DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>]> {
+ DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getTypeSize"]>,
+ DeclareTypeInterfaceMethods<DestructurableTypeInterface>]> {
let summary = "LLVM array type";
let description = [{
The `!llvm.array` type represents a fixed-size array of element types.
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 0b42dfec5facb..71a980056a4ea 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -19,6 +19,8 @@ add_mlir_interface(ViewLikeInterface)
set(LLVM_TARGET_DEFINITIONS MemorySlotInterfaces.td)
mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(MemorySlotTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(MemorySlotTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index c0f8b2f8ee9ce..56e5e96aecd13 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -24,6 +24,13 @@ struct MemorySlot {
Type elemType;
};
+/// Memory slot attached with information about its destructuring procedure.
+struct DestructurableMemorySlot : public MemorySlot {
+ /// Maps an index within the memory slot to the type of the pointer that
+ /// will be generated to access the element directly.
+ DenseMap<Attribute, Type> elementPtrs;
+};
+
/// Returned by operation promotion logic requesting the deletion of an
/// operation.
enum class DeletionKind {
@@ -36,5 +43,6 @@ enum class DeletionKind {
} // namespace mlir
#include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc"
+#include "mlir/Interfaces/MemorySlotTypeInterfaces.h.inc"
#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 73061f79521af..5ef646b24ff9f 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -215,4 +215,158 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
];
}
+def DestructurableAllocationOpInterface
+ : OpInterface<"DestructurableAllocationOpInterface"> {
+ let description = [{
+ Describes operations allocating memory slots of aggregates that can be
+ destructured into multiple smaller allocations.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the list of slots for which destructuring should be attempted,
+ specifying in which way the slot should be destructured into subslots.
+ The subslots are indexed by attributes. This computes the type of the
+ pointer for each subslot to be generated. The type of the memory slot
+ must implement `DestructurableTypeInterface`.
+
+ No IR mutation is allowed in this method.
+ }],
+ "::llvm::SmallVector<::mlir::DestructurableMemorySlot>",
+ "getDestructurableSlots",
+ (ins)
+ >,
+ InterfaceMethod<[{
+ Destructures this slot into multiple subslots. The newly generated slots
+ may belong to a
diff erent allocator. The original slot must still exist
+ at the end of this call. Only generates subslots for the indices found in
+ `usedIndices` since all other subslots are unused.
+
+ The rewriter is located at the beginning of the block where the slot
+ pointer is defined. All IR mutations must happen through the rewriter.
+ }],
+ "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>",
+ "destructure",
+ (ins "const ::mlir::DestructurableMemorySlot &":$slot,
+ "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
+ "::mlir::RewriterBase &":$rewriter)
+ >,
+ InterfaceMethod<[{
+ Hook triggered once the destructuring of a slot is complete, meaning the
+ original slot is no longer being refered to and could be deleted.
+ This will only be called for slots declared by this operation.
+
+ All IR mutations must happen through the rewriter.
+ }],
+ "void", "handleDestructuringComplete",
+ (ins "const ::mlir::DestructurableMemorySlot &":$slot,
+ "::mlir::RewriterBase &":$rewriter)
+ >,
+ ];
+}
+
+def SafeMemorySlotAccessOpInterface
+ : OpInterface<"SafeMemorySlotAccessOpInterface"> {
+ let description = [{
+ Describes operations using memory slots in a type-safe manner.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns whether all accesses in this operation to the provided slot are
+ done in a type-safe manner. To be type-safe, the access must only load
+ the value in this type as the type of the slot, and without assuming any
+ context around the slot. For example, a type-safe load must not load
+ outside the bounds of the slot.
+
+ If the type-safety of the accesses depends on the type-safety of the
+ accesses to further memory slots, the result of this method will be
+ conditioned to the type-safety of the accesses to the slots added by
+ this method to `mustBeSafelyUsed`.
+
+ No IR mutation is allowed in this method.
+ }],
+ "::mlir::LogicalResult",
+ "ensureOnlySafeAccesses",
+ (ins "const ::mlir::MemorySlot &":$slot,
+ "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed)
+ >
+ ];
+}
+
+def DestructurableAccessorOpInterface
+ : OpInterface<"DestructurableAccessorOpInterface"> {
+ let description = [{
+ Describes operations that can access a sub-element of a destructurable slot.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ For a given destructurable memory slot, returns whether this operation can
+ rewire its uses of the slot to use the slots generated after
+ destructuring. This may involve creating new operations, and usually
+ amounts to checking if the pointer types match.
+
+ This method must also register the indices it will access within the
+ `usedIndices` set. If the accessor generates new slots mapping to
+ subelements, they must be registered in `mustBeSafelyUsed` to ensure
+ they are used in a locally type-safe manner.
+
+ No IR mutation is allowed in this method.
+ }],
+ "bool",
+ "canRewire",
+ (ins "const ::mlir::DestructurableMemorySlot &":$slot,
+ "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
+ "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed)
+ >,
+ InterfaceMethod<[{
+ Rewires the use of a slot to the generated subslots, without deleting
+ any operation. Returns whether the accessor should be deleted.
+
+ All IR mutations must happen through the rewriter. Deletion of
+ operations is not allowed, only the accessor can be scheduled for
+ deletion by returning the appropriate value.
+ }],
+ "::mlir::DeletionKind",
+ "rewire",
+ (ins "const ::mlir::DestructurableMemorySlot &":$slot,
+ "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots,
+ "::mlir::RewriterBase &":$rewriter)
+ >
+ ];
+}
+
+def DestructurableTypeInterface
+ : TypeInterface<"DestructurableTypeInterface"> {
+ let description = [{
+ Describes a type that can be broken down into indexable sub-element types.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Destructures the type into subelements into a map of index attributes to
+ types of subelements. Returns nothing if the type cannot be destructured.
+ }],
+ "::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>",
+ "getSubelementIndexMap",
+ (ins)
+ >,
+ InterfaceMethod<[{
+ Indicates which type is held at the provided index, returning a null
+ Type if no type could be computed. While this can return information
+ even when the type cannot be completely destructured, it must be coherent
+ with the types returned by `getSubelementIndexMap` when they exist.
+ }],
+ "::mlir::Type",
+ "getTypeAtIndex",
+ (ins "::mlir::Attribute":$index)
+ >
+ ];
+}
+
#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 46b2a1f56d21e..89244feb21754 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -17,8 +17,11 @@
namespace mlir {
+/// Statistics collected while applying mem2reg.
struct Mem2RegStatistics {
+ /// Total amount of memory slots promoted.
llvm::Statistic *promotedAmount = nullptr;
+ /// Total amount of new block arguments inserted in blocks.
llvm::Statistic *newBlockArgumentAmount = nullptr;
};
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index f5f76076c8e07..9110b64d55a63 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -36,6 +36,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_MEM2REG
#define GEN_PASS_DECL_PRINTIRPASS
#define GEN_PASS_DECL_PRINTOPSTATS
+#define GEN_PASS_DECL_SROA
#define GEN_PASS_DECL_STRIPDEBUGINFO
#define GEN_PASS_DECL_SCCP
#define GEN_PASS_DECL_SYMBOLDCE
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 62b8dd075f21f..125ba2ffbac72 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -199,10 +199,10 @@ def Mem2Reg : Pass<"mem2reg"> {
let statistics = [
Statistic<"promotedAmount",
"promoted slots",
- "Number of promoted memory slot">,
+ "Total amount of memory slot promoted">,
Statistic<"newBlockArgumentAmount",
"new block args",
- "Total number of block arguments added">,
+ "Total amount of new block argument inserted in blocks">,
];
}
@@ -229,6 +229,42 @@ def SCCP : Pass<"sccp"> {
let constructor = "mlir::createSCCPPass()";
}
+def SROA : Pass<"sroa"> {
+ let summary = "Scalar Replacement of Aggregates";
+ let description = [{
+ Scalar Replacement of Aggregates. Replaces allocations of aggregates into
+ independant allocations of its elements.
+
+ Allocators must implement `DestructurableAllocationOpInterface` to provide
+ the list of memory slots for which destructuring should be attempted.
+
+ This pass will only be applied if all accessors of the aggregate implement
+ the `DestructurableAccessorOpInterface`. If the accessors provide a view
+ into the struct, users of the view must ensure it is used in a type-safe
+ manner and within bounds by implementing `TypeSafeOpInterface`.
+ }];
+
+ let statistics = [
+ Statistic<
+ "destructuredAmount",
+ "destructured slots",
+ "Total amount of memory slots destructured"
+ >,
+ Statistic<
+ "slotsWithMemoryBenefit",
+ "slots with memory benefit",
+ "Total amount of memory slots in which the destructured size was smaller "
+ "than the total size after eliminating unused fields"
+ >,
+ Statistic<
+ "maxSubelementAmount",
+ "max subelement number",
+ "Maximal number of sub-elements a successfully destructured slot "
+ "initially had"
+ >,
+ ];
+}
+
def StripDebugInfo : Pass<"strip-debuginfo"> {
let summary = "Strip debug info from all operations";
let description = [{
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
new file mode 100644
index 0000000000000..3a44dc032966b
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -0,0 +1,57 @@
+//===-- SROA.h - Scalar Replacement Of Aggregates ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SROA_H
+#define MLIR_TRANSFORMS_SROA_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Statistic.h"
+
+namespace mlir {
+
+/// Statistics collected while applying SROA.
+struct SROAStatistics {
+ /// Total amount of memory slots destructured.
+ llvm::Statistic *destructuredAmount = nullptr;
+ /// Total amount of memory slots in which the destructured size was smaller
+ /// than the total size after eliminating unused fields.
+ llvm::Statistic *slotsWithMemoryBenefit = nullptr;
+ /// Maximal number of sub-elements a successfully destructured slot initially
+ /// had.
+ llvm::Statistic *maxSubelementAmount = nullptr;
+};
+
+/// Pattern applying SROA to the regions of the operations on which it
+/// matches.
+class SROAPattern
+ : public OpInterfaceRewritePattern<DestructurableAllocationOpInterface> {
+public:
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ SROAPattern(MLIRContext *context, SROAStatistics statistics = {},
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
+
+ LogicalResult matchAndRewrite(DestructurableAllocationOpInterface allocator,
+ PatternRewriter &rewriter) const override;
+
+private:
+ SROAStatistics statistics;
+};
+
+/// Attempts to destructure the slots of destructurable allocators. Returns
+/// failure if no slot was destructured.
+LogicalResult tryToDestructureMemorySlots(
+ ArrayRef<DestructurableAllocationOpInterface> allocators,
+ RewriterBase &rewriter, SROAStatistics statistics = {});
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SROA_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 51c49892d71c9..89693ec69cf6c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -12,6 +12,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
using namespace mlir;
@@ -56,6 +60,57 @@ void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
rewriter.eraseOp(*this);
}
+SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
+ if (!mlir::matchPattern(getArraySize(), m_One()))
+ return {};
+
+ Type elemType = getAllocaElementType(*this);
+ auto destructurable = dyn_cast<DestructurableTypeInterface>(elemType);
+ if (!destructurable)
+ return {};
+
+ std::optional<DenseMap<Attribute, Type>> destructuredType =
+ destructurable.getSubelementIndexMap();
+ if (!destructuredType)
+ return {};
+
+ DenseMap<Attribute, Type> allocaTypeMap;
+ for (Attribute index : llvm::make_first_range(destructuredType.value()))
+ allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
+
+ return {DestructurableMemorySlot{{getResult(), elemType}, {allocaTypeMap}}};
+}
+
+DenseMap<Attribute, MemorySlot>
+LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices,
+ RewriterBase &rewriter) {
+ assert(slot.ptr == getResult());
+ Type elemType =
+ getElemType() ? *getElemType() : getResult().getType().getElementType();
+
+ rewriter.setInsertionPointAfter(*this);
+
+ auto destructurableType = cast<DestructurableTypeInterface>(elemType);
+ DenseMap<Attribute, MemorySlot> slotMap;
+ for (Attribute index : usedIndices) {
+ Type elemType = destructurableType.getTypeAtIndex(index);
+ assert(elemType && "used index must exist");
+ auto subAlloca = rewriter.create<LLVM::AllocaOp>(
+ getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
+ getArraySize());
+ slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
+ }
+
+ return slotMap;
+}
+
+void LLVM::AllocaOp::handleDestructuringComplete(
+ const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+ assert(slot.ptr == getResult());
+ rewriter.eraseOp(*this);
+}
+
//===----------------------------------------------------------------------===//
// Interfaces for LoadOp/StoreOp
//===----------------------------------------------------------------------===//
@@ -115,12 +170,23 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
// `canUsesBeRemoved` checked this blocking use must be the stored slot
// pointer.
for (Operation *user : slot.ptr.getUsers())
- if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
+ if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
declareOp.getVarInfo());
return DeletionKind::Delete;
}
+LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ return success(getAddr() != slot.ptr || getType() == slot.elemType);
+}
+
+LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ return success(getAddr() != slot.ptr ||
+ getValue().getType() == slot.elemType);
+}
+
//===----------------------------------------------------------------------===//
// Interfaces for discardable OPs
//===----------------------------------------------------------------------===//
@@ -189,6 +255,10 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
return DeletionKind::Delete;
}
+//===----------------------------------------------------------------------===//
+// Interfaces for GEPOp
+//===----------------------------------------------------------------------===//
+
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
return llvm::all_of(gepOp.getIndices(), [](auto index) {
auto indexAttr = index.template dyn_cast<IntegerAttr>();
@@ -196,10 +266,6 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
});
}
-//===----------------------------------------------------------------------===//
-// Interfaces for GEPOp
-//===----------------------------------------------------------------------===//
-
bool LLVM::GEPOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses) {
@@ -213,3 +279,171 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
+
+/// Returns the type the resulting pointer of the GEP points to. If such a type
+/// is not clear, returns null type.
+static Type computeReachedGEPType(LLVM::GEPOp gep) {
+ if (gep.getIndices().empty())
+ return {};
+
+ // Ensures all indices are static and fetches them.
+ SmallVector<IntegerAttr> indices;
+ for (auto index : gep.getIndices()) {
+ IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
+ if (!indexInt)
+ return {};
+ indices.push_back(indexInt);
+ }
+
+ // Check the pointer indexing only targets the first element.
+ if (indices[0].getInt() != 0)
+ return {};
+
+ // Set the initial type currently being used for indexing. This will be
+ // updated as the indices get walked over.
+ std::optional<Type> maybeSelectedType = gep.getElemType();
+ if (!maybeSelectedType)
+ return {};
+ Type selectedType = *maybeSelectedType;
+
+ // Follow the indexed elements in the gep.
+ for (IntegerAttr index : llvm::drop_begin(indices)) {
+ // Ensure the structure of the type being indexed can be reasoned about.
+ // This includes rejecting any potential typed pointer.
+ auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
+ if (!destructurable)
+ return {};
+
+ // Follow the type at the index the gep is accessing, making it the new type
+ // used for indexing.
+ Type field = destructurable.getTypeAtIndex(index);
+ if (!field)
+ return {};
+ selectedType = field;
+ }
+
+ // When there are no more indices, the type currently being used for indexing
+ // is the type of the value pointed at by the returned indexed pointer.
+ return selectedType;
+}
+
+LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ if (getBase() != slot.ptr)
+ return success();
+ if (slot.elemType != getElemType())
+ return failure();
+ Type reachedType = computeReachedGEPType(*this);
+ if (!reachedType)
+ return failure();
+ mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ return success();
+}
+
+bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
+ if (!basePtrType)
+ return false;
+
+ // Typed pointers are not supported. This should be removed once typed
+ // pointers are removed from the LLVM dialect.
+ if (!basePtrType.isOpaque())
+ return false;
+
+ if (getBase() != slot.ptr || slot.elemType != getElemType())
+ return false;
+ Type reachedType = computeReachedGEPType(*this);
+ if (!reachedType || getIndices().size() < 2)
+ return false;
+ auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
+ assert(slot.elementPtrs.contains(firstLevelIndex));
+ if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
+ return false;
+ mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ usedIndices.insert(firstLevelIndex);
+ return true;
+}
+
+DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter) {
+ IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
+ const MemorySlot &newSlot = subslots.at(firstLevelIndex);
+
+ ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
+
+ // If the GEP would become trivial after this transformation, eliminate it.
+ // A GEP should only be eliminated if it has no indices (except the first
+ // pointer index), as simplifying GEPs with all-zero indices would eliminate
+ // structure information useful for further destruction.
+ if (remainingIndices.empty()) {
+ rewriter.replaceAllUsesWith(getResult(), newSlot.ptr);
+ return DeletionKind::Delete;
+ }
+
+ rewriter.updateRootInPlace(*this, [&]() {
+ // Rewire the indices by popping off the second index.
+ // Start with a single zero, then add the indices beyond the second.
+ SmallVector<int32_t> newIndices(1);
+ newIndices.append(remainingIndices.begin(), remainingIndices.end());
+ setRawConstantIndices(newIndices);
+
+ // Rewire the pointed type.
+ setElemType(newSlot.elemType);
+
+ // Rewire the pointer.
+ getBaseMutable().assign(newSlot.ptr);
+ });
+
+ return DeletionKind::Keep;
+}
+
+//===----------------------------------------------------------------------===//
+// Interfaces for destructurable types
+//===----------------------------------------------------------------------===//
+
+std::optional<DenseMap<Attribute, Type>>
+LLVM::LLVMStructType::getSubelementIndexMap() {
+ Type i32 = IntegerType::get(getContext(), 32);
+ DenseMap<Attribute, Type> destructured;
+ for (const auto &[index, elemType] : llvm::enumerate(getBody()))
+ destructured.insert({IntegerAttr::get(i32, index), elemType});
+ return destructured;
+}
+
+Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
+ auto indexAttr = index.dyn_cast<IntegerAttr>();
+ if (!indexAttr || !indexAttr.getType().isInteger(32))
+ return {};
+ int32_t indexInt = indexAttr.getInt();
+ ArrayRef<Type> body = getBody();
+ if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
+ return {};
+ return body[indexInt];
+}
+
+std::optional<DenseMap<Attribute, Type>>
+LLVM::LLVMArrayType::getSubelementIndexMap() const {
+ constexpr size_t maxArraySizeForDestructuring = 16;
+ if (getNumElements() > maxArraySizeForDestructuring)
+ return {};
+ int32_t numElements = getNumElements();
+
+ Type i32 = IntegerType::get(getContext(), 32);
+ DenseMap<Attribute, Type> destructured;
+ for (int32_t index = 0; index < numElements; ++index)
+ destructured.insert({IntegerAttr::get(i32, index), getElementType()});
+ return destructured;
+}
+
+Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
+ auto indexAttr = index.dyn_cast<IntegerAttr>();
+ if (!indexAttr || !indexAttr.getType().isInteger(32))
+ return {};
+ int32_t indexInt = indexAttr.getInt();
+ if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
+ return {};
+ return getElementType();
+}
diff --git a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
index 431febf31f89a..2c9e23250e9ee 100644
--- a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
@@ -9,3 +9,4 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc"
+#include "mlir/Interfaces/MemorySlotTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index b7e1cd927d6b9..72dd7ab94e909 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms
OpStats.cpp
PrintIR.cpp
SCCP.cpp
+ SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
SymbolPrivatize.cpp
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 3b303f9836cf5..5152d10fbf6b4 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -578,8 +578,6 @@ void MemorySlotPromoter::promoteSlot() {
LogicalResult mlir::tryToPromoteMemorySlots(
ArrayRef<PromotableAllocationOpInterface> allocators,
RewriterBase &rewriter, Mem2RegStatistics statistics) {
- DominanceInfo dominance;
-
bool promotedAny = false;
for (PromotableAllocationOpInterface allocator : allocators) {
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
new file mode 100644
index 0000000000000..3ceda51e1c894
--- /dev/null
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -0,0 +1,235 @@
+//===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SROA.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SROA
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "sroa"
+
+using namespace mlir;
+
+namespace {
+
+/// Information computed by destructurable memory slot analysis used to perform
+/// actual destructuring of the slot. This struct is only constructed if
+/// destructuring is possible, and contains the necessary data to perform it.
+struct MemorySlotDestructuringInfo {
+ /// Set of the indices that are actually used when accessing the subelements.
+ SmallPtrSet<Attribute, 8> usedIndices;
+ /// Blocking uses of a given user of the memory slot that must be eliminated.
+ DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+ /// List of potentially indirect accessors of the memory slot that need
+ /// rewiring.
+ SmallVector<DestructurableAccessorOpInterface> accessors;
+};
+
+} // namespace
+
+/// Computes information for slot destructuring. This will compute whether this
+/// slot can be destructured and data to perform the destructuring. Returns
+/// nothing if the slot cannot be destructured or if there is no useful work to
+/// be done.
+static std::optional<MemorySlotDestructuringInfo>
+computeDestructuringInfo(DestructurableMemorySlot &slot) {
+ assert(isa<DestructurableTypeInterface>(slot.elemType));
+
+ if (slot.ptr.use_empty())
+ return {};
+
+ MemorySlotDestructuringInfo info;
+
+ SmallVector<MemorySlot> usedSafelyWorklist;
+
+ auto scheduleAsBlockingUse = [&](OpOperand &use) {
+ SmallPtrSetImpl<OpOperand *> &blockingUses =
+ info.userToBlockingUses.getOrInsertDefault(use.getOwner());
+ blockingUses.insert(&use);
+ };
+
+ // Initialize the analysis with the immediate users of the slot.
+ for (OpOperand &use : slot.ptr.getUses()) {
+ if (auto accessor =
+ dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
+ if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) {
+ info.accessors.push_back(accessor);
+ continue;
+ }
+ }
+
+ // If it cannot be shown that the operation uses the slot safely, maybe it
+ // can be promoted out of using the slot?
+ scheduleAsBlockingUse(use);
+ }
+
+ SmallPtrSet<OpOperand *, 16> visited;
+ while (!usedSafelyWorklist.empty()) {
+ MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
+ for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
+ if (!visited.insert(&subslotUse).second)
+ continue;
+ Operation *subslotUser = subslotUse.getOwner();
+
+ if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
+ if (succeeded(memOp.ensureOnlySafeAccesses(mustBeUsedSafely,
+ usedSafelyWorklist)))
+ continue;
+
+ // If it cannot be shown that the operation uses the slot safely, maybe it
+ // can be promoted out of using the slot?
+ scheduleAsBlockingUse(subslotUse);
+ }
+ }
+
+ SetVector<Operation *> forwardSlice;
+ mlir::getForwardSlice(slot.ptr, &forwardSlice);
+ for (Operation *user : forwardSlice) {
+ // If the next operation has no blocking uses, everything is fine.
+ if (!info.userToBlockingUses.contains(user))
+ continue;
+
+ SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
+ auto promotable = dyn_cast<PromotableOpInterface>(user);
+
+ // An operation that has blocking uses must be promoted. If it is not
+ // promotable, destructuring must fail.
+ if (!promotable)
+ return {};
+
+ SmallVector<OpOperand *> newBlockingUses;
+ // If the operation decides it cannot deal with removing the blocking uses,
+ // destructuring must fail.
+ if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
+ return {};
+
+ // Then, register any new blocking uses for coming operations.
+ for (OpOperand *blockingUse : newBlockingUses) {
+ assert(llvm::is_contained(user->getResults(), blockingUse->get()));
+
+ SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
+ info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
+ newUserBlockingUseSet.insert(blockingUse);
+ }
+ }
+
+ return info;
+}
+
+/// Performs the destructuring of a destructible slot given associated
+/// destructuring information. The provided slot will be destructured in
+/// subslots as specified by its allocator.
+static void destructureSlot(DestructurableMemorySlot &slot,
+ DestructurableAllocationOpInterface allocator,
+ RewriterBase &rewriter,
+ MemorySlotDestructuringInfo &info,
+ const SROAStatistics &statistics) {
+ RewriterBase::InsertionGuard guard(rewriter);
+
+ rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+ DenseMap<Attribute, MemorySlot> subslots =
+ allocator.destructure(slot, info.usedIndices, rewriter);
+
+ if (statistics.slotsWithMemoryBenefit &&
+ slot.elementPtrs.size() != info.usedIndices.size())
+ (*statistics.slotsWithMemoryBenefit)++;
+
+ if (statistics.maxSubelementAmount)
+ statistics.maxSubelementAmount->updateMax(slot.elementPtrs.size());
+
+ SetVector<Operation *> usersToRewire;
+ for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
+ usersToRewire.insert(user);
+ for (DestructurableAccessorOpInterface accessor : info.accessors)
+ usersToRewire.insert(accessor);
+ usersToRewire = mlir::topologicalSort(usersToRewire);
+
+ llvm::SmallVector<Operation *> toErase;
+ for (Operation *toRewire : llvm::reverse(usersToRewire)) {
+ rewriter.setInsertionPointAfter(toRewire);
+ if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
+ if (accessor.rewire(slot, subslots, rewriter) == DeletionKind::Delete)
+ toErase.push_back(accessor);
+ continue;
+ }
+
+ auto promotable = cast<PromotableOpInterface>(toRewire);
+ if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
+ rewriter) == DeletionKind::Delete)
+ toErase.push_back(promotable);
+ }
+
+ for (Operation *toEraseOp : toErase)
+ rewriter.eraseOp(toEraseOp);
+
+ assert(slot.ptr.use_empty() && "after destructuring, the original slot "
+ "pointer should no longer be used");
+
+ LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
+ << "\n");
+
+ if (statistics.destructuredAmount)
+ (*statistics.destructuredAmount)++;
+
+ allocator.handleDestructuringComplete(slot, rewriter);
+}
+
+LogicalResult mlir::tryToDestructureMemorySlots(
+ ArrayRef<DestructurableAllocationOpInterface> allocators,
+ RewriterBase &rewriter, SROAStatistics statistics) {
+ bool destructuredAny = false;
+
+ for (DestructurableAllocationOpInterface allocator : allocators) {
+ for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+ std::optional<MemorySlotDestructuringInfo> info =
+ computeDestructuringInfo(slot);
+ if (!info)
+ continue;
+
+ destructureSlot(slot, allocator, rewriter, *info, statistics);
+ destructuredAny = true;
+ }
+ }
+
+ return success(destructuredAny);
+}
+
+LogicalResult
+SROAPattern::matchAndRewrite(DestructurableAllocationOpInterface allocator,
+ PatternRewriter &rewriter) const {
+ hasBoundedRewriteRecursion();
+ return tryToDestructureMemorySlots({allocator}, rewriter, statistics);
+}
+
+namespace {
+
+struct SROA : public impl::SROABase<SROA> {
+ using impl::SROABase<SROA>::SROABase;
+
+ void runOnOperation() override {
+ Operation *scopeOp = getOperation();
+
+ SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
+ &maxSubelementAmount};
+
+ RewritePatternSet rewritePatterns(&getContext());
+ rewritePatterns.add<SROAPattern>(&getContext(), statistics);
+ FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+
+ if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/LLVMIR/sroa-statistics.mlir b/mlir/test/Dialect/LLVMIR/sroa-statistics.mlir
new file mode 100644
index 0000000000000..e859fbb07e1fb
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/sroa-statistics.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file --mlir-pass-statistics 2>&1 >/dev/null | FileCheck %s
+
+// CHECK: SROA
+// CHECK-NEXT: (S) 1 destructured slots
+// CHECK-NEXT: (S) 2 max subelement number
+// CHECK-NEXT: (S) 1 slots with memory benefit
+llvm.func @basic() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK: SROA
+// CHECK-NEXT: (S) 1 destructured slots
+// CHECK-NEXT: (S) 2 max subelement number
+// CHECK-NEXT: (S) 0 slots with memory benefit
+llvm.func @basic_no_memory_benefit() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+ %3 = llvm.getelementptr inbounds %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+ %4 = llvm.load %2 : !llvm.ptr -> i32
+ %5 = llvm.load %3 : !llvm.ptr -> i32
+ %6 = llvm.add %4, %5 : i32
+ llvm.return %6 : i32
+}
+
+// -----
+
+// CHECK: SROA
+// CHECK-NEXT: (S) 1 destructured slots
+// CHECK-NEXT: (S) 10 max subelement number
+// CHECK-NEXT: (S) 1 slots with memory benefit
+llvm.func @basic_array() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// SROA is applied repeatedly here, peeling off layers of aggregates one after
+// the other, four times.
+
+// CHECK: SROA
+// CHECK-NEXT: (S) 4 destructured slots
+// CHECK-NEXT: (S) 10 max subelement number
+// CHECK-NEXT: (S) 4 slots with memory benefit
+llvm.func @multi_level_direct() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
new file mode 100644
index 0000000000000..d7bf7942686ff
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -0,0 +1,211 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @basic_struct
+llvm.func @basic_struct() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @basic_array
+llvm.func @basic_array() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @multi_level_direct
+llvm.func @multi_level_direct() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// The first application of SROA would generate a GEP with indices [0, 0]. This
+// test ensures this GEP is not eliminated during the first application. Even
+// though doing it would be correct, it would prevent the second application
+// of SROA to eliminate the array. GEPs should be eliminated only when they are
+// truly trivial (with indices [0]).
+
+// CHECK-LABEL: llvm.func @multi_level_direct_two_applications
+llvm.func @multi_level_direct_two_applications() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, array<10 x i32>, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, array<10 x i32>, i8)>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @multi_level_indirect
+llvm.func @multi_level_indirect() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)>
+ %3 = llvm.getelementptr inbounds %2[0, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %4 = llvm.load %3 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @resolve_alias
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+llvm.func @resolve_alias(%arg: i32) -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ %3 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.store %arg, %2 : i32, !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %4 = llvm.load %3 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_non_single_support
+llvm.func @no_non_single_support() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant
+ %0 = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ // CHECK-NOT: = llvm.alloca
+ %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_pointer_indexing
+llvm.func @no_pointer_indexing() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ // CHECK-NOT: = llvm.alloca
+ %2 = llvm.getelementptr %1[1, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_direct_use
+llvm.func @no_direct_use() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ // CHECK-NOT: = llvm.alloca
+ %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.call @use(%1) : (!llvm.ptr) -> ()
+ llvm.return %3 : i32
+}
+
+llvm.func @use(!llvm.ptr)
+
+// -----
+
+// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine
+llvm.func @direct_promotable_use_is_fine() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // This is a direct use of the slot but it can be removed because it implements PromotableOpInterface.
+ llvm.intr.lifetime.start 2, %1 : !llvm.ptr
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine_on_accessor
+llvm.func @direct_promotable_use_is_fine_on_accessor() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)>
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // This does not provide side-effect info but it can be removed because it implements PromotableOpInterface.
+ llvm.intr.lifetime.start 2, %2 : !llvm.ptr
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_dynamic_indexing
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+llvm.func @no_dynamic_indexing(%arg: i32) -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ // CHECK-NOT: = llvm.alloca
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, %[[ARG]]]
+ %2 = llvm.getelementptr %1[0, %arg] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
+ // CHECK: %[[RES:.*]] = llvm.load %[[GEP]]
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_typed_pointers
+llvm.func @no_typed_pointers() -> i32 {
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32)
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr<array<10 x i32>>
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr<array<10 x i32>>
+ // CHECK-NOT: = llvm.alloca
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr<array<10 x i32>>) -> !llvm.ptr<i32>
+ %3 = llvm.load %2 : !llvm.ptr<i32>
+ llvm.return %3 : i32
+}
More information about the Mlir-commits
mailing list