[llvm] [SandboxIR] Implement SandboxIR Type (PR #106294)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 27 14:26:09 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/106294

This patch implements sandboxir::Type, a thin wrapper of llvm::Type. This is designed very similarly to sandbox::Value. Context owns all sandboxir::Type objects and maintains a map between llvm::Type and sandboxir::Type.

>From f0d6392f70f6d81cf80253678032890df06898b9 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 19 Aug 2024 11:26:19 -0700
Subject: [PATCH] [SandboxIR] Implement SandboxIR Type

This patch implements sandboxir::Type, a thin wrapper of llvm::Type.
This is designed very similarly to sandbox::Value.
Context owns all sandboxir::Type objects and maintains a map between
llvm::Type and sandboxir::Type.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h    |  71 +++---
 llvm/include/llvm/SandboxIR/Type.h         | 282 +++++++++++++++++++++
 llvm/lib/SandboxIR/CMakeLists.txt          |   1 +
 llvm/lib/SandboxIR/SandboxIR.cpp           | 110 ++++++--
 llvm/lib/SandboxIR/Type.cpp                |  48 ++++
 llvm/unittests/SandboxIR/CMakeLists.txt    |   1 +
 llvm/unittests/SandboxIR/SandboxIRTest.cpp |  83 +++---
 llvm/unittests/SandboxIR/TrackerTest.cpp   |   9 +-
 llvm/unittests/SandboxIR/TypesTest.cpp     | 242 ++++++++++++++++++
 9 files changed, 747 insertions(+), 100 deletions(-)
 create mode 100644 llvm/include/llvm/SandboxIR/Type.h
 create mode 100644 llvm/lib/SandboxIR/Type.cpp
 create mode 100644 llvm/unittests/SandboxIR/TypesTest.cpp

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index b7bdf9acd2ef45..f9581c0dc299e3 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -98,6 +98,7 @@
 #include "llvm/IR/User.h"
 #include "llvm/IR/Value.h"
 #include "llvm/SandboxIR/Tracker.h"
+#include "llvm/SandboxIR/Type.h"
 #include "llvm/SandboxIR/Use.h"
 #include "llvm/Support/raw_ostream.h"
 #include <iterator>
@@ -378,7 +379,7 @@ class Value {
     return Cnt == Num;
   }
 
-  Type *getType() const { return Val->getType(); }
+  Type *getType() const;
 
   Context &getContext() const { return Ctx; }
 
@@ -566,8 +567,7 @@ class ConstantInt : public Constant {
 public:
   /// If Ty is a vector type, return a Constant with a splat of the given
   /// value. Otherwise return a ConstantInt for the given value.
-  static ConstantInt *get(Type *Ty, uint64_t V, Context &Ctx,
-                          bool IsSigned = false);
+  static ConstantInt *get(Type *Ty, uint64_t V, bool IsSigned = false);
 
   // TODO: Implement missing functions.
 
@@ -1014,10 +1014,7 @@ class ExtractElementInst final
   Value *getIndexOperand() { return getOperand(1); }
   const Value *getVectorOperand() const { return getOperand(0); }
   const Value *getIndexOperand() const { return getOperand(1); }
-
-  VectorType *getVectorOperandType() const {
-    return cast<VectorType>(getVectorOperand()->getType());
-  }
+  VectorType *getVectorOperandType() const;
 };
 
 class ShuffleVectorInst final
@@ -1062,9 +1059,7 @@ class ShuffleVectorInst final
   }
 
   /// Overload to return most specific vector type.
-  VectorType *getType() const {
-    return cast<llvm::ShuffleVectorInst>(Val)->getType();
-  }
+  VectorType *getType() const;
 
   /// Return the shuffle mask value of this instruction for the given element
   /// index. Return PoisonMaskElem if the element is undef.
@@ -1090,7 +1085,7 @@ class ShuffleVectorInst final
   Constant *getShuffleMaskForBitcode() const;
 
   static Constant *convertShuffleMaskForBitcode(ArrayRef<int> Mask,
-                                                Type *ResultTy, Context &Ctx);
+                                                Type *ResultTy);
 
   void setShuffleMask(ArrayRef<int> Mask);
 
@@ -1713,9 +1708,7 @@ class CallBase : public SingleLLVMInstructionImpl<llvm::CallBase> {
            Opc == Instruction::ClassID::CallBr;
   }
 
-  FunctionType *getFunctionType() const {
-    return cast<llvm::CallBase>(Val)->getFunctionType();
-  }
+  FunctionType *getFunctionType() const;
 
   op_iterator data_operands_begin() { return op_begin(); }
   const_op_iterator data_operands_begin() const {
@@ -2131,12 +2124,8 @@ class GetElementPtrInst final
     return From->getSubclassID() == ClassID::GetElementPtr;
   }
 
-  Type *getSourceElementType() const {
-    return cast<llvm::GetElementPtrInst>(Val)->getSourceElementType();
-  }
-  Type *getResultElementType() const {
-    return cast<llvm::GetElementPtrInst>(Val)->getResultElementType();
-  }
+  Type *getSourceElementType() const;
+  Type *getResultElementType() const;
   unsigned getAddressSpace() const {
     return cast<llvm::GetElementPtrInst>(Val)->getAddressSpace();
   }
@@ -2160,9 +2149,7 @@ class GetElementPtrInst final
   static unsigned getPointerOperandIndex() {
     return llvm::GetElementPtrInst::getPointerOperandIndex();
   }
-  Type *getPointerOperandType() const {
-    return cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType();
-  }
+  Type *getPointerOperandType() const;
   unsigned getPointerAddressSpace() const {
     return cast<llvm::GetElementPtrInst>(Val)->getPointerAddressSpace();
   }
@@ -2709,9 +2696,7 @@ class AllocaInst final : public UnaryInstruction {
     return const_cast<AllocaInst *>(this)->getArraySize();
   }
   /// Overload to return most specific pointer type.
-  PointerType *getType() const {
-    return cast<llvm::AllocaInst>(Val)->getType();
-  }
+  PointerType *getType() const;
   /// Return the address space for the allocation.
   unsigned getAddressSpace() const {
     return cast<llvm::AllocaInst>(Val)->getAddressSpace();
@@ -2727,9 +2712,7 @@ class AllocaInst final : public UnaryInstruction {
     return cast<llvm::AllocaInst>(Val)->getAllocationSizeInBits(DL);
   }
   /// Return the type that is being allocated by the instruction.
-  Type *getAllocatedType() const {
-    return cast<llvm::AllocaInst>(Val)->getAllocatedType();
-  }
+  Type *getAllocatedType() const;
   /// for use only in special circumstances that need to generically
   /// transform a whole instruction (eg: IR linking and vectorization).
   void setAllocatedType(Type *Ty);
@@ -2811,8 +2794,8 @@ class CastInst : public UnaryInstruction {
                        const Twine &Name = "");
   /// For isa/dyn_cast.
   static bool classof(const Value *From);
-  Type *getSrcTy() const { return cast<llvm::CastInst>(Val)->getSrcTy(); }
-  Type *getDestTy() const { return cast<llvm::CastInst>(Val)->getDestTy(); }
+  Type *getSrcTy() const;
+  Type *getDestTy() const;
 };
 
 /// Instruction that can have a nneg flag (zext/uitofp).
@@ -2992,6 +2975,8 @@ class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
 class Context {
 protected:
   LLVMContext &LLVMCtx;
+  friend class Type;        // For LLVMCtx.
+  friend class PointerType; // For LLVMCtx.
   Tracker IRTracker;
 
   /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
@@ -2999,6 +2984,16 @@ class Context {
   DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
       LLVMValueToValueMap;
 
+  /// Type has a protected destructor to prohibit the user from managing the
+  /// lifetime of the Type objects. Context is friend of Type, and this custom
+  /// deleter can destroy Type.
+  struct TypeDeleter {
+    void operator()(Type *Ty) { delete Ty; }
+  };
+  /// Maps LLVM Type to the corresonding sandboxir::Type. Owns all Sandbox IR
+  /// Type objects.
+  DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
+
   /// Remove \p V from the maps and returns the unique_ptr.
   std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
   /// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -3033,7 +3028,6 @@ class Context {
   /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
   /// also create all contents of the block.
   BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
-
   friend class BasicBlock; // For getOrCreateValue().
 
   IRBuilder<ConstantFolder> LLVMIRBuilder;
@@ -3119,6 +3113,15 @@ class Context {
   const sandboxir::Value *getValue(const llvm::Value *V) const {
     return getValue(const_cast<llvm::Value *>(V));
   }
+
+  Type *getType(llvm::Type *LLVMTy) {
+    auto Pair = LLVMTypeToTypeMap.insert({LLVMTy, nullptr});
+    auto It = Pair.first;
+    if (Pair.second)
+      It->second = std::unique_ptr<Type, TypeDeleter>(new Type(LLVMTy, *this));
+    return It->second.get();
+  }
+
   /// Create a sandboxir::Function for an existing LLVM IR \p F, including all
   /// blocks and instructions.
   /// This is the main API function for creating Sandbox IR.
@@ -3165,9 +3168,7 @@ class Function : public Constant {
     LLVMBBToBB BBGetter(Ctx);
     return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
   }
-  FunctionType *getFunctionType() const {
-    return cast<llvm::Function>(Val)->getFunctionType();
-  }
+  FunctionType *getFunctionType() const;
 
 #ifndef NDEBUG
   void verify() const final {
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
new file mode 100644
index 00000000000000..c002f3b400eb73
--- /dev/null
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -0,0 +1,282 @@
+//===- llvm/SandboxIR/Type.h - Classes for handling data types --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a thin wrapper over llvm::Type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SANDBOXIR_TYPE_H
+#define LLVM_SANDBOXIR_TYPE_H
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm::sandboxir {
+
+class Context;
+
+/// Just like llvm::Type these are immutable, unique, never get freed and can
+/// only be created via static factory methods.
+class Type {
+protected:
+  llvm::Type *LLVMTy;
+  friend class VectorType;   // For LLVMTy.
+  friend class PointerType;  // For LLVMTy.
+  friend class FunctionType; // For LLVMTy.
+  friend class Function;     // For LLVMTy.
+  friend class CallBase;     // For LLVMTy.
+  friend class ConstantInt;  // For LLVMTy.
+  // Friend all instruction classes because `create()` functions use LLVMTy.
+#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
+  // TODO: Friend DEF_CONST()
+#include "llvm/SandboxIR/SandboxIRValues.def"
+  Context &Ctx;
+
+  Type(llvm::Type *LLVMTy, Context &Ctx) : LLVMTy(LLVMTy), Ctx(Ctx) {}
+  friend class Context; // For constructor and ~Type().
+  ~Type() = default;
+
+public:
+  Context &getContext() const { return Ctx; }
+
+  /// Return true if this is 'void'.
+  bool isVoidTy() const { return LLVMTy->isVoidTy(); }
+
+  /// Return true if this is 'half', a 16-bit IEEE fp type.
+  bool isHalfTy() const { return LLVMTy->isHalfTy(); }
+
+  /// Return true if this is 'bfloat', a 16-bit bfloat type.
+  bool isBFloatTy() const { return LLVMTy->isBFloatTy(); }
+
+  /// Return true if this is a 16-bit float type.
+  bool is16bitFPTy() const { return LLVMTy->is16bitFPTy(); }
+
+  /// Return true if this is 'float', a 32-bit IEEE fp type.
+  bool isFloatTy() const { return LLVMTy->isFloatTy(); }
+
+  /// Return true if this is 'double', a 64-bit IEEE fp type.
+  bool isDoubleTy() const { return LLVMTy->isDoubleTy(); }
+
+  /// Return true if this is x86 long double.
+  bool isX86_FP80Ty() const { return LLVMTy->isX86_FP80Ty(); }
+
+  /// Return true if this is 'fp128'.
+  bool isFP128Ty() const { return LLVMTy->isFP128Ty(); }
+
+  /// Return true if this is powerpc long double.
+  bool isPPC_FP128Ty() const { return LLVMTy->isPPC_FP128Ty(); }
+
+  /// Return true if this is a well-behaved IEEE-like type, which has a IEEE
+  /// compatible layout as defined by APFloat::isIEEE(), and does not have
+  /// non-IEEE values, such as x86_fp80's unnormal values.
+  bool isIEEELikeFPTy() const { return LLVMTy->isIEEELikeFPTy(); }
+
+  /// Return true if this is one of the floating-point types
+  bool isFloatingPointTy() const { return LLVMTy->isFloatingPointTy(); }
+
+  /// Returns true if this is a floating-point type that is an unevaluated sum
+  /// of multiple floating-point units.
+  /// An example of such a type is ppc_fp128, also known as double-double, which
+  /// consists of two IEEE 754 doubles.
+  bool isMultiUnitFPType() const { return LLVMTy->isMultiUnitFPType(); }
+
+  const fltSemantics &getFltSemantics() const {
+    return LLVMTy->getFltSemantics();
+  }
+
+  /// Return true if this is X86 AMX.
+  bool isX86_AMXTy() const { return LLVMTy->isX86_AMXTy(); }
+
+  /// Return true if this is a target extension type.
+  bool isTargetExtTy() const { return LLVMTy->isTargetExtTy(); }
+
+  /// Return true if this is a target extension type with a scalable layout.
+  bool isScalableTargetExtTy() const { return LLVMTy->isScalableTargetExtTy(); }
+
+  /// Return true if this is a type whose size is a known multiple of vscale.
+  bool isScalableTy() const { return LLVMTy->isScalableTy(); }
+
+  /// Return true if this is a FP type or a vector of FP.
+  bool isFPOrFPVectorTy() const { return LLVMTy->isFPOrFPVectorTy(); }
+
+  /// Return true if this is 'label'.
+  bool isLabelTy() const { return LLVMTy->isLabelTy(); }
+
+  /// Return true if this is 'metadata'.
+  bool isMetadataTy() const { return LLVMTy->isMetadataTy(); }
+
+  /// Return true if this is 'token'.
+  bool isTokenTy() const { return LLVMTy->isTokenTy(); }
+
+  /// True if this is an instance of IntegerType.
+  bool isIntegerTy() const { return LLVMTy->isIntegerTy(); }
+
+  /// Return true if this is an IntegerType of the given width.
+  bool isIntegerTy(unsigned Bitwidth) const {
+    return LLVMTy->isIntegerTy(Bitwidth);
+  }
+
+  /// Return true if this is an integer type or a vector of integer types.
+  bool isIntOrIntVectorTy() const { return LLVMTy->isIntOrIntVectorTy(); }
+
+  /// Return true if this is an integer type or a vector of integer types of
+  /// the given width.
+  bool isIntOrIntVectorTy(unsigned BitWidth) const {
+    return LLVMTy->isIntOrIntVectorTy(BitWidth);
+  }
+
+  /// Return true if this is an integer type or a pointer type.
+  bool isIntOrPtrTy() const { return LLVMTy->isIntOrPtrTy(); }
+
+  /// True if this is an instance of FunctionType.
+  bool isFunctionTy() const { return LLVMTy->isFunctionTy(); }
+
+  /// True if this is an instance of StructType.
+  bool isStructTy() const { return LLVMTy->isStructTy(); }
+
+  /// True if this is an instance of ArrayType.
+  bool isArrayTy() const { return LLVMTy->isArrayTy(); }
+
+  /// True if this is an instance of PointerType.
+  bool isPointerTy() const { return LLVMTy->isPointerTy(); }
+
+  /// Return true if this is a pointer type or a vector of pointer types.
+  bool isPtrOrPtrVectorTy() const { return LLVMTy->isPtrOrPtrVectorTy(); }
+
+  /// True if this is an instance of VectorType.
+  inline bool isVectorTy() const { return LLVMTy->isVectorTy(); }
+
+  /// Return true if this type could be converted with a lossless BitCast to
+  /// type 'Ty'. For example, i8* to i32*. BitCasts are valid for types of the
+  /// same size only where no re-interpretation of the bits is done.
+  /// Determine if this type could be losslessly bitcast to Ty
+  bool canLosslesslyBitCastTo(Type *Ty) const {
+    return LLVMTy->canLosslesslyBitCastTo(Ty->LLVMTy);
+  }
+
+  /// Return true if this type is empty, that is, it has no elements or all of
+  /// its elements are empty.
+  bool isEmptyTy() const { return LLVMTy->isEmptyTy(); }
+
+  /// Return true if the type is "first class", meaning it is a valid type for a
+  /// Value.
+  bool isFirstClassType() const { return LLVMTy->isFirstClassType(); }
+
+  /// Return true if the type is a valid type for a register in codegen. This
+  /// includes all first-class types except struct and array types.
+  bool isSingleValueType() const { return LLVMTy->isSingleValueType(); }
+
+  /// Return true if the type is an aggregate type. This means it is valid as
+  /// the first operand of an insertvalue or extractvalue instruction. This
+  /// includes struct and array types, but does not include vector types.
+  bool isAggregateType() const { return LLVMTy->isAggregateType(); }
+
+  /// Return true if it makes sense to take the size of this type. To get the
+  /// actual size for a particular target, it is reasonable to use the
+  /// DataLayout subsystem to do this.
+  bool isSized(SmallPtrSetImpl<Type *> *Visited = nullptr) const {
+    SmallPtrSet<llvm::Type *, 8> LLVMVisited;
+    LLVMVisited.reserve(Visited->size());
+    for (Type *Ty : *Visited)
+      LLVMVisited.insert(Ty->LLVMTy);
+    return LLVMTy->isSized(&LLVMVisited);
+  }
+
+  /// Return the basic size of this type if it is a primitive type. These are
+  /// fixed by LLVM and are not target-dependent.
+  /// This will return zero if the type does not have a size or is not a
+  /// primitive type.
+  ///
+  /// If this is a scalable vector type, the scalable property will be set and
+  /// the runtime size will be a positive integer multiple of the base size.
+  ///
+  /// Note that this may not reflect the size of memory allocated for an
+  /// instance of the type or the number of bytes that are written when an
+  /// instance of the type is stored to memory. The DataLayout class provides
+  /// additional query functions to provide this information.
+  ///
+  TypeSize getPrimitiveSizeInBits() const {
+    return LLVMTy->getPrimitiveSizeInBits();
+  }
+
+  /// If this is a vector type, return the getPrimitiveSizeInBits value for the
+  /// element type. Otherwise return the getPrimitiveSizeInBits value for this
+  /// type.
+  unsigned getScalarSizeInBits() const { return LLVMTy->getScalarSizeInBits(); }
+
+  /// Return the width of the mantissa of this type. This is only valid on
+  /// floating-point types. If the FP type does not have a stable mantissa (e.g.
+  /// ppc long double), this method returns -1.
+  int getFPMantissaWidth() const { return LLVMTy->getFPMantissaWidth(); }
+
+  /// Return whether the type is IEEE compatible, as defined by the eponymous
+  /// method in APFloat.
+  bool isIEEE() const { return LLVMTy->isIEEE(); }
+
+  /// If this is a vector type, return the element type, otherwise return
+  /// 'this'.
+  Type *getScalarType() const;
+
+  // TODO: ADD MISSING
+
+  static Type *getInt64Ty(Context &Ctx);
+  static Type *getInt32Ty(Context &Ctx);
+  static Type *getInt16Ty(Context &Ctx);
+  static Type *getInt8Ty(Context &Ctx);
+  static Type *getInt1Ty(Context &Ctx);
+  static Type *getDoubleTy(Context &Ctx);
+  static Type *getFloatTy(Context &Ctx);
+  // TODO: missing get*
+
+  /// Get the address space of this pointer or pointer vector type.
+  inline unsigned getPointerAddressSpace() const {
+    return LLVMTy->getPointerAddressSpace();
+  }
+
+#ifndef NDEBUG
+  void dumpOS(raw_ostream &OS) { LLVMTy->print(OS); }
+  LLVM_DUMP_METHOD void dump() {
+    dumpOS(dbgs());
+    dbgs() << "\n";
+  }
+#endif // NDEBUG
+};
+
+class PointerType : public Type {
+public:
+  // TODO: add missing functions
+  static PointerType *get(Type *ElementType, unsigned AddressSpace);
+  static PointerType *get(Context &Ctx, unsigned AddressSpace);
+
+  static bool classof(const Type *From) {
+    return isa<llvm::PointerType>(From->LLVMTy);
+  }
+};
+
+class VectorType : public Type {
+public:
+  // TODO: add missing functions
+  static bool classof(const Type *From) {
+    return isa<llvm::VectorType>(From->LLVMTy);
+  }
+};
+
+class FunctionType : public Type {
+public:
+  // TODO: add missing functions
+  static bool classof(const Type *From) {
+    return isa<llvm::FunctionType>(From->LLVMTy);
+  }
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_SANDBOXIR_TYPE_H
diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt
index 6c0666b186b8a6..d94f0642ccc4a1 100644
--- a/llvm/lib/SandboxIR/CMakeLists.txt
+++ b/llvm/lib/SandboxIR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_llvm_component_library(LLVMSandboxIR
   SandboxIR.cpp
   Tracker.cpp
+  Type.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index b75424909f0835..921bab66e8f1cc 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -135,6 +135,8 @@ Value::user_iterator Value::user_begin() {
 
 unsigned Value::getNumUses() const { return range_size(Val->users()); }
 
+Type *Value::getType() const { return Ctx.getType(Val->getType()); }
+
 void Value::replaceUsesWithIf(
     Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
   assert(getType() == OtherV->getType() && "Can't replace with different type");
@@ -583,7 +585,8 @@ VAArgInst *VAArgInst::create(Value *List, Type *Ty, BBIterator WhereIt,
     Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
   else
     Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
-  auto *LLVMI = cast<llvm::VAArgInst>(Builder.CreateVAArg(List->Val, Ty, Name));
+  auto *LLVMI =
+      cast<llvm::VAArgInst>(Builder.CreateVAArg(List->Val, Ty->LLVMTy, Name));
   return Ctx.createVAArgInst(LLVMI);
 }
 
@@ -754,7 +757,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
   auto &Builder = Ctx.getLLVMIRBuilder();
   Builder.SetInsertPoint(BeforeIR);
   auto *NewLI =
-      Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name);
+      Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name);
   auto *NewSBI = Ctx.createLoadInst(NewLI);
   return NewSBI;
 }
@@ -771,7 +774,7 @@ LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
   auto &Builder = Ctx.getLLVMIRBuilder();
   Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
   auto *NewLI =
-      Builder.CreateAlignedLoad(Ty, Ptr->Val, Align, IsVolatile, Name);
+      Builder.CreateAlignedLoad(Ty->LLVMTy, Ptr->Val, Align, IsVolatile, Name);
   auto *NewSBI = Ctx.createLoadInst(NewLI);
   return NewSBI;
 }
@@ -886,6 +889,11 @@ Value *ReturnInst::getReturnValue() const {
   return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr;
 }
 
+FunctionType *CallBase::getFunctionType() const {
+  return cast<FunctionType>(
+      Ctx.getType(cast<llvm::CallBase>(Val)->getFunctionType()));
+}
+
 Value *CallBase::getCalledOperand() const {
   return Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledOperand());
 }
@@ -911,8 +919,9 @@ void CallBase::setCalledFunction(Function *F) {
   // Note: This may break if `setCalledFunction()` early returns if `F`
   // is already set, but we do have a unit test for it.
   setCalledOperand(F);
-  cast<llvm::CallBase>(Val)->setCalledFunction(F->getFunctionType(),
-                                               cast<llvm::Function>(F->Val));
+  cast<llvm::CallBase>(Val)->setCalledFunction(
+      cast<llvm::FunctionType>(F->getFunctionType()->LLVMTy),
+      cast<llvm::Function>(F->Val));
 }
 
 CallInst *CallInst::create(FunctionType *FTy, Value *Func,
@@ -928,7 +937,8 @@ CallInst *CallInst::create(FunctionType *FTy, Value *Func,
   LLVMArgs.reserve(Args.size());
   for (Value *Arg : Args)
     LLVMArgs.push_back(Arg->Val);
-  llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr);
+  llvm::CallInst *NewCI = Builder.CreateCall(
+      cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val, LLVMArgs, NameStr);
   return Ctx.createCallInst(NewCI);
 }
 
@@ -961,7 +971,8 @@ InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
   for (Value *Arg : Args)
     LLVMArgs.push_back(Arg->Val);
   llvm::InvokeInst *Invoke = Builder.CreateInvoke(
-      FTy, Func->Val, cast<llvm::BasicBlock>(IfNormal->Val),
+      cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val,
+      cast<llvm::BasicBlock>(IfNormal->Val),
       cast<llvm::BasicBlock>(IfException->Val), LLVMArgs, NameStr);
   return Ctx.createInvokeInst(Invoke);
 }
@@ -1032,9 +1043,10 @@ CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func,
   for (Value *Arg : Args)
     LLVMArgs.push_back(Arg->Val);
 
-  llvm::CallBrInst *CallBr = Builder.CreateCallBr(
-      FTy, Func->Val, cast<llvm::BasicBlock>(DefaultDest->Val),
-      LLVMIndirectDests, LLVMArgs, NameStr);
+  llvm::CallBrInst *CallBr =
+      Builder.CreateCallBr(cast<llvm::FunctionType>(FTy->LLVMTy), Func->Val,
+                           cast<llvm::BasicBlock>(DefaultDest->Val),
+                           LLVMIndirectDests, LLVMArgs, NameStr);
   return Ctx.createCallBrInst(CallBr);
 }
 
@@ -1107,7 +1119,7 @@ LandingPadInst *LandingPadInst::create(Type *RetTy, unsigned NumReservedClauses,
   else
     Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
   llvm::LandingPadInst *LLVMI =
-      Builder.CreateLandingPad(RetTy, NumReservedClauses, Name);
+      Builder.CreateLandingPad(RetTy->LLVMTy, NumReservedClauses, Name);
   return Ctx.createLandingPadInst(LLVMI);
 }
 
@@ -1288,7 +1300,8 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
   LLVMIdxList.reserve(IdxList.size());
   for (Value *Idx : IdxList)
     LLVMIdxList.push_back(Idx->Val);
-  llvm::Value *NewV = Builder.CreateGEP(Ty, Ptr->Val, LLVMIdxList, NameStr);
+  llvm::Value *NewV =
+      Builder.CreateGEP(Ty->LLVMTy, Ptr->Val, LLVMIdxList, NameStr);
   if (auto *NewGEP = dyn_cast<llvm::GetElementPtrInst>(NewV))
     return Ctx.createGetElementPtrInst(NewGEP);
   assert(isa<llvm::Constant>(NewV) && "Expected constant");
@@ -1312,10 +1325,25 @@ Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
                                    InsertAtEnd, Ctx, NameStr);
 }
 
+Type *GetElementPtrInst::getSourceElementType() const {
+  return Ctx.getType(
+      cast<llvm::GetElementPtrInst>(Val)->getSourceElementType());
+}
+
+Type *GetElementPtrInst::getResultElementType() const {
+  return Ctx.getType(
+      cast<llvm::GetElementPtrInst>(Val)->getResultElementType());
+}
+
 Value *GetElementPtrInst::getPointerOperand() const {
   return Ctx.getValue(cast<llvm::GetElementPtrInst>(Val)->getPointerOperand());
 }
 
+Type *GetElementPtrInst::getPointerOperandType() const {
+  return Ctx.getType(
+      cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType());
+}
+
 BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
   return cast<BasicBlock>(Ctx.getValue(LLVMBB));
 }
@@ -1323,8 +1351,9 @@ BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
 PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues,
                          Instruction *InsertBefore, Context &Ctx,
                          const Twine &Name) {
-  llvm::PHINode *NewPHI = llvm::PHINode::Create(
-      Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction());
+  llvm::PHINode *NewPHI =
+      llvm::PHINode::Create(Ty->LLVMTy, NumReservedValues, Name,
+                            InsertBefore->getTopmostLLVMInstruction());
   return Ctx.createPHINode(NewPHI);
 }
 
@@ -1943,7 +1972,8 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace, BBIterator WhereIt,
     Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
   else
     Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
-  auto *NewAlloca = Builder.CreateAlloca(Ty, AddrSpace, ArraySize->Val, Name);
+  auto *NewAlloca =
+      Builder.CreateAlloca(Ty->LLVMTy, AddrSpace, ArraySize->Val, Name);
   return Ctx.createAllocaInst(NewAlloca);
 }
 
@@ -1961,11 +1991,15 @@ AllocaInst *AllocaInst::create(Type *Ty, unsigned AddrSpace,
                 Name);
 }
 
+Type *AllocaInst::getAllocatedType() const {
+  return Ctx.getType(cast<llvm::AllocaInst>(Val)->getAllocatedType());
+}
+
 void AllocaInst::setAllocatedType(Type *Ty) {
   Ctx.getTracker()
       .emplaceIfTracking<GenericSetter<&AllocaInst::getAllocatedType,
                                        &AllocaInst::setAllocatedType>>(this);
-  cast<llvm::AllocaInst>(Val)->setAllocatedType(Ty);
+  cast<llvm::AllocaInst>(Val)->setAllocatedType(Ty->LLVMTy);
 }
 
 void AllocaInst::setAlignment(Align Align) {
@@ -1987,6 +2021,10 @@ Value *AllocaInst::getArraySize() {
   return Ctx.getValue(cast<llvm::AllocaInst>(Val)->getArraySize());
 }
 
+PointerType *AllocaInst::getType() const {
+  return cast<PointerType>(Ctx.getType(cast<llvm::AllocaInst>(Val)->getType()));
+}
+
 Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
                         BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
                         const Twine &Name) {
@@ -1997,7 +2035,7 @@ Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
   else
     Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
   auto *NewV =
-      Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy, Name);
+      Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy->LLVMTy, Name);
   if (auto *NewCI = dyn_cast<llvm::CastInst>(NewV))
     return Ctx.createCastInst(NewCI);
   assert(isa<llvm::Constant>(NewV) && "Expected constant");
@@ -2022,6 +2060,14 @@ bool CastInst::classof(const Value *From) {
   return From->getSubclassID() == ClassID::Cast;
 }
 
+Type *CastInst::getSrcTy() const {
+  return Ctx.getType(cast<llvm::CastInst>(Val)->getSrcTy());
+}
+
+Type *CastInst::getDestTy() const {
+  return Ctx.getType(cast<llvm::CastInst>(Val)->getDestTy());
+}
+
 void PossiblyNonNegInst::setNonNeg(bool B) {
   Ctx.getTracker()
       .emplaceIfTracking<GenericSetter<&PossiblyNonNegInst::hasNonNeg,
@@ -2134,15 +2180,25 @@ void ShuffleVectorInst::setShuffleMask(ArrayRef<int> Mask) {
   cast<llvm::ShuffleVectorInst>(Val)->setShuffleMask(Mask);
 }
 
+VectorType *ShuffleVectorInst::getType() const {
+  return cast<VectorType>(
+      Ctx.getType(cast<llvm::ShuffleVectorInst>(Val)->getType()));
+}
+
 Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const {
   return Ctx.getOrCreateConstant(
       cast<llvm::ShuffleVectorInst>(Val)->getShuffleMaskForBitcode());
 }
 
-Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(
-    llvm::ArrayRef<int> Mask, llvm::Type *ResultTy, Context &Ctx) {
-  return Ctx.getOrCreateConstant(
-      llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy));
+Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef<int> Mask,
+                                                          Type *ResultTy) {
+  return ResultTy->getContext().getOrCreateConstant(
+      llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask,
+                                                            ResultTy->LLVMTy));
+}
+
+VectorType *ExtractElementInst::getVectorOperandType() const {
+  return cast<VectorType>(Ctx.getType(getVectorOperand()->getType()->LLVMTy));
 }
 
 #ifndef NDEBUG
@@ -2152,10 +2208,14 @@ void Constant::dumpOS(raw_ostream &OS) const {
 }
 #endif // NDEBUG
 
-ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, Context &Ctx,
-                              bool IsSigned) {
-  auto *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
-  return cast<ConstantInt>(Ctx.getOrCreateConstant(LLVMC));
+ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, bool IsSigned) {
+  auto *LLVMC = llvm::ConstantInt::get(Ty->LLVMTy, V, IsSigned);
+  return cast<ConstantInt>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+
+FunctionType *Function::getFunctionType() const {
+  return cast<FunctionType>(
+      Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
 }
 
 #ifndef NDEBUG
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
new file mode 100644
index 00000000000000..6f850b82d2e996
--- /dev/null
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -0,0 +1,48 @@
+//===- Type.cpp - Sandbox IR Type -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/SandboxIR/Type.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+
+using namespace llvm::sandboxir;
+
+Type *Type::getScalarType() const {
+  return Ctx.getType(LLVMTy->getScalarType());
+}
+
+Type *Type::getInt64Ty(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getInt64Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt32Ty(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getInt32Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt16Ty(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getInt16Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt8Ty(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getInt8Ty(Ctx.LLVMCtx));
+}
+Type *Type::getInt1Ty(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getInt1Ty(Ctx.LLVMCtx));
+}
+Type *Type::getDoubleTy(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getDoubleTy(Ctx.LLVMCtx));
+}
+Type *Type::getFloatTy(Context &Ctx) {
+  return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx));
+}
+
+PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) {
+  return cast<PointerType>(ElementType->getContext().getType(
+      llvm::PointerType::get(ElementType->LLVMTy, AddressSpace)));
+}
+
+PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
+  return cast<PointerType>(
+      Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
+}
diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt
index 3f43f6337b919b..2da936bffa02bf 100644
--- a/llvm/unittests/SandboxIR/CMakeLists.txt
+++ b/llvm/unittests/SandboxIR/CMakeLists.txt
@@ -7,4 +7,5 @@ set(LLVM_LINK_COMPONENTS
 add_llvm_unittest(SandboxIRTests
   SandboxIRTest.cpp
   TrackerTest.cpp
+  TypesTest.cpp
   )
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bc3fddf9e163dc..d1217f56ec29e7 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -121,10 +121,12 @@ define void @foo(i32 %v0) {
   auto *FortyTwo = cast<sandboxir::ConstantInt>(Add0->getOperand(1));
 
   // Check that creating an identical constant gives us the same object.
-  auto *NewCI = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+  auto *NewCI =
+      sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
   EXPECT_EQ(NewCI, FortyTwo);
   // Check new constant.
-  auto *FortyThree = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 43, Ctx);
+  auto *FortyThree =
+      sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 43);
   EXPECT_NE(FortyThree, FortyTwo);
 }
 
@@ -603,7 +605,7 @@ define void @foo(ptr %va) {
   EXPECT_EQ(sandboxir::VAArgInst::getPointerOperandIndex(),
             llvm::VAArgInst::getPointerOperandIndex());
   // Check create().
-  auto *NewVATy = Type::getInt8Ty(C);
+  auto *NewVATy = sandboxir::Type::getInt8Ty(Ctx);
   auto *NewVA = sandboxir::VAArgInst::create(Arg, NewVATy, Ret->getIterator(),
                                              Ret->getParent(), Ctx, "NewVA");
   EXPECT_EQ(NewVA->getNextNode(), Ret);
@@ -743,10 +745,10 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   }
   {
     // Check SelectInst::create() Folded.
-    auto *False = sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 0, Ctx,
-                                              /*IsSigned=*/false);
+    auto *False = sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx),
+                                              0, /*IsSigned=*/false);
     auto *FortyTwo =
-        sandboxir::ConstantInt::get(llvm::Type::getInt1Ty(C), 42, Ctx,
+        sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx), 42,
                                     /*IsSigned=*/false);
     auto *NewSel =
         sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
@@ -838,7 +840,7 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
 
   auto *LLVMArg0 = LLVMF.getArg(0);
   auto *LLVMArgVec = LLVMF.getArg(2);
-  auto *Zero = sandboxir::ConstantInt::get(Type::getInt8Ty(C), 0, Ctx);
+  auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
   auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0);
   EXPECT_EQ(
       sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero),
@@ -950,7 +952,7 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
   // convertShuffleMaskForBitcode
   {
     auto *C = sandboxir::ShuffleVectorInst::convertShuffleMaskForBitcode(
-        ArrayRef<int>({2, 3}), ArgV1->getType(), Ctx);
+        ArrayRef<int>({2, 3}), ArgV1->getType());
     SmallVector<int, 2> Result;
     sandboxir::ShuffleVectorInst::getShuffleMask(C, Result);
     EXPECT_THAT(Result, testing::ElementsAre(2, 3));
@@ -1607,7 +1609,8 @@ define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) {
     // Check classof(Value *).
     EXPECT_TRUE(isa<sandboxir::CallBase>((sandboxir::Value *)Call));
     // Check getFunctionType().
-    EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType());
+    EXPECT_EQ(Call->getFunctionType(),
+              Ctx.getType(LLVMCall->getFunctionType()));
     // Check data_ops().
     EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops()));
     auto DataOpIt = Call->data_operands_begin();
@@ -1738,7 +1741,7 @@ define i8 @foo(i8 %arg) {
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
   EXPECT_EQ(Call->getNumOperands(), 2u);
   EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
-  FunctionType *FTy = F.getFunctionType();
+  sandboxir::FunctionType *FTy = F.getFunctionType();
   SmallVector<sandboxir::Value *, 1> Args;
   Args.push_back(Arg0);
   {
@@ -2027,8 +2030,8 @@ define void @foo() {
   auto *BBRet = &*BB->begin();
   auto *NewLPad =
       cast<sandboxir::LandingPadInst>(sandboxir::LandingPadInst::create(
-          Type::getInt8Ty(C), 0, BBRet->getIterator(), BBRet->getParent(), Ctx,
-          "NewLPad"));
+          sandboxir::Type::getInt8Ty(Ctx), 0, BBRet->getIterator(),
+          BBRet->getParent(), Ctx, "NewLPad"));
   EXPECT_EQ(NewLPad->getNextNode(), BBRet);
   EXPECT_FALSE(NewLPad->isCleanup());
 #ifndef NDEBUG
@@ -2287,9 +2290,11 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
     // Check classof().
     auto *GEP = cast<sandboxir::GetElementPtrInst>(Ctx.getValue(LLVMGEP));
     // Check getSourceElementType().
-    EXPECT_EQ(GEP->getSourceElementType(), LLVMGEP->getSourceElementType());
+    EXPECT_EQ(GEP->getSourceElementType(),
+              Ctx.getType(LLVMGEP->getSourceElementType()));
     // Check getResultElementType().
-    EXPECT_EQ(GEP->getResultElementType(), LLVMGEP->getResultElementType());
+    EXPECT_EQ(GEP->getResultElementType(),
+              Ctx.getType(LLVMGEP->getResultElementType()));
     // Check getAddressSpace().
     EXPECT_EQ(GEP->getAddressSpace(), LLVMGEP->getAddressSpace());
     // Check indices().
@@ -2305,7 +2310,8 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
     // Check getPointerOperandIndex().
     EXPECT_EQ(GEP->getPointerOperandIndex(), LLVMGEP->getPointerOperandIndex());
     // Check getPointerOperandType().
-    EXPECT_EQ(GEP->getPointerOperandType(), LLVMGEP->getPointerOperandType());
+    EXPECT_EQ(GEP->getPointerOperandType(),
+              Ctx.getType(LLVMGEP->getPointerOperandType()));
     // Check getPointerAddressSpace().
     EXPECT_EQ(GEP->getPointerAddressSpace(), LLVMGEP->getPointerAddressSpace());
     // Check getNumIndices().
@@ -2666,8 +2672,8 @@ define void @foo(i32 %cond0, i32 %cond1) {
   Switch->setSuccessor(0, OrigSucc);
   EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
   // Check case_begin(), case_end(), CaseIt.
-  auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
-  auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+  auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+  auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1);
   auto CaseIt = Switch->case_begin();
   {
     sandboxir::SwitchInst::CaseHandle Case = *CaseIt++;
@@ -2704,7 +2710,7 @@ define void @foo(i32 %cond0, i32 %cond1) {
   EXPECT_EQ(Switch->findCaseDest(BB1), One);
   EXPECT_EQ(Switch->findCaseDest(Entry), nullptr);
   // Check addCase().
-  auto *Two = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 2, Ctx);
+  auto *Two = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 2);
   Switch->addCase(Two, Entry);
   auto CaseTwoIt = Switch->findCaseValue(Two);
   auto CaseTwo = *CaseTwoIt;
@@ -2969,7 +2975,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
   }
   {
     // Check create() when it gets folded.
-    auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+    auto *FortyTwo =
+        sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
     auto *NewV = sandboxir::BinaryOperator::create(
         sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo,
         /*InsertBefore=*/Ret, Ctx, "Folded");
@@ -3025,7 +3032,8 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
   }
   {
     // Check createWithCopiedFlags() when it gets folded.
-    auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+    auto *FortyTwo =
+        sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
     auto *NewV = sandboxir::BinaryOperator::createWithCopiedFlags(
         sandboxir::Instruction::Opcode::Add, FortyTwo, FortyTwo, CopyFrom,
         /*InsertBefore=*/Ret, Ctx, "Folded");
@@ -3447,8 +3455,8 @@ define void @foo() {
   EXPECT_EQ(AllocaArray->getArraySize(),
             Ctx.getValue(LLVMAllocaArray->getArraySize()));
   // Check getType().
-  EXPECT_EQ(AllocaScalar->getType(), LLVMAllocaScalar->getType());
-  EXPECT_EQ(AllocaArray->getType(), LLVMAllocaArray->getType());
+  EXPECT_EQ(AllocaScalar->getType(), Ctx.getType(LLVMAllocaScalar->getType()));
+  EXPECT_EQ(AllocaArray->getType(), Ctx.getType(LLVMAllocaArray->getType()));
   // Check getAddressSpace().
   EXPECT_EQ(AllocaScalar->getAddressSpace(),
             LLVMAllocaScalar->getAddressSpace());
@@ -3465,12 +3473,12 @@ define void @foo() {
             LLVMAllocaArray->getAllocationSizeInBits(DL));
   // Check getAllocatedType().
   EXPECT_EQ(AllocaScalar->getAllocatedType(),
-            LLVMAllocaScalar->getAllocatedType());
+            Ctx.getType(LLVMAllocaScalar->getAllocatedType()));
   EXPECT_EQ(AllocaArray->getAllocatedType(),
-            LLVMAllocaArray->getAllocatedType());
+            Ctx.getType(LLVMAllocaArray->getAllocatedType()));
   // Check setAllocatedType().
   auto *OrigType = AllocaScalar->getAllocatedType();
-  auto *NewType = PointerType::get(C, 0);
+  auto *NewType = sandboxir::PointerType::get(Ctx, 0);
   EXPECT_NE(NewType, OrigType);
   AllocaScalar->setAllocatedType(NewType);
   EXPECT_EQ(AllocaScalar->getAllocatedType(), NewType);
@@ -3501,10 +3509,10 @@ define void @foo() {
   AllocaScalar->setUsedWithInAlloca(OrigUsedWithInAlloca);
   EXPECT_EQ(AllocaScalar->isUsedWithInAlloca(), OrigUsedWithInAlloca);
 
-  auto *Ty = Type::getInt32Ty(C);
+  auto *Ty = sandboxir::Type::getInt32Ty(Ctx);
   unsigned AddrSpace = 42;
-  auto *PtrTy = PointerType::get(C, AddrSpace);
-  auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43, Ctx);
+  auto *PtrTy = sandboxir::PointerType::get(Ctx, AddrSpace);
+  auto *ArraySize = sandboxir::ConstantInt::get(Ty, 43);
   {
     // Check create() WhereIt, WhereBB.
     auto *NewI = cast<sandboxir::AllocaInst>(sandboxir::AllocaInst::create(
@@ -3581,13 +3589,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
 
-  Type *Ti64 = Type::getInt64Ty(C);
-  Type *Ti32 = Type::getInt32Ty(C);
-  Type *Ti16 = Type::getInt16Ty(C);
-  Type *Tdouble = Type::getDoubleTy(C);
-  Type *Tfloat = Type::getFloatTy(C);
-  Type *Tptr = Tfloat->getPointerTo();
-  Type *Tptr1 = Tfloat->getPointerTo(1);
+  auto *Ti64 = sandboxir::Type::getInt64Ty(Ctx);
+  auto *Ti32 = sandboxir::Type::getInt32Ty(Ctx);
+  auto *Ti16 = sandboxir::Type::getInt16Ty(Ctx);
+  auto *Tdouble = sandboxir::Type::getDoubleTy(Ctx);
+  auto *Tfloat = sandboxir::Type::getFloatTy(Ctx);
+  auto *Tptr = sandboxir::PointerType::get(Tfloat, 0);
+  auto *Tptr1 = sandboxir::PointerType::get(Tfloat, 1);
 
   // Check classof(), getOpcode(), getSrcTy(), getDstTy()
   auto *ZExt = cast<sandboxir::CastInst>(&*It++);
@@ -3799,10 +3807,13 @@ define void @foo(i32 %arg, float %farg, double %darg, ptr %ptr) {
 /// CastInst's subclasses are very similar so we can use a common test function
 /// for them.
 template <typename SubclassT, sandboxir::Instruction::Opcode OpcodeT>
-void testCastInst(llvm::Module &M, Type *SrcTy, Type *DstTy) {
+void testCastInst(llvm::Module &M, llvm::Type *LLVMSrcTy,
+                  llvm::Type *LLVMDstTy) {
   Function &LLVMF = *M.getFunction("foo");
   sandboxir::Context Ctx(M.getContext());
   sandboxir::Function *F = Ctx.createFunction(&LLVMF);
+  sandboxir::Type *SrcTy = Ctx.getType(LLVMSrcTy);
+  sandboxir::Type *DstTy = Ctx.getType(LLVMDstTy);
   unsigned ArgIdx = 0;
   auto *Arg = F->getArg(ArgIdx++);
   auto *BB = &*F->begin();
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index ca6effb727bf37..c189100fbd6947 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -938,9 +938,10 @@ define void @foo(i32 %cond0, i32 %cond1) {
   Ctx.revert();
   EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
   // Check addCase().
-  auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
-  auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
-  auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+  auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+  auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1);
+  auto *FortyTwo =
+      sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
   Ctx.save();
   Switch->addCase(FortyTwo, Entry);
   EXPECT_EQ(Switch->getNumCases(), 3u);
@@ -1187,7 +1188,7 @@ define void @foo(i8 %arg) {
   // Check setAllocatedType().
   Ctx.save();
   auto *OrigTy = Alloca->getAllocatedType();
-  auto *NewTy = Type::getInt64Ty(C);
+  auto *NewTy = sandboxir::Type::getInt64Ty(Ctx);
   EXPECT_NE(NewTy, OrigTy);
   Alloca->setAllocatedType(NewTy);
   EXPECT_EQ(Alloca->getAllocatedType(), NewTy);
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
new file mode 100644
index 00000000000000..d72725883a7dd7
--- /dev/null
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -0,0 +1,242 @@
+//===- TypesTest.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct SandboxTypeTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("SandboxTypeTest", errs());
+  }
+  BasicBlock *getBasicBlockByName(Function &F, StringRef Name) {
+    for (BasicBlock &BB : F)
+      if (BB.getName() == Name)
+        return &BB;
+    llvm_unreachable("Expected to find basic block!");
+  }
+};
+
+TEST_F(SandboxTypeTest, Type) {
+  parseIR(C, R"IR(
+define void @foo(i32 %v0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  sandboxir::Type *I32Ty = F->getArg(0)->getType();
+
+  auto *LLVMInt32Ty = llvm::Type::getInt32Ty(C);
+  auto *LLVMFloatTy = llvm::Type::getFloatTy(C);
+  auto *LLVMInt8Ty = llvm::Type::getInt8Ty(C);
+
+  auto *Int32Ty = Ctx.getType(LLVMInt32Ty);
+  auto *FloatTy = Ctx.getType(LLVMFloatTy);
+
+  // Check getContext().
+  EXPECT_EQ(&I32Ty->getContext(), &Ctx);
+
+#define CHK(LLVMCreate, SBCheck)                                               \
+  Ctx.getType(llvm::Type::LLVMCreate(C))->SBCheck()
+  // Check isVoidTy().
+  EXPECT_TRUE(Ctx.getType(llvm::Type::getVoidTy(C))->isVoidTy());
+  EXPECT_TRUE(CHK(getVoidTy, isVoidTy));
+  // Check isHalfTy().
+  EXPECT_TRUE(CHK(getHalfTy, isHalfTy));
+  // Check isBFloatTy().
+  EXPECT_TRUE(CHK(getBFloatTy, isBFloatTy));
+  // Check is16bitFPTy().
+  EXPECT_TRUE(CHK(getHalfTy, is16bitFPTy));
+  // Check isFloatTy().
+  EXPECT_TRUE(CHK(getFloatTy, isFloatTy));
+  // Check isDoubleTy().
+  EXPECT_TRUE(CHK(getDoubleTy, isDoubleTy));
+  // Check isX86_FP80Ty().
+  EXPECT_TRUE(CHK(getX86_FP80Ty, isX86_FP80Ty));
+  // Check isFP128Ty().
+  EXPECT_TRUE(CHK(getFP128Ty, isFP128Ty));
+  // Check isPPC_FP128Ty().
+  EXPECT_TRUE(CHK(getPPC_FP128Ty, isPPC_FP128Ty));
+  // Check isIEEELikeFPTy().
+  EXPECT_TRUE(CHK(getFloatTy, isIEEELikeFPTy));
+  // Check isFloatingPointTy().
+  EXPECT_TRUE(CHK(getFloatTy, isFloatingPointTy));
+  EXPECT_TRUE(CHK(getDoubleTy, isFloatingPointTy));
+  // Check isMultiUnitFPType().
+  EXPECT_TRUE(CHK(getPPC_FP128Ty, isMultiUnitFPType));
+  EXPECT_FALSE(CHK(getFloatTy, isMultiUnitFPType));
+  // Check getFltSemantics().
+  EXPECT_EQ(&sandboxir::Type::getFloatTy(Ctx)->getFltSemantics(),
+            &llvm::Type::getFloatTy(C)->getFltSemantics());
+  // Check isX86_AMXTy().
+  EXPECT_TRUE(CHK(getX86_AMXTy, isX86_AMXTy));
+  // Check isTargetExtTy().
+  EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "foo"))->isTargetExtTy());
+  // Check isScalableTargetExtTy().
+  EXPECT_TRUE(Ctx.getType(llvm::TargetExtType::get(C, "aarch64.svcount"))
+                  ->isScalableTargetExtTy());
+  // Check isScalableTy().
+  EXPECT_TRUE(Ctx.getType(llvm::ScalableVectorType::get(LLVMInt32Ty, 2u))
+                  ->isScalableTy());
+  // Check isFPOrFPVectorTy().
+  EXPECT_TRUE(CHK(getFloatTy, isFPOrFPVectorTy));
+  EXPECT_FALSE(CHK(getInt32Ty, isFPOrFPVectorTy));
+  // Check isLabelTy().
+  EXPECT_TRUE(CHK(getLabelTy, isLabelTy));
+  // Check isMetadataTy().
+  EXPECT_TRUE(CHK(getMetadataTy, isMetadataTy));
+  // Check isTokenTy().
+  EXPECT_TRUE(CHK(getTokenTy, isTokenTy));
+  // Check isIntegerTy().
+  EXPECT_TRUE(CHK(getInt32Ty, isIntegerTy));
+  EXPECT_FALSE(CHK(getFloatTy, isIntegerTy));
+  // Check isIntegerTy(Bitwidth).
+  EXPECT_TRUE(LLVMInt32Ty->isIntegerTy(32u));
+  EXPECT_FALSE(LLVMInt32Ty->isIntegerTy(31u));
+  EXPECT_FALSE(Ctx.getType(llvm::Type::getFloatTy(C))->isIntegerTy(32u));
+  // Check isIntOrIntVectorTy().
+  EXPECT_TRUE(LLVMInt32Ty->isIntOrIntVectorTy());
+  EXPECT_TRUE(Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8))
+                  ->isIntOrIntVectorTy());
+  EXPECT_FALSE(Ctx.getType(LLVMFloatTy)->isIntOrIntVectorTy());
+  EXPECT_FALSE(Ctx.getType(llvm::FixedVectorType::get(LLVMFloatTy, 8))
+                   ->isIntOrIntVectorTy());
+  // Check isIntOrPtrTy().
+  EXPECT_TRUE(Int32Ty->isIntOrPtrTy());
+  EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isIntOrPtrTy());
+  EXPECT_FALSE(FloatTy->isIntOrPtrTy());
+  // Check isFunctionTy().
+  EXPECT_TRUE(Ctx.getType(llvm::FunctionType::get(LLVMInt32Ty, {}, false))
+                  ->isFunctionTy());
+  // Check isStructTy().
+  EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isStructTy());
+  // Check isArrayTy().
+  EXPECT_TRUE(Ctx.getType(llvm::ArrayType::get(LLVMInt32Ty, 10))->isArrayTy());
+  // Check isPointerTy().
+  EXPECT_TRUE(Ctx.getType(llvm::PointerType::get(C, 0u))->isPointerTy());
+  // Check isPtrOrPtrVectroTy().
+  EXPECT_TRUE(
+      Ctx.getType(llvm::FixedVectorType::get(llvm::PointerType::get(C, 0u), 8u))
+          ->isPtrOrPtrVectorTy());
+  // Check isVectorTy().
+  EXPECT_TRUE(
+      Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->isVectorTy());
+  // Check canLosslesslyBitCastTo().
+  auto *VecTy32x4 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 4u));
+  auto *VecTy32x2 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 2u));
+  auto *VecTy8x16 = Ctx.getType(llvm::FixedVectorType::get(LLVMInt8Ty, 16u));
+  EXPECT_TRUE(VecTy32x4->canLosslesslyBitCastTo(VecTy8x16));
+  EXPECT_FALSE(VecTy32x4->canLosslesslyBitCastTo(VecTy32x2));
+  // Check isEmptyTy().
+  EXPECT_TRUE(Ctx.getType(llvm::StructType::get(C))->isEmptyTy());
+  // Check isFirstClassType().
+  EXPECT_TRUE(Int32Ty->isFirstClassType());
+  // Check isSingleValueType().
+  EXPECT_TRUE(Int32Ty->isSingleValueType());
+  // Check isAggregateType().
+  EXPECT_FALSE(Int32Ty->isAggregateType());
+  // Check isSized().
+  SmallPtrSet<sandboxir::Type *, 1> Visited;
+  EXPECT_TRUE(Int32Ty->isSized(&Visited));
+  // Check getPrimitiveSizeInBits().
+  EXPECT_EQ(VecTy32x2->getPrimitiveSizeInBits(), 32u * 2);
+  // Check getScalarSizeInBits().
+  EXPECT_EQ(VecTy32x2->getScalarSizeInBits(), 32u);
+  // Check getFPMantissaWidth().
+  EXPECT_EQ(FloatTy->getFPMantissaWidth(), LLVMFloatTy->getFPMantissaWidth());
+  // Check isIEEE().
+  EXPECT_EQ(FloatTy->isIEEE(), LLVMFloatTy->isIEEE());
+  // Check getScalarType().
+  EXPECT_EQ(
+      Ctx.getType(llvm::FixedVectorType::get(LLVMInt32Ty, 8u))->getScalarType(),
+      Int32Ty);
+
+#define CHK_GET(TY)                                                            \
+  EXPECT_EQ(Ctx.getType(llvm::Type::get##TY##Ty(C)),                           \
+            sandboxir::Type::get##TY##Ty(Ctx))
+  // Check getInt64Ty().
+  CHK_GET(Int64);
+  // Check getInt32Ty().
+  CHK_GET(Int32);
+  // Check getInt16Ty().
+  CHK_GET(Int16);
+  // Check getInt8Ty().
+  CHK_GET(Int8);
+  // Check getInt1Ty().
+  CHK_GET(Int1);
+  // Check getDoubleTy().
+  CHK_GET(Double);
+  // Check getFloatTy().
+  CHK_GET(Float);
+}
+
+TEST_F(SandboxTypeTest, PointerType) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation.
+  auto *PtrTy = cast<sandboxir::PointerType>(F->getArg(0)->getType());
+  // Check get(ElementType, AddressSpace).
+  auto *NewPtrTy =
+      sandboxir::PointerType::get(sandboxir::Type::getInt32Ty(Ctx), 0u);
+  EXPECT_EQ(NewPtrTy, PtrTy);
+  // Check get(Ctx, AddressSpace).
+  auto *NewPtrTy2 = sandboxir::PointerType::get(Ctx, 0u);
+  EXPECT_EQ(NewPtrTy2, PtrTy);
+}
+
+TEST_F(SandboxTypeTest, VectorType) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation.
+  [[maybe_unused]] auto *VecTy =
+      cast<sandboxir::VectorType>(F->getArg(0)->getType());
+}
+
+TEST_F(SandboxTypeTest, FunctionType) {
+  parseIR(C, R"IR(
+define void @foo() {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation.
+  [[maybe_unused]] auto *FTy =
+      cast<sandboxir::FunctionType>(F->getFunctionType());
+}



More information about the llvm-commits mailing list