[llvm] [SandboxIR] Implement SandboxIR Type (PR #106294)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 28 09:46:41 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/106294
>From 03951d4e458b1961bc455af753edc44116ee1add Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 19 Aug 2024 11:26:19 -0700
Subject: [PATCH] [SandboxIR] Implement SandboxIR Type
This patch implements sandboxir::Type, a thin wrapper of llvm::Type.
This is designed very similarly to sandbox::Value.
Context owns all sandboxir::Type objects and maintains a map between
llvm::Type and sandboxir::Type.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 71 ++---
llvm/include/llvm/SandboxIR/Type.h | 288 +++++++++++++++++++++
llvm/lib/SandboxIR/CMakeLists.txt | 1 +
llvm/lib/SandboxIR/SandboxIR.cpp | 110 ++++++--
llvm/lib/SandboxIR/Type.cpp | 48 ++++
llvm/unittests/SandboxIR/CMakeLists.txt | 1 +
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 83 +++---
llvm/unittests/SandboxIR/TrackerTest.cpp | 9 +-
llvm/unittests/SandboxIR/TypesTest.cpp | 242 +++++++++++++++++
9 files changed, 753 insertions(+), 100 deletions(-)
create mode 100644 llvm/include/llvm/SandboxIR/Type.h
create mode 100644 llvm/lib/SandboxIR/Type.cpp
create mode 100644 llvm/unittests/SandboxIR/TypesTest.cpp
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b7bdf9acd2ef45..f9581c0dc299e3 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -98,6 +98,7 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/SandboxIR/Tracker.h"
+#include "llvm/SandboxIR/Type.h"
#include "llvm/SandboxIR/Use.h"
#include "llvm/Support/raw_ostream.h"
#include <iterator>
@@ -378,7 +379,7 @@ class Value {
return Cnt == Num;
}
- Type *getType() const { return Val->getType(); }
+ Type *getType() const;
Context &getContext() const { return Ctx; }
@@ -566,8 +567,7 @@ class ConstantInt : public Constant {
public:
/// If Ty is a vector type, return a Constant with a splat of the given
/// value. Otherwise return a ConstantInt for the given value.
- static ConstantInt *get(Type *Ty, uint64_t V, Context &Ctx,
- bool IsSigned = false);
+ static ConstantInt *get(Type *Ty, uint64_t V, bool IsSigned = false);
// TODO: Implement missing functions.
@@ -1014,10 +1014,7 @@ class ExtractElementInst final
Value *getIndexOperand() { return getOperand(1); }
const Value *getVectorOperand() const { return getOperand(0); }
const Value *getIndexOperand() const { return getOperand(1); }
-
- VectorType *getVectorOperandType() const {
- return cast<VectorType>(getVectorOperand()->getType());
- }
+ VectorType *getVectorOperandType() const;
};
class ShuffleVectorInst final
@@ -1062,9 +1059,7 @@ class ShuffleVectorInst final
}
/// Overload to return most specific vector type.
- VectorType *getType() const {
- return cast<llvm::ShuffleVectorInst>(Val)->getType();
- }
+ VectorType *getType() const;
/// Return the shuffle mask value of this instruction for the given element
/// index. Return PoisonMaskElem if the element is undef.
@@ -1090,7 +1085,7 @@ class ShuffleVectorInst final
Constant *getShuffleMaskForBitcode() const;
static Constant *convertShuffleMaskForBitcode(ArrayRef<int> Mask,
- Type *ResultTy, Context &Ctx);
+ Type *ResultTy);
void setShuffleMask(ArrayRef<int> Mask);
@@ -1713,9 +1708,7 @@ class CallBase : public SingleLLVMInstructionImpl<llvm::CallBase> {
Opc == Instruction::ClassID::CallBr;
}
- FunctionType *getFunctionType() const {
- return cast<llvm::CallBase>(Val)->getFunctionType();
- }
+ FunctionType *getFunctionType() const;
op_iterator data_operands_begin() { return op_begin(); }
const_op_iterator data_operands_begin() const {
@@ -2131,12 +2124,8 @@ class GetElementPtrInst final
return From->getSubclassID() == ClassID::GetElementPtr;
}
- Type *getSourceElementType() const {
- return cast<llvm::GetElementPtrInst>(Val)->getSourceElementType();
- }
- Type *getResultElementType() const {
- return cast<llvm::GetElementPtrInst>(Val)->getResultElementType();
- }
+ Type *getSourceElementType() const;
+ Type *getResultElementType() const;
unsigned getAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getAddressSpace();
}
@@ -2160,9 +2149,7 @@ class GetElementPtrInst final
static unsigned getPointerOperandIndex() {
return llvm::GetElementPtrInst::getPointerOperandIndex();
}
- Type *getPointerOperandType() const {
- return cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType();
- }
+ Type *getPointerOperandType() const;
unsigned getPointerAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getPointerAddressSpace();
}
@@ -2709,9 +2696,7 @@ class AllocaInst final : public UnaryInstruction {
return const_cast<AllocaInst *>(this)->getArraySize();
}
/// Overload to return most specific pointer type.
- PointerType *getType() const {
- return cast<llvm::AllocaInst>(Val)->getType();
- }
+ PointerType *getType() const;
/// Return the address space for the allocation.
unsigned getAddressSpace() const {
return cast<llvm::AllocaInst>(Val)->getAddressSpace();
@@ -2727,9 +2712,7 @@ class AllocaInst final : public UnaryInstruction {
return cast<llvm::AllocaInst>(Val)->getAllocationSizeInBits(DL);
}
/// Return the type that is being allocated by the instruction.
- Type *getAllocatedType() const {
- return cast<llvm::AllocaInst>(Val)->getAllocatedType();
- }
+ Type *getAllocatedType() const;
/// for use only in special circumstances that need to generically
/// transform a whole instruction (eg: IR linking and vectorization).
void setAllocatedType(Type *Ty);
@@ -2811,8 +2794,8 @@ class CastInst : public UnaryInstruction {
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);
- Type *getSrcTy() const { return cast<llvm::CastInst>(Val)->getSrcTy(); }
- Type *getDestTy() const { return cast<llvm::CastInst>(Val)->getDestTy(); }
+ Type *getSrcTy() const;
+ Type *getDestTy() const;
};
/// Instruction that can have a nneg flag (zext/uitofp).
@@ -2992,6 +2975,8 @@ class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
class Context {
protected:
LLVMContext &LLVMCtx;
+ friend class Type; // For LLVMCtx.
+ friend class PointerType; // For LLVMCtx.
Tracker IRTracker;
/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
@@ -2999,6 +2984,16 @@ class Context {
DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
LLVMValueToValueMap;
+ /// Type has a protected destructor to prohibit the user from managing the
+ /// lifetime of the Type objects. Context is friend of Type, and this custom
+ /// deleter can destroy Type.
+ struct TypeDeleter {
+ void operator()(Type *Ty) { delete Ty; }
+ };
+ /// Maps LLVM Type to the corresonding sandboxir::Type. Owns all Sandbox IR
+ /// Type objects.
+ DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
+
/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -3033,7 +3028,6 @@ class Context {
/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
-
friend class BasicBlock; // For getOrCreateValue().
IRBuilder<ConstantFolder> LLVMIRBuilder;
@@ -3119,6 +3113,15 @@ class Context {
const sandboxir::Value *getValue(const llvm::Value *V) const {
return getValue(const_cast<llvm::Value *>(V));
}
+
+ Type *getType(llvm::Type *LLVMTy) {
+ auto Pair = LLVMTypeToTypeMap.insert({LLVMTy, nullptr});
+ auto It = Pair.first;
+ if (Pair.second)
+ It->second = std::unique_ptr<Type, TypeDeleter>(new Type(LLVMTy, *this));
+ return It->second.get();
+ }
+
/// Create a sandboxir::Function for an existing LLVM IR \p F, including all
/// blocks and instructions.
/// This is the main API function for creating Sandbox IR.
@@ -3165,9 +3168,7 @@ class Function : public Constant {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
}
- FunctionType *getFunctionType() const {
- return cast<llvm::Function>(Val)->getFunctionType();
- }
+ FunctionType *getFunctionType() const;
#ifndef NDEBUG
void verify() const final {
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
new file mode 100644
index 00000000000000..75e477d953f689
--- /dev/null
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -0,0 +1,288 @@
+//===- llvm/SandboxIR/Type.h - Classes for handling data types --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a thin wrapper over llvm::Type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SANDBOXIR_TYPE_H
+#define LLVM_SANDBOXIR_TYPE_H
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm::sandboxir {
+
+class Context;
+// Forward declare friend classes for MSVC.
+class PointerType;
+class VectorType;
+class FunctionType;
+#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
+#include "llvm/SandboxIR/SandboxIRValues.def"
+
+/// Just like llvm::Type these are immutable, unique, never get freed and can
+/// only be created via static factory methods.
+class Type {
+protected:
+ llvm::Type *LLVMTy;
+ friend class VectorType; // For LLVMTy.
+ friend class PointerType; // For LLVMTy.
+ friend class FunctionType; // For LLVMTy.
+ friend class Function; // For LLVMTy.
+ friend class CallBase; // For LLVMTy.
+ friend class ConstantInt; // For LLVMTy.
+ // Friend all instruction classes because `create()` functions use LLVMTy.
+#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
+ // TODO: Friend DEF_CONST()
+#include "llvm/SandboxIR/SandboxIRValues.def"
+ Context &Ctx;
+
+ Type(llvm::Type *LLVMTy, Context &Ctx) : LLVMTy(LLVMTy), Ctx(Ctx) {}
+ friend class Context; // For constructor and ~Type().
+ ~Type() = default;
+
+public:
+ Context &getContext() const { return Ctx; }
+
+ /// Return true if this is 'void'.
+ bool isVoidTy() const { return LLVMTy->isVoidTy(); }
+
+ /// Return true if this is 'half', a 16-bit IEEE fp type.
+ bool isHalfTy() const { return LLVMTy->isHalfTy(); }
+
+ /// Return true if this is 'bfloat', a 16-bit bfloat type.
+ bool isBFloatTy() const { return LLVMTy->isBFloatTy(); }
+
+ /// Return true if this is a 16-bit float type.
+ bool is16bitFPTy() const { return LLVMTy->is16bitFPTy(); }
+
+ /// Return true if this is 'float', a 32-bit IEEE fp type.
+ bool isFloatTy() const { return LLVMTy->isFloatTy(); }
+
+ /// Return true if this is 'double', a 64-bit IEEE fp type.
+ bool isDoubleTy() const { return LLVMTy->isDoubleTy(); }
+
+ /// Return true if this is x86 long double.
+ bool isX86_FP80Ty() const { return LLVMTy->isX86_FP80Ty(); }
+
+ /// Return true if this is 'fp128'.
+ bool isFP128Ty() const { return LLVMTy->isFP128Ty(); }
+
+ /// Return true if this is powerpc long double.
+ bool isPPC_FP128Ty() const { return LLVMTy->isPPC_FP128Ty(); }
+
+ /// Return true if this is a well-behaved IEEE-like type, which has a IEEE
+ /// compatible layout as defined by APFloat::isIEEE(), and does not have
+ /// non-IEEE values, such as x86_fp80's unnormal values.
+ bool isIEEELikeFPTy() const { return LLVMTy->isIEEELikeFPTy(); }
+
+ /// Return true if this is one of the floating-point types
+ bool isFloatingPointTy() const { return LLVMTy->isFloatingPointTy(); }
+
+ /// Returns true if this is a floating-point type that is an unevaluated sum
+ /// of multiple floating-point units.
+ /// An example of such a type is ppc_fp128, also known as double-double, which
+ /// consists of two IEEE 754 doubles.
+ bool isMultiUnitFPType() const { return LLVMTy->isMultiUnitFPType(); }
+
+ const fltSemantics &getFltSemantics() const {
+ return LLVMTy->getFltSemantics();
+ }
+
+ /// Return true if this is X86 AMX.
+ bool isX86_AMXTy() const { return LLVMTy->isX86_AMXTy(); }
+
+ /// Return true if this is a target extension type.
+ bool isTargetExtTy() const { return LLVMTy->isTargetExtTy(); }
+
+ /// Return true if this is a target extension type with a scalable layout.
+ bool isScalableTargetExtTy() const { return LLVMTy->isScalableTargetExtTy(); }
+
+ /// Return true if this is a type whose size is a known multiple of vscale.
+ bool isScalableTy() const { return LLVMTy->isScalableTy(); }
+
+ /// Return true if this is a FP type or a vector of FP.
+ bool isFPOrFPVectorTy() const { return LLVMTy->isFPOrFPVectorTy(); }
+
+ /// Return true if this is 'label'.
+ bool isLabelTy() const { return LLVMTy->isLabelTy(); }
+
+ /// Return true if this is 'metadata'.
+ bool isMetadataTy() const { return LLVMTy->isMetadataTy(); }
+
+ /// Return true if this is 'token'.
+ bool isTokenTy() const { return LLVMTy->isTokenTy(); }
+
+ /// True if this is an instance of IntegerType.
+ bool isIntegerTy() const { return LLVMTy->isIntegerTy(); }
+
+ /// Return true if this is an IntegerType of the given width.
+ bool isIntegerTy(unsigned Bitwidth) const {
+ return LLVMTy->isIntegerTy(Bitwidth);
+ }
+
+ /// Return true if this is an integer type or a vector of integer types.
+ bool isIntOrIntVectorTy() const { return LLVMTy->isIntOrIntVectorTy(); }
+
+ /// Return true if this is an integer type or a vector of integer types of
+ /// the given width.
+ bool isIntOrIntVectorTy(unsigned BitWidth) const {
+ return LLVMTy->isIntOrIntVectorTy(BitWidth);
+ }
+
+ /// Return true if this is an integer type or a pointer type.
+ bool isIntOrPtrTy() const { return LLVMTy->isIntOrPtrTy(); }
+
+ /// True if this is an instance of FunctionType.
+ bool isFunctionTy() const { return LLVMTy->isFunctionTy(); }
+
+ /// True if this is an instance of StructType.
+ bool isStructTy() const { return LLVMTy->isStructTy(); }
+
+ /// True if this is an instance of ArrayType.
+ bool isArrayTy() const { return LLVMTy->isArrayTy(); }
+
+ /// True if this is an instance of PointerType.
+ bool isPointerTy() const { return LLVMTy->isPointerTy(); }
+
+ /// Return true if this is a pointer type or a vector of pointer types.
+ bool isPtrOrPtrVectorTy() const { return LLVMTy->isPtrOrPtrVectorTy(); }
+
+ /// True if this is an instance of VectorType.
+ inline bool isVectorTy() const { return LLVMTy->isVectorTy(); }
+
+ /// Return true if this type could be converted with a lossless BitCast to
+ /// type 'Ty'. For example, i8* to i32*. BitCasts are valid for types of the
+ /// same size only where no re-interpretation of the bits is done.
+ /// Determine if this type could be losslessly bitcast to Ty
+ bool canLosslesslyBitCastTo(Type *Ty) const {
+ return LLVMTy->canLosslesslyBitCastTo(Ty->LLVMTy);
+ }
+
+ /// Return true if this type is empty, that is, it has no elements or all of
+ /// its elements are empty.
+ bool isEmptyTy() const { return LLVMTy->isEmptyTy(); }
+
+ /// Return true if the type is "first class", meaning it is a valid type for a
+ /// Value.
+ bool isFirstClassType() const { return LLVMTy->isFirstClassType(); }
+
+ /// Return true if the type is a valid type for a register in codegen. This
+ /// includes all first-class types except struct and array types.
+ bool isSingleValueType() const { return LLVMTy->isSingleValueType(); }
+
+ /// Return true if the type is an aggregate type. This means it is valid as
+ /// the first operand of an insertvalue or extractvalue instruction. This
+ /// includes struct and array types, but does not include vector types.
+ bool isAggregateType() const { return LLVMTy->isAggregateType(); }
+
+ /// Return true if it makes sense to take the size of this type. To get the
+ /// actual size for a particular target, it is reasonable to use the
+ /// DataLayout subsystem to do this.
+ bool isSized(SmallPtrSetImpl<Type *> *Visited = nullptr) const {
+ SmallPtrSet<llvm::Type *, 8> LLVMVisited;
+ LLVMVisited.reserve(Visited->size());
+ for (Type *Ty : *Visited)
+ LLVMVisited.insert(Ty->LLVMTy);
+ return LLVMTy->isSized(&LLVMVisited);
+ }
+
+ /// Return the basic size of this type if it is a primitive type. These are
+ /// fixed by LLVM and are not target-dependent.
+ /// This will return zero if the type does not have a size or is not a
+ /// primitive type.
+ ///
+ /// If this is a scalable vector type, the scalable property will be set and
+ /// the runtime size will be a positive integer multiple of the base size.
+ ///
+ /// Note that this may not reflect the size of memory allocated for an
+ /// instance of the type or the number of bytes that are written when an
+ /// instance of the type is stored to memory. The DataLayout class provides
+ /// additional query functions to provide this information.
+ ///
+ TypeSize getPrimitiveSizeInBits() const {
+ return LLVMTy->getPrimitiveSizeInBits();
+ }
+
+ /// If this is a vector type, return the getPrimitiveSizeInBits value for the
+ /// element type. Otherwise return the getPrimitiveSizeInBits value for this
+ /// type.
+ unsigned getScalarSizeInBits() const { return LLVMTy->getScalarSizeInBits(); }
+
+ /// Return the width of the mantissa of this type. This is only valid on
+ /// floating-point types. If the FP type does not have a stable mantissa (e.g.
+ /// ppc long double), this method returns -1.
+ int getFPMantissaWidth() const { return LLVMTy->getFPMantissaWidth(); }
+
+ /// Return whether the type is IEEE compatible, as defined by the eponymous
+ /// method in APFloat.
+ bool isIEEE() const { return LLVMTy->isIEEE(); }
+
+ /// If this is a vector type, return the element type, otherwise return
+ /// 'this'.
+ Type *getScalarType() const;
+
+ // TODO: ADD MISSING
+
+ static Type *getInt64Ty(Context &Ctx);
+ static Type *getInt32Ty(Context &Ctx);
+ static Type *getInt16Ty(Context &Ctx);
+ static Type *getInt8Ty(Context &Ctx);
+ static Type *getInt1Ty(Context &Ctx);
+ static Type *getDoubleTy(Context &Ctx);
+ static Type *getFloatTy(Context &Ctx);
+ // TODO: missing get*
+
+ /// Get the address space of this pointer or pointer vector type.
+ inline unsigned getPointerAddressSpace() const {
+ return LLVMTy->getPointerAddressSpace();
+ }
+
+#ifndef NDEBUG
+ void dumpOS(raw_ostream &OS) { LLVMTy->print(OS); }
+ LLVM_DUMP_METHOD void dump() {
+ dumpOS(dbgs());
+ dbgs() << "\n";
+ }
+#endif // NDEBUG
+};
+
+class PointerType : public Type {
+public:
+ // TODO: add missing functions
+ static PointerType *get(Type *ElementType, unsigned AddressSpace);
+ static PointerType *get(Context &Ctx, unsigned AddressSpace);
+
+ static bool classof(const Type *From) {
+ return isa<llvm::PointerType>(From->LLVMTy);
+ }
+};
+
+class VectorType : public Type {
+public:
+ // TODO: add missing functions
+ static bool classof(const Type *From) {
+ return isa<llvm::VectorType>(From->LLVMTy);
+ }
+};
+
+class FunctionType : public Type {
+public:
+ // TODO: add missing functions
+ static bool classof(const Type *From) {
+ return isa<llvm::FunctionType>(From->LLVMTy);
+ }
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_SANDBOXIR_TYPE_H
diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt
index 6c0666b186b8a6..d94f0642ccc4a1 100644
--- a/llvm/lib/SandboxIR/CMakeLists.txt
+++ b/llvm/lib/SandboxIR/CMakeLists.txt
@@ -1,6 +1,7 @@
add_llvm_component_library(LLVMSandboxIR
SandboxIR.cpp
Tracker.cpp
+ Type.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b75424909f0835..921bab66e8f1cc 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -135,6 +135,8 @@ Value::user_iterator Value::user_begin() {
unsigned Value::getNumUses() const { return range_size(Val->users()); }
+Type *Value::getType() const { return Ctx.getType(Val->getType()); }
+
void Value::replaceUsesWithIf(
Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
assert(getType() == OtherV->getType() && "Can't replace with different type");
@@ -583,7 +585,8 @@ VAArgInst *VAArgInst::create(Value *List, Type *Ty, BBIterator WhereIt,
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
- auto *LLVMI = cast<llvm::VAArgInst>(Builder.CreateVAArg(List->Val, Ty, Name));
+ auto *LLVMI =
+ cast<llvm::VAArgInst>(Builder.CreateVAArg(List->Val, Ty->LLVMTy, Name));
return Ctx.createVAArgInst(LLVMI);
}
@@ -754,7 +757,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(BeforeIR);
auto *NewLI =
- Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name);
+ Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name);
auto *NewSBI = Ctx.createLoadInst(NewLI);
return NewSBI;
}
@@ -771,7 +774,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
auto *NewLI =
- Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name);
+ Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name);
auto *NewSBI = Ctx.createLoadInst(NewLI);
return NewSBI;
}
@@ -886,6 +889,11 @@ Value *ReturnInst::getReturnValue() const {
return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr;
}
+FunctionType *CallBase::getFunctionType() const {
+ return cast<FunctionType>(
+ Ctx.getType(cast<llvm::CallBase>(Val)->getFunctionType()));
+}
+
Value *CallBase::getCalledOperand() const {
return Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledOperand());
}
@@ -911,8 +919,9 @@ void CallBase::setCalledFunction(Function *F) {
// Note: This may break if `setCalledFunction()` early returns if `F`
// is already set, but we do have a unit test for it.
setCalledOperand(F);
- cast<llvm::CallBase>(Val)->setCalledFunction(F->getFunctionType(),
- cast<llvm::Function>(F->Val));
+ cast<llvm::CallBase>(Val)->setCalledFunction(
+ cast<llvm::FunctionType>(F->getFunctionType()->LLVMTy),
+ cast<llvm::Function>(F->Val));
}
CallInst *CallInst::create(FunctionType *FTy, Value *Func,
@@ -928,7 +937,8 @@ CallInst *CallInst::create(FunctionType *FTy, Value *Func,
LLVMArgs.reserve(Args.size());
for (Value *Arg : Args)
LLVMArgs.push_back(Arg->Val);
- llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr);
+ llvm::CallInst *NewCI = Builder.CreateCall(
+ cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val, LLVMArgs, NameStr);
return Ctx.createCallInst(NewCI);
}
@@ -961,7 +971,8 @@ InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
for (Value *Arg : Args)
LLVMArgs.push_back(Arg->Val);
llvm::InvokeInst *Invoke = Builder.CreateInvoke(
- FTy, Func->Val, cast<llvm::BasicBlock>(IfNormal->Val),
+ cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val,
+ cast<llvm::BasicBlock>(IfNormal->Val),
cast<llvm::BasicBlock>(IfException->Val), LLVMArgs, NameStr);
return Ctx.createInvokeInst(Invoke);
}
@@ -1032,9 +1043,10 @@ CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func,
for (Value *Arg : Args)
LLVMArgs.push_back(Arg->Val);
- llvm::CallBrInst *CallBr = Builder.CreateCallBr(
- FTy, Func->Val, cast<llvm::BasicBlock>(DefaultDest->Val),
- LLVMIndirectDests, LLVMArgs, NameStr);
+ llvm::CallBrInst *CallBr =
+ Builder.CreateCallBr(cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val,
+ cast<llvm::BasicBlock>(DefaultDest->Val),
+ LLVMIndirectDests, LLVMArgs, NameStr);
return Ctx.createCallBrInst(CallBr);
}
@@ -1107,7 +1119,7 @@ LandingPadInst *LandingPadInst::create(Type *RetTy, unsigned NumReservedClauses,
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
llvm::LandingPadInst *LLVMI =
- Builder.CreateLandingPad(RetTy, NumReservedClauses, Name);
+ Builder.CreateLandingPad(RetTy->LLVMTy, NumReservedClauses, Name);
return Ctx.createLandingPadInst(LLVMI);
}
@@ -1288,7 +1300,8 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
LLVMIdxList.reserve(IdxList.size());
for (Value *Idx : IdxList)
LLVMIdxList.push_back(Idx->Val);
- llvm::Value *NewV = Builder.CreateGEP(Ty, Ptr->Val, LLVMIdxList, NameStr);
+ llvm::Value *NewV =
+ Builder.CreateGEP(Ty->LLVMTy, Ptr->Val, LLVMIdxList, NameStr);
if (auto *NewGEP = dyn_cast<llvm::GetElementPtrInst>(NewV))
return Ctx.createGetElementPtrInst(NewGEP);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
@@ -1312,10 +1325,25 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
InsertAtEnd, Ctx, NameStr);
}
+Type *GetElementPtrInst::getSourceElementType() const {
+ return Ctx.getType(
+ cast<llvm::GetElementPtrInst>(Val)->getSourceElementType());
+}
+
+Type *GetElementPtrInst::getResultElementType() const {
+ return Ctx.getType(
+ cast<llvm::GetElementPtrInst>(Val)->getResultElementType());
+}
+
Value *GetElementPtrInst::getPointerOperand() const {
return Ctx.getValue(cast<llvm::GetElementPtrInst>(Val)->getPointerOperand());
}
+Type *GetElementPtrInst::getPointerOperandType() const {
+ return Ctx.getType(
+ cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType());
+}
+
BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
return cast<BasicBlock>(Ctx.getValue(LLVMBB));
}
@@ -1323,8 +1351,9 @@ BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
- llvm::PHINode *NewPHI = llvm::PHINode::Create(
- Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction());
+ llvm::PHINode *NewPHI =
+ llvm::PHINode::Create(Ty->LLVMTy, NumReservedValues, Name,
+ InsertBefore->getTopmostLLVMInstruction());
return Ctx.createPHINode(NewPHI);
}
@@ -1943,7 +1972,8 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace, BBIterator WhereIt,
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
else
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
- auto *NewAlloca = Builder.CreateAlloca(Ty, AddrSpace, ArraySize->Val, Name);
+ auto *NewAlloca =
+ Builder.CreateAlloca(Ty->LLVMTy, AddrSpace, ArraySize->Val, Name);
return Ctx.createAllocaInst(NewAlloca);
}
@@ -1961,11 +1991,15 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace,
Name);
}
+Type *AllocaInst::getAllocatedType() const {
+ return Ctx.getType(cast<llvm::AllocaInst>(Val)->getAllocatedType());
+}
+
void AllocaInst::setAllocatedType(Type *Ty) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AllocaInst::getAllocatedType,
&AllocaInst::setAllocatedType>>(this);
- cast<llvm::AllocaInst>(Val)->setAllocatedType(Ty);
+ cast<llvm::AllocaInst>(Val)->setAllocatedType(Ty->LLVMTy);
}
void AllocaInst::setAlignment(Align Align) {
@@ -1987,6 +2021,10 @@ Value *AllocaInst::getArraySize() {
return Ctx.getValue(cast<llvm::AllocaInst>(Val)->getArraySize());
}
+PointerType *AllocaInst::getType() const {
+ return cast<PointerType>(Ctx.getType(cast<llvm::AllocaInst>(Val)->getType()));
+}
+
Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
const Twine &Name) {
@@ -1997,7 +2035,7 @@ Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
else
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
auto *NewV =
- Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy, Name);
+ Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy->LLVMTy, Name);
if (auto *NewCI = dyn_cast<llvm::CastInst>(NewV))
return Ctx.createCastInst(NewCI);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
@@ -2022,6 +2060,14 @@ bool CastInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Cast;
}
+Type *CastInst::getSrcTy() const {
+ return Ctx.getType(cast<llvm::CastInst>(Val)->getSrcTy());
+}
+
+Type *CastInst::getDestTy() const {
+ return Ctx.getType(cast<llvm::CastInst>(Val)->getDestTy());
+}
+
void PossiblyNonNegInst::setNonNeg(bool B) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&PossiblyNonNegInst::hasNonNeg,
@@ -2134,15 +2180,25 @@ void ShuffleVectorInst::setShuffleMask(ArrayRef<int> Mask) {
cast<llvm::ShuffleVectorInst>(Val)->setShuffleMask(Mask);
}
+VectorType *ShuffleVectorInst::getType() const {
+ return cast<VectorType>(
+ Ctx.getType(cast<llvm::ShuffleVectorInst>(Val)->getType()));
+}
+
Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const {
return Ctx.getOrCreateConstant(
cast<llvm::ShuffleVectorInst>(Val)->getShuffleMaskForBitcode());
}
-Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(
- llvm::ArrayRef<int> Mask, llvm::Type *ResultTy, Context &Ctx) {
- return Ctx.getOrCreateConstant(
- llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy));
+Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef<int> Mask,
+ Type *ResultTy) {
+ return ResultTy->getContext().getOrCreateConstant(
+ llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask,
+ ResultTy->LLVMTy));
+}
+
+VectorType *ExtractElementInst::getVectorOperandType() const {
+ return cast<VectorType>(Ctx.getType(getVectorOperand()->getType()->LLVMTy));
}
#ifndef NDEBUG
@@ -2152,10 +2208,14 @@ void Constant::dumpOS(raw_ostream &OS) const {
}
#endif // NDEBUG
-ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, Context &Ctx,
- bool IsSigned) {
- auto *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
- return cast<ConstantInt>(Ctx.getOrCreateConstant(LLVMC));
+ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, bool IsSigned) {
+ auto *LLVMC = llvm::ConstantInt::get(Ty->LLVMTy, V, IsSigned);
+ return cast<ConstantInt>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+
+FunctionType *Function::getFunctionType() const {
+ return cast<FunctionType>(
+ Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
}
#ifndef NDEBUG
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
new file mode 100644
index 00000000000000..6f850b82d2e996
--- /dev/null
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -0,0 +1,48 @@
+//===- Type.cpp - Sandbox IR Type -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/SandboxIR/Type.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+
+using namespace llvm::sandboxir;
+
+Type *Type::getScalarType() const {
+ return Ctx.getType(LLVMTy->getScalarType());
+}
+
+Type *Type::getInt64Ty(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getInt64Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt32Ty(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getInt32Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt16Ty(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getInt16Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt8Ty(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getInt8Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt1Ty(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getInt1Ty(Ctx.LLVMCtx));
+}
+Type *Type::getDoubleTy(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getDoubleTy(Ctx.LLVMCtx));
+}
+Type *Type::getFloatTy(Context &Ctx) {
+ return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx));
+}
+
+PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) {
+ return cast<PointerType>(ElementType->getContext().getType(
+ llvm::PointerType::get(ElementType->LLVMTy, AddressSpace)));
+}
+
+PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
+ return cast<PointerType>(
+ Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
+}
diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt
index 3f43f6337b919b..2da936bffa02bf 100644
--- a/llvm/unittests/SandboxIR/CMakeLists.txt
+++ b/llvm/unittests/SandboxIR/CMakeLists.txt
@@ -7,4 +7,5 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(SandboxIRTests
SandboxIRTest.cpp
TrackerTest.cpp
+ TypesTest.cpp
)
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bc3fddf9e163dc..d1217f56ec29e7 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -121,10 +121,12 @@ define void @foo(i32 %v0) {
auto *FortyTwo = cast<sandboxir::ConstantInt>(Add0->getOperand(1));
// Check that creating an identical constant gives us the same object.
- auto *NewCI = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+ auto *NewCI =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
EXPECT_EQ(NewCI, FortyTwo);
// Check new constant.
- auto *FortyThree = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 43, Ctx);
+ auto *FortyThree =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 43);
EXPECT_NE(FortyThree, FortyTwo);
}
@@ -603,7 +605,7 @@ define void @foo(ptr %va) {
EXPECT_EQ(sandboxir::VAArgInst::getPointerOperandIndex(),
llvm::VAArgInst::getPointerOperandIndex());
// Check create().
- auto *NewVATy = Type::getInt8Ty(C);
+ auto *NewVATy = sandboxir::Type::getInt8Ty(Ctx);
auto *NewVA = sandboxir::VAArgInst::create(Arg, NewVATy, Ret->getIterator(),
Ret->getParent(), Ctx, "NewVA");
EXPECT_EQ(NewVA->getNextNode(), Ret);
@@ -743,10 +745,10 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
}
{
// Check SelectInst::create() Folded.
- auto *False = sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 0, Ctx,
- /*IsSigned=*/false);
+ auto *False = sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx),
+ 0, /*IsSigned=*/false);
auto *FortyTwo =
- sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 42, Ctx,
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx), 42,
/*IsSigned=*/false);
auto *NewSel =
sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
@@ -838,7 +840,7 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
auto *LLVMArg0 = LLVMF.getArg(0);
auto *LLVMArgVec = LLVMF.getArg(2);
- auto *Zero = sandboxir::ConstantInt::get(Type::getInt8Ty(C), 0, Ctx);
+ auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0);
EXPECT_EQ(
sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero),
@@ -950,7 +952,7 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
// convertShuffleMaskForBitcode
{
auto *C = sandboxir::ShuffleVectorInst::convertShuffleMaskForBitcode(
- ArrayRef<int>({2, 3}), ArgV1->getType(), Ctx);
+ ArrayRef<int>({2, 3}), ArgV1->getType());
SmallVector<int, 2> Result;
sandboxir::ShuffleVectorInst::getShuffleMask(C, Result);
EXPECT_THAT(Result, testing::ElementsAre(2, 3));
@@ -1607,7 +1609,8 @@ define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) {
// Check classof(Value *).
EXPECT_TRUE(isa<sandboxir::CallBase>((sandboxir::Value *)Call));
// Check getFunctionType().
- EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType());
+ EXPECT_EQ(Call->getFunctionType(),
+ Ctx.getType(LLVMCall->getFunctionType()));
// Check data_ops().
EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops()));
auto DataOpIt = Call->data_operands_begin();
@@ -1738,7 +1741,7 @@ define i8 @foo(i8 %arg) {
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
EXPECT_EQ(Call->getNumOperands(), 2u);
EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
- FunctionType *FTy = F.getFunctionType();
+ sandboxir::FunctionType *FTy = F.getFunctionType();
SmallVector<sandboxir::Value *, 1> Args;
Args.push_back(Arg0);
{
@@ -2027,8 +2030,8 @@ define void @foo() {
auto *BBRet = &*BB->begin();
auto *NewLPad =
cast<sandboxir::LandingPadInst>(sandboxir::LandingPadInst::create(
- Type::getInt8Ty(C), 0, BBRet->getIterator(), BBRet->getParent(), Ctx,
- "NewLPad"));
+ sandboxir::Type::getInt8Ty(Ctx), 0, BBRet->getIterator(),
+ BBRet->getParent(), Ctx, "NewLPad"));
EXPECT_EQ(NewLPad->getNextNode(), BBRet);
EXPECT_FALSE(NewLPad->isCleanup());
#ifndef NDEBUG
@@ -2287,9 +2290,11 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
// Check classof().
auto *GEP = cast<sandboxir::GetElementPtrInst>(Ctx.getValue(LLVMGEP));
// Check getSourceElementType().
- EXPECT_EQ(GEP->getSourceElementType(), LLVMGEP->getSourceElementType());
+ EXPECT_EQ(GEP->getSourceElementType(),
+ Ctx.getType(LLVMGEP->getSourceElementType()));
// Check getResultElementType().
- EXPECT_EQ(GEP->getResultElementType(), LLVMGEP->getResultElementType());
+ EXPECT_EQ(GEP->getResultElementType(),
+ Ctx.getType(LLVMGEP->getResultElementType()));
// Check getAddressSpace().
EXPECT_EQ(GEP->getAddressSpace(), LLVMGEP->getAddressSpace());
// Check indices().
@@ -2305,7 +2310,8 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
// Check getPointerOperandIndex().
EXPECT_EQ(GEP->getPointerOperandIndex(), LLVMGEP->getPointerOperandIndex());
// Check getPointerOperandType().
- EXPECT_EQ(GEP->getPointerOperandType(), LLVMGEP->getPointerOperandType());
+ EXPECT_EQ(GEP->getPointerOperandType(),
+ Ctx.getType(LLVMGEP->getPointerOperandType()));
// Check getPointerAddressSpace().
EXPECT_EQ(GEP->getPointerAddressSpace(), LLVMGEP->getPointerAddressSpace());
// Check getNumIndices().
@@ -2666,8 +2672,8 @@ define void @foo(i32 %cond0, i32 %cond1) {
Switch->setSuccessor(0, OrigSucc);
EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
// Check case_begin(), case_end(), CaseIt.
- auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
- auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+ auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+ auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1);
auto CaseIt = Switch->case_begin();
{
sandboxir::SwitchInst::CaseHandle Case = *CaseIt++;
@@ -2704,7 +2710,7 @@ define void @foo(i32 %cond0, i32 %cond1) {
EXPECT_EQ(Switch->findCaseDest(BB1), One);
EXPECT_EQ(Switch->findCaseDest(Entry), nullptr);
// Check addCase().
- auto *Two = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 2, Ctx);
+ auto *Two = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 2);
Switch->addCase(Two, Entry);
auto CaseTwoIt = Switch->findCaseValue(Two);
auto CaseTwo = *CaseTwoIt;
@@ -2969,7 +2975,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
}
{
// Check create() when it gets folded.
- auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+ auto *FortyTwo =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
auto *NewV = sandboxir::BinaryOperator::create(
sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo,
/*InsertBefore=*/Ret, Ctx, "Folded");
@@ -3025,7 +3032,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
}
{
// Check createWithCopiedFlags() when it gets folded.
- auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+ auto *FortyTwo =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
auto *NewV = sandboxir::BinaryOperator::createWithCopiedFlags(
sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo, CopyFrom,
/*InsertBefore=*/Ret, Ctx, "Folded");
@@ -3447,8 +3455,8 @@ define void @foo() {
EXPECT_EQ(AllocaArray->getArraySize(),
Ctx.getValue(LLVMAllocaArray->getArraySize()));
// Check getType().
- EXPECT_EQ(AllocaScalar->getType(), LLVMAllocaScalar->getType());
- EXPECT_EQ(AllocaArray->getType(), LLVMAllocaArray->getType());
+ EXPECT_EQ(AllocaScalar->getType(), Ctx.getType(LLVMAllocaScalar->getType()));
+ EXPECT_EQ(AllocaArray->getType(), Ctx.getType(LLVMAllocaArray->getType()));
// Check getAddressSpace().
EXPECT_EQ(AllocaScalar->getAddressSpace(),
LLVMAllocaScalar->getAddressSpace());
@@ -3465,12 +3473,12 @@ define void @foo() {
LLVMAllocaArray->getAllocationSizeInBits(DL));
// Check getAllocatedType().
EXPECT_EQ(AllocaScalar->getAllocatedType(),
- LLVMAllocaScalar->getAllocatedType());
+ Ctx.getType(LLVMAllocaScalar->getAllocatedType()));
EXPECT_EQ(AllocaArray->getAllocatedType(),
- LLVMAllocaArray->getAllocatedType());
+ Ctx.getType(LLVMAllocaArray->getAllocatedType()));
// Check setAllocatedType().
auto *OrigType = AllocaScalar->getAllocatedType();
- auto *NewType = PointerType::get(C, 0);
+ auto *NewType = sandboxir::PointerType::get(Ctx, 0);
EXPECT_NE(NewType, OrigType);
AllocaScalar->setAllocatedType(NewType);
EXPECT_EQ(AllocaScalar->getAllocatedType(), NewType);
@@ -3501,10 +3509,10 @@ define void @foo() {
AllocaScalar->setUsedWithInAlloca(OrigUsedWithInAlloca);
EXPECT_EQ(AllocaScalar->isUsedWithInAlloca(), OrigUsedWithInAlloca);
- auto *Ty = Type::getInt32Ty(C);
+ auto *Ty = sandboxir::Type::getInt32Ty(Ctx);
unsigned AddrSpace = 42;
- auto *PtrTy = PointerType::get(C, AddrSpace);
- auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43, Ctx);
+ auto *PtrTy = sandboxir::PointerType::get(Ctx, AddrSpace);
+ auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43);
{
// Check create() WhereIt, WhereBB.
auto *NewI = cast<sandboxir::AllocaInst>(sandboxir::AllocaInst::create(
@@ -3581,13 +3589,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) {
auto *BB = &*F->begin();
auto It = BB->begin();
- Type *Ti64 = Type::getInt64Ty(C);
- Type *Ti32 = Type::getInt32Ty(C);
- Type *Ti16 = Type::getInt16Ty(C);
- Type *Tdouble = Type::getDoubleTy(C);
- Type *Tfloat = Type::getFloatTy(C);
- Type *Tptr = Tfloat->getPointerTo();
- Type *Tptr1 = Tfloat->getPointerTo(1);
+ auto *Ti64 = sandboxir::Type::getInt64Ty(Ctx);
+ auto *Ti32 = sandboxir::Type::getInt32Ty(Ctx);
+ auto *Ti16 = sandboxir::Type::getInt16Ty(Ctx);
+ auto *Tdouble = sandboxir::Type::getDoubleTy(Ctx);
+ auto *Tfloat = sandboxir::Type::getFloatTy(Ctx);
+ auto *Tptr = sandboxir::PointerType::get(Tfloat, 0);
+ auto *Tptr1 = sandboxir::PointerType::get(Tfloat, 1);
// Check classof(), getOpcode(), getSrcTy(), getDstTy()
auto *ZExt = cast<sandboxir::CastInst>(&*It++);
@@ -3799,10 +3807,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) {
/// CastInst's subclasses are very similar so we can use a common test function
/// for them.
template <typename SubclassT, sandboxir::Instruction::Opcode OpcodeT>
-void testCastInst(llvm::Module &M, Type *SrcTy, Type *DstTy) {
+void testCastInst(llvm::Module &M, llvm::Type *LLVMSrcTy,
+ llvm::Type *LLVMDstTy) {
Function &LLVMF = *M.getFunction("foo");
sandboxir::Context Ctx(M.getContext());
sandboxir::Function *F = Ctx.createFunction(&LLVMF);
+ sandboxir::Type *SrcTy = Ctx.getType(LLVMSrcTy);
+ sandboxir::Type *DstTy = Ctx.getType(LLVMDstTy);
unsigned ArgIdx = 0;
auto *Arg = F->getArg(ArgIdx++);
auto *BB = &*F->begin();
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index ca6effb727bf37..c189100fbd6947 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -938,9 +938,10 @@ define void @foo(i32 %cond0, i32 %cond1) {
Ctx.revert();
EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
// Check addCase().
- auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
- auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
- auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+ auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+ auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1);
+ auto *FortyTwo =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
Ctx.save();
Switch->addCase(FortyTwo, Entry);
EXPECT_EQ(Switch->getNumCases(), 3u);
@@ -1187,7 +1188,7 @@ define void @foo(i8 %arg) {
// Check setAllocatedType().
Ctx.save();
auto *OrigTy = Alloca->getAllocatedType();
- auto *NewTy = Type::getInt64Ty(C);
+ auto *NewTy = sandboxir::Type::getInt64Ty(Ctx);
EXPECT_NE(NewTy, OrigTy);
Alloca->setAllocatedType(NewTy);
EXPECT_EQ(Alloca->getAllocatedType(), NewTy);
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
new file mode 100644
index 00000000000000..d72725883a7dd7
--- /dev/null
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -0,0 +1,242 @@
+//===- TypesTest.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct SandboxTypeTest : public testing::Test {
+ LLVMContext C;
+ std::unique_ptr<Module> M;
+
+ void parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ M = parseAssemblyString(IR, Err, C);
+ if (!M)
+ Err.print("SandboxTypeTest", errs());
+ }
+ BasicBlock *getBasicBlockByName(Function &F, StringRef Name) {
+ for (BasicBlock &BB : F)
+ if (BB.getName() == Name)
+ return &BB;
+ llvm_unreachable("Expected to find basic block!");
+ }
+};
+
+TEST_F(SandboxTypeTest, Type) {
+ parseIR(C, R"IR(
+define void @foo(i32 %v0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ sandboxir::Type *I32Ty = F->getArg(0)->getType();
+
+ auto *LLVMInt32Ty = llvm::Type::getInt32Ty(C);
+ auto *LLVMFloatTy = llvm::Type::getFloatTy(C);
+ auto *LLVMInt8Ty = llvm::Type::getInt8Ty(C);
+
+ auto *Int32Ty = Ctx.getType(LLVMInt32Ty);
+ auto *FloatTy = Ctx.getType(LLVMFloatTy);
+
+ // Check getContext().
+ EXPECT_EQ(&I32Ty->getContext(), &Ctx);
+
+#define CHK(LLVMCreate, SBCheck) \
+ Ctx.getType(llvm::Type::LLVMCreate(C))->SBCheck()
+ // Check isVoidTy().
+ EXPECT_TRUE(Ctx.getType(llvm::Type::getVoidTy(C))->isVoidTy());
+ EXPECT_TRUE(CHK(getVoidTy, isVoidTy));
+ // Check isHalfTy().
+ EXPECT_TRUE(CHK(getHalfTy, isHalfTy));
+ // Check isBFloatTy().
+ EXPECT_TRUE(CHK(getBFloatTy, isBFloatTy));
+ // Check is16bitFPTy().
+ EXPECT_TRUE(CHK(getHalfTy, is16bitFPTy));
+ // Check isFloatTy().
+ EXPECT_TRUE(CHK(getFloatTy, isFloatTy));
+ // Check isDoubleTy().
+ EXPECT_TRUE(CHK(getDoubleTy, isDoubleTy));
+ // Check isX86_FP80Ty().
+ EXPECT_TRUE(CHK(getX86_FP80Ty, isX86_FP80Ty));
+ // Check isFP128Ty().
+ EXPECT_TRUE(CHK(getFP128Ty, isFP128Ty));
+ // Check isPPC_FP128Ty().
+ EXPECT_TRUE(CHK(getPPC_FP128Ty, isPPC_FP128Ty));
+ // Check isIEEELikeFPTy().
+ EXPECT_TRUE(CHK(getFloatTy, isIEEELikeFPTy));
+ // Check isFloatingPointTy().
+ EXPECT_TRUE(CHK(getFloatTy, isFloatingPointTy));
+ EXPECT_TRUE(CHK(getDoubleTy, isFloatingPointTy));
+ // Check isMultiUnitFPType().
+ EXPECT_TRUE(CHK(getPPC_FP128Ty, isMultiUnitFPType));
+ EXPECT_FALSE(CHK(getFloatTy, isMultiUnitFPType));
+ // Check getFltSemantics().
+ EXPECT_EQ(&sandboxir::Type::getFloatTy(Ctx)->getFltSemantics(),
+ &llvm::Type::getFloatTy(C)->getFltSemantics());
+ // Check isX86_AMXTy().
+ EXPECT_TRUE(CHK(getX86_AMXTy, isX86_AMXTy));
+ // Check isTargetExtTy().
+ EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "foo"))->isTargetExtTy());
+ // Check isScalableTargetExtTy().
+ EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "aarch64.svcount"))
+ ->isScalableTargetExtTy());
+ // Check isScalableTy().
+ EXPECT_TRUE(Ctx.getType(llvm::ScalableVectorType::get(LLVMInt32Ty, 2u))
+ ->isScalableTy());
+ // Check isFPOrFPVectorTy().
+ EXPECT_TRUE(CHK(getFloatTy, isFPOrFPVectorTy));
+ EXPECT_FALSE(CHK(getInt32Ty, isFPOrFPVectorTy));
+ // Check isLabelTy().
+ EXPECT_TRUE(CHK(getLabelTy, isLabelTy));
+ // Check isMetadataTy().
+ EXPECT_TRUE(CHK(getMetadataTy, isMetadataTy));
+ // Check isTokenTy().
+ EXPECT_TRUE(CHK(getTokenTy, isTokenTy));
+ // Check isIntegerTy().
+ EXPECT_TRUE(CHK(getInt32Ty, isIntegerTy));
+ EXPECT_FALSE(CHK(getFloatTy, isIntegerTy));
+ // Check isIntegerTy(Bitwidth).
+ EXPECT_TRUE(LLVMInt32Ty->isIntegerTy(32u));
+ EXPECT_FALSE(LLVMInt32Ty->isIntegerTy(31u));
+ EXPECT_FALSE(Ctx.getType(llvm::Type::getFloatTy(C))->isIntegerTy(32u));
+ // Check isIntOrIntVectorTy().
+ EXPECT_TRUE(LLVMInt32Ty->isIntOrIntVectorTy());
+ EXPECT_TRUE(Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8))
+ ->isIntOrIntVectorTy());
+ EXPECT_FALSE(Ctx.getType(LLVMFloatTy)->isIntOrIntVectorTy());
+ EXPECT_FALSE(Ctx.getType(llvm::FixedVectorType::get(LLVMFloatTy, 8))
+ ->isIntOrIntVectorTy());
+ // Check isIntOrPtrTy().
+ EXPECT_TRUE(Int32Ty->isIntOrPtrTy());
+ EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isIntOrPtrTy());
+ EXPECT_FALSE(FloatTy->isIntOrPtrTy());
+ // Check isFunctionTy().
+ EXPECT_TRUE(Ctx.getType(llvm::FunctionType::get(LLVMInt32Ty, {}, false))
+ ->isFunctionTy());
+ // Check isStructTy().
+ EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isStructTy());
+ // Check isArrayTy().
+ EXPECT_TRUE(Ctx.getType(llvm::ArrayType::get(LLVMInt32Ty, 10))->isArrayTy());
+ // Check isPointerTy().
+ EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isPointerTy());
+ // Check isPtrOrPtrVectroTy().
+ EXPECT_TRUE(
+ Ctx.getType(llvm::FixedVectorType::get(llvm::PointerType::get(C, 0u), 8u))
+ ->isPtrOrPtrVectorTy());
+ // Check isVectorTy().
+ EXPECT_TRUE(
+ Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->isVectorTy());
+ // Check canLosslesslyBitCastTo().
+ auto *VecTy32x4 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 4u));
+ auto *VecTy32x2 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 2u));
+ auto *VecTy8x16 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt8Ty, 16u));
+ EXPECT_TRUE(VecTy32x4->canLosslesslyBitCastTo(VecTy8x16));
+ EXPECT_FALSE(VecTy32x4->canLosslesslyBitCastTo(VecTy32x2));
+ // Check isEmptyTy().
+ EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isEmptyTy());
+ // Check isFirstClassType().
+ EXPECT_TRUE(Int32Ty->isFirstClassType());
+ // Check isSingleValueType().
+ EXPECT_TRUE(Int32Ty->isSingleValueType());
+ // Check isAggregateType().
+ EXPECT_FALSE(Int32Ty->isAggregateType());
+ // Check isSized().
+ SmallPtrSet<sandboxir::Type *, 1> Visited;
+ EXPECT_TRUE(Int32Ty->isSized(&Visited));
+ // Check getPrimitiveSizeInBits().
+ EXPECT_EQ(VecTy32x2->getPrimitiveSizeInBits(), 32u * 2);
+ // Check getScalarSizeInBits().
+ EXPECT_EQ(VecTy32x2->getScalarSizeInBits(), 32u);
+ // Check getFPMantissaWidth().
+ EXPECT_EQ(FloatTy->getFPMantissaWidth(), LLVMFloatTy->getFPMantissaWidth());
+ // Check isIEEE().
+ EXPECT_EQ(FloatTy->isIEEE(), LLVMFloatTy->isIEEE());
+ // Check getScalarType().
+ EXPECT_EQ(
+ Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->getScalarType(),
+ Int32Ty);
+
+#define CHK_GET(TY) \
+ EXPECT_EQ(Ctx.getType(llvm::Type::get##TY##Ty(C)), \
+ sandboxir::Type::get##TY##Ty(Ctx))
+ // Check getInt64Ty().
+ CHK_GET(Int64);
+ // Check getInt32Ty().
+ CHK_GET(Int32);
+ // Check getInt16Ty().
+ CHK_GET(Int16);
+ // Check getInt8Ty().
+ CHK_GET(Int8);
+ // Check getInt1Ty().
+ CHK_GET(Int1);
+ // Check getDoubleTy().
+ CHK_GET(Double);
+ // Check getFloatTy().
+ CHK_GET(Float);
+}
+
+TEST_F(SandboxTypeTest, PointerType) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation.
+ auto *PtrTy = cast<sandboxir::PointerType>(F->getArg(0)->getType());
+ // Check get(ElementType, AddressSpace).
+ auto *NewPtrTy =
+ sandboxir::PointerType::get(sandboxir::Type::getInt32Ty(Ctx), 0u);
+ EXPECT_EQ(NewPtrTy, PtrTy);
+ // Check get(Ctx, AddressSpace).
+ auto *NewPtrTy2 = sandboxir::PointerType::get(Ctx, 0u);
+ EXPECT_EQ(NewPtrTy2, PtrTy);
+}
+
+TEST_F(SandboxTypeTest, VectorType) {
+ parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation.
+ [[maybe_unused]] auto *VecTy =
+ cast<sandboxir::VectorType>(F->getArg(0)->getType());
+}
+
+TEST_F(SandboxTypeTest, FunctionType) {
+ parseIR(C, R"IR(
+define void @foo() {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation.
+ [[maybe_unused]] auto *FTy =
+ cast<sandboxir::FunctionType>(F->getFunctionType());
+}
More information about the llvm-commits
mailing list