[llvm] a80a888 - [DirectX backend] Support global ctor for DXILBitcodeWriter.

Xiang Li via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 30 11:27:34 PDT 2022


Author: Xiang Li
Date: 2022-09-30T11:27:23-07:00
New Revision: a80a888de5b1a96c64ed308dc858cb1820108178

URL: https://github.com/llvm/llvm-project/commit/a80a888de5b1a96c64ed308dc858cb1820108178
DIFF: https://github.com/llvm/llvm-project/commit/a80a888de5b1a96c64ed308dc858cb1820108178.diff

LOG: [DirectX backend] Support global ctor for DXILBitcodeWriter.

1. Save typed pointer type for GlobalVariable/Function instead of the ObjectType.
   This will allow use GlobalVariable/Function as value.
2. Save target type for global ctors for Constant.
3. In DXILBitcodeWriter::getTypeID, check PointerMap first for Constant case.

Reviewed By: beanz

Differential Revision: https://reviews.llvm.org/D133283

Added: 
    llvm/test/tools/dxil-dis/global_ctor.ll

Modified: 
    llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
    llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
    llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
index 4b9447340b34a..09909a252ee9f 100644
--- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
+++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
@@ -346,7 +346,11 @@ class DXILBitcodeWriter {
   unsigned getEncodedAlign(MaybeAlign Alignment) { return encode(Alignment); }
 
   unsigned getTypeID(Type *T, const Value *V = nullptr);
-  unsigned getTypeID(Type *T, const Function *F);
+  /// getGlobalObjectValueTypeID - returns the element type for a GlobalObject
+  ///
+  /// GlobalObject types are saved by PointerTypeAnalysis as pointers to the
+  /// GlobalObject, but in the bitcode writer we need the pointer element type.
+  unsigned getGlobalObjectValueTypeID(Type *T, const GlobalObject *G);
 };
 
 } // namespace dxil
@@ -551,18 +555,30 @@ unsigned DXILBitcodeWriter::getEncodedBinaryOpcode(unsigned Opcode) {
 }
 
 unsigned DXILBitcodeWriter::getTypeID(Type *T, const Value *V) {
-  if (!T->isOpaquePointerTy())
+  if (!T->isOpaquePointerTy() &&
+      // For Constant, always check PointerMap to make sure OpaquePointer in
+      // things like constant struct/array works.
+      (!V || !isa<Constant>(V)))
     return VE.getTypeID(T);
   auto It = PointerMap.find(V);
   if (It != PointerMap.end())
     return VE.getTypeID(It->second);
+  // For Constant, return T when cannot find in PointerMap.
+  // FIXME: support ConstantPointerNull which could map to more than one
+  // TypedPointerType.
+  // See https://github.com/llvm/llvm-project/issues/57942.
+  if (V && isa<Constant>(V) && !isa<ConstantPointerNull>(V))
+    return VE.getTypeID(T);
   return VE.getTypeID(I8PtrTy);
 }
 
-unsigned DXILBitcodeWriter::getTypeID(Type *T, const Function *F) {
-  auto It = PointerMap.find(F);
-  if (It != PointerMap.end())
-    return VE.getTypeID(It->second);
+unsigned DXILBitcodeWriter::getGlobalObjectValueTypeID(Type *T,
+                                                       const GlobalObject *G) {
+  auto It = PointerMap.find(G);
+  if (It != PointerMap.end()) {
+    TypedPointerType *PtrTy = cast<TypedPointerType>(It->second);
+    return VE.getTypeID(PtrTy->getElementType());
+  }
   return VE.getTypeID(T);
 }
 
@@ -1209,7 +1225,10 @@ void DXILBitcodeWriter::writeModuleInfo() {
   };
   for (const GlobalVariable &GV : M.globals()) {
     UpdateMaxAlignment(GV.getAlign());
-    MaxGlobalType = std::max(MaxGlobalType, getTypeID(GV.getValueType(), &GV));
+    // Use getGlobalObjectValueTypeID to look up the enumerated type ID for
+    // Global Variable types.
+    MaxGlobalType = std::max(
+        MaxGlobalType, getGlobalObjectValueTypeID(GV.getValueType(), &GV));
     if (GV.hasSection()) {
       // Give section names unique ID's.
       unsigned &Entry = SectionMap[std::string(GV.getSection())];
@@ -1281,7 +1300,7 @@ void DXILBitcodeWriter::writeModuleInfo() {
     //             linkage, alignment, section, visibility, threadlocal,
     //             unnamed_addr, externally_initialized, dllstorageclass,
     //             comdat]
-    Vals.push_back(getTypeID(GV.getValueType(), &GV));
+    Vals.push_back(getGlobalObjectValueTypeID(GV.getValueType(), &GV));
     Vals.push_back(
         GV.getType()->getAddressSpace() << 2 | 2 |
         (GV.isConstant() ? 1 : 0)); // HLSL Change - bitwise | was used with
@@ -1317,7 +1336,7 @@ void DXILBitcodeWriter::writeModuleInfo() {
     // FUNCTION:  [type, callingconv, isproto, linkage, paramattrs, alignment,
     //             section, visibility, gc, unnamed_addr, prologuedata,
     //             dllstorageclass, comdat, prefixdata, personalityfn]
-    Vals.push_back(getTypeID(F.getFunctionType(), &F));
+    Vals.push_back(getGlobalObjectValueTypeID(F.getFunctionType(), &F));
     Vals.push_back(F.getCallingConv());
     Vals.push_back(F.isDeclaration());
     Vals.push_back(getEncodedLinkage(F));
@@ -1971,7 +1990,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     // If we need to switch types, do so now.
     if (V->getType() != LastTy) {
       LastTy = V->getType();
-      Record.push_back(getTypeID(LastTy));
+      Record.push_back(getTypeID(LastTy, V));
       Stream.EmitRecord(bitc::CST_CODE_SETTYPE, Record,
                         CONSTANTS_SETTYPE_ABBREV);
       Record.clear();
@@ -2106,7 +2125,8 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
         if (Instruction::isCast(CE->getOpcode())) {
           Code = bitc::CST_CODE_CE_CAST;
           Record.push_back(getEncodedCastOpcode(CE->getOpcode()));
-          Record.push_back(getTypeID(C->getOperand(0)->getType()));
+          Record.push_back(
+              getTypeID(C->getOperand(0)->getType(), C->getOperand(0)));
           Record.push_back(VE.getValueID(C->getOperand(0)));
           AbbrevToUse = CONSTANTS_CE_CAST_Abbrev;
         } else {
@@ -2127,7 +2147,8 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
           Code = bitc::CST_CODE_CE_INBOUNDS_GEP;
         Record.push_back(getTypeID(GO->getSourceElementType()));
         for (unsigned i = 0, e = CE->getNumOperands(); i != e; ++i) {
-          Record.push_back(getTypeID(C->getOperand(i)->getType()));
+          Record.push_back(
+              getTypeID(C->getOperand(i)->getType(), C->getOperand(i)));
           Record.push_back(VE.getValueID(C->getOperand(i)));
         }
         break;
@@ -2529,7 +2550,7 @@ void DXILBitcodeWriter::writeInstruction(const Instruction &I, unsigned InstID,
     Vals.push_back(VE.getAttributeListID(CI.getAttributes()));
     Vals.push_back((CI.getCallingConv() << 1) | unsigned(CI.isTailCall()) |
                    unsigned(CI.isMustTailCall()) << 14 | 1 << 15);
-    Vals.push_back(getTypeID(FTy, CI.getCalledFunction()));
+    Vals.push_back(getGlobalObjectValueTypeID(FTy, CI.getCalledFunction()));
     pushValueAndType(CI.getCalledOperand(), InstID, Vals); // Callee
 
     // Emit value #'s for the fixed parameters.

diff  --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
index 1d536bbd00114..eea89941983bb 100644
--- a/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
+++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PointerTypeAnalysis.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/Instructions.h"
 
 using namespace llvm;
@@ -20,22 +21,32 @@ namespace {
 
 // Classifies the type of the value passed in by walking the value's users to
 // find a typed instruction to materialize a type from.
-TypedPointerType *classifyPointerType(const Value *V) {
+Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
   assert(V->getType()->isOpaquePointerTy() &&
          "classifyPointerType called with non-opaque pointer");
+  auto It = Map.find(V);
+  if (It != Map.end())
+    return It->second;
+
   Type *PointeeTy = nullptr;
   if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
     if (!Inst->getResultElementType()->isOpaquePointerTy())
       PointeeTy = Inst->getResultElementType();
   } else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
     PointeeTy = Inst->getAllocatedType();
+  } else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
+    PointeeTy = GV->getValueType();
   }
+
   for (const auto *User : V->users()) {
     Type *NewPointeeTy = nullptr;
     if (const auto *Inst = dyn_cast<LoadInst>(User)) {
       NewPointeeTy = Inst->getType();
     } else if (const auto *Inst = dyn_cast<StoreInst>(User)) {
       NewPointeeTy = Inst->getValueOperand()->getType();
+      // When store value is ptr type, cannot get more type info.
+      if (NewPointeeTy->isOpaquePointerTy())
+        continue;
     } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
       NewPointeeTy = Inst->getSourceElementType();
     }
@@ -43,9 +54,10 @@ TypedPointerType *classifyPointerType(const Value *V) {
       // HLSL doesn't support pointers, so it is unlikely to get more than one
       // or two levels of indirection in the IR. Because of this, recursion is
       // pretty safe.
-      if (NewPointeeTy->isOpaquePointerTy())
-        return TypedPointerType::get(classifyPointerType(User),
-                                     V->getType()->getPointerAddressSpace());
+      if (NewPointeeTy->isOpaquePointerTy()) {
+        PointeeTy = classifyPointerType(User, Map);
+        break;
+      }
       if (!PointeeTy)
         PointeeTy = NewPointeeTy;
       else if (PointeeTy != NewPointeeTy)
@@ -55,65 +67,143 @@ TypedPointerType *classifyPointerType(const Value *V) {
   // If we were unable to determine the pointee type, set to i8
   if (!PointeeTy)
     PointeeTy = Type::getInt8Ty(V->getContext());
-  return TypedPointerType::get(PointeeTy,
-                               V->getType()->getPointerAddressSpace());
+  auto *TypedPtrTy =
+      TypedPointerType::get(PointeeTy, V->getType()->getPointerAddressSpace());
+
+  Map[V] = TypedPtrTy;
+  return TypedPtrTy;
 }
 
 // This function constructs a function type accepting typed pointers. It only
 // handles function arguments and return types, and assigns the function type to
 // the function's value in the type map.
-void classifyFunctionType(const Function &F, PointerTypeMap &Map) {
+Type *classifyFunctionType(const Function &F, PointerTypeMap &Map) {
+  auto It = Map.find(&F);
+  if (It != Map.end())
+    return It->second;
+
   SmallVector<Type *, 8> NewArgs;
-  bool HasOpaqueTy = false;
   Type *RetTy = F.getReturnType();
+  LLVMContext &Ctx = F.getContext();
   if (RetTy->isOpaquePointerTy()) {
     RetTy = nullptr;
     for (const auto &B : F) {
-      for (const auto &I : B) {
-        if (const auto *RetInst = dyn_cast_or_null<ReturnInst>(&I)) {
-          Type *NewRetTy = classifyPointerType(RetInst->getReturnValue());
-          if (!RetTy)
-            RetTy = NewRetTy;
-          else if (RetTy != NewRetTy)
-            RetTy = TypedPointerType::get(
-                Type::getInt8Ty(I.getContext()),
-                F.getReturnType()->getPointerAddressSpace());
-        }
-      }
+      const auto *RetInst = dyn_cast_or_null<ReturnInst>(B.getTerminator());
+      if (!RetInst)
+        continue;
+
+      Type *NewRetTy = classifyPointerType(RetInst->getReturnValue(), Map);
+      if (!RetTy)
+        RetTy = NewRetTy;
+      else if (RetTy != NewRetTy)
+        RetTy = TypedPointerType::get(
+            Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
     }
+    // For function decl.
+    if (!RetTy)
+      RetTy = TypedPointerType::get(
+          Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
   }
   for (auto &A : F.args()) {
     Type *ArgTy = A.getType();
-    if (ArgTy->isOpaquePointerTy()) {
-      TypedPointerType *NewTy = classifyPointerType(&A);
-      Map[&A] = NewTy;
-      ArgTy = NewTy;
-      HasOpaqueTy = true;
-    }
+    if (ArgTy->isOpaquePointerTy())
+      ArgTy = classifyPointerType(&A, Map);
     NewArgs.push_back(ArgTy);
   }
-  if (!HasOpaqueTy)
-    return;
-  Map[&F] = FunctionType::get(RetTy, NewArgs, false);
+  auto *TypedPtrTy =
+      TypedPointerType::get(FunctionType::get(RetTy, NewArgs, false), 0);
+  Map[&F] = TypedPtrTy;
+  return TypedPtrTy;
 }
 } // anonymous namespace
 
+static Type *classifyConstantWithOpaquePtr(const Constant *C,
+                                           PointerTypeMap &Map) {
+  // FIXME: support ConstantPointerNull which could map to more than one
+  // TypedPointerType.
+  // See https://github.com/llvm/llvm-project/issues/57942.
+  if (isa<ConstantPointerNull>(C))
+    return TypedPointerType::get(Type::getInt8Ty(C->getContext()),
+                                 C->getType()->getPointerAddressSpace());
+
+  // Skip ConstantData which cannot have opaque ptr.
+  if (isa<ConstantData>(C))
+    return C->getType();
+
+  auto It = Map.find(C);
+  if (It != Map.end())
+    return It->second;
+
+  if (const auto *F = dyn_cast<Function>(C))
+    return classifyFunctionType(*F, Map);
+
+  Type *Ty = C->getType();
+  Type *TargetTy = nullptr;
+  if (auto *CS = dyn_cast<ConstantStruct>(C)) {
+    SmallVector<Type *> EltTys;
+    for (unsigned int I = 0; I < CS->getNumOperands(); ++I) {
+      const Constant *Elt = C->getAggregateElement(I);
+      Type *EltTy = classifyConstantWithOpaquePtr(Elt, Map);
+      EltTys.emplace_back(EltTy);
+    }
+    TargetTy = StructType::get(C->getContext(), EltTys);
+  } else if (auto *CA = dyn_cast<ConstantAggregate>(C)) {
+
+    Type *TargetEltTy = nullptr;
+    for (auto &Elt : CA->operands()) {
+      Type *EltTy = classifyConstantWithOpaquePtr(cast<Constant>(&Elt), Map);
+      assert(TargetEltTy == EltTy || TargetEltTy == nullptr);
+      TargetEltTy = EltTy;
+    }
+
+    if (auto *AT = dyn_cast<ArrayType>(Ty)) {
+      TargetTy = ArrayType::get(TargetEltTy, AT->getNumElements());
+    } else {
+      // Not struct, not array, must be vector here.
+      auto *VT = cast<VectorType>(Ty);
+      TargetTy = VectorType::get(TargetEltTy, VT);
+    }
+  }
+  // Must have a target ty when map.
+  assert(TargetTy && "PointerTypeAnalyisis failed to identify target type");
+
+  // Same type, no need to map.
+  if (TargetTy == Ty)
+    return Ty;
+
+  Map[C] = TargetTy;
+  return TargetTy;
+}
+
+static void classifyGlobalCtorPointerType(const GlobalVariable &GV,
+                                          PointerTypeMap &Map) {
+  const auto *CA = cast<ConstantArray>(GV.getInitializer());
+  // Type for global ctor should be array of { i32, void ()*, i8* }.
+  Type *CtorArrayTy = classifyConstantWithOpaquePtr(CA, Map);
+
+  // Map the global type.
+  Map[&GV] = TypedPointerType::get(CtorArrayTy,
+                                   GV.getType()->getPointerAddressSpace());
+}
+
 PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
   PointerTypeMap Map;
   for (auto &G : M.globals()) {
     if (G.getType()->isOpaquePointerTy())
-      Map[&G] = classifyPointerType(&G);
+      classifyPointerType(&G, Map);
+    if (G.getName() == "llvm.global_ctors")
+      classifyGlobalCtorPointerType(G, Map);
   }
+
   for (auto &F : M) {
     classifyFunctionType(F, Map);
 
     for (const auto &B : F) {
       for (const auto &I : B) {
         if (I.getType()->isOpaquePointerTy())
-          Map[&I] = classifyPointerType(&I);
+          classifyPointerType(&I, Map);
       }
     }
   }
-
   return Map;
 }

diff  --git a/llvm/test/tools/dxil-dis/global_ctor.ll b/llvm/test/tools/dxil-dis/global_ctor.ll
new file mode 100644
index 0000000000000..55c4f7e8709dc
--- /dev/null
+++ b/llvm/test/tools/dxil-dis/global_ctor.ll
@@ -0,0 +1,54 @@
+; RUN: llc --filetype=obj %s -o - 2>&1 | dxil-dis -o - | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-unknown-shadermodel6.7-library"
+; Make sure global ctor type is changed to void ()*.
+; CHECK:@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* @_GLOBAL__sub_I_static_global.hlsl, i8* null }]
+
+ at f = internal unnamed_addr global float 0.000000e+00, align 4
+ at llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_static_global.hlsl, ptr null }]
+
+declare float @"?init@@YAMXZ"() local_unnamed_addr #0
+
+; Function Attrs: nounwind
+define float @"?foo@@YAMXZ"() local_unnamed_addr #1 {
+entry:
+  %0 = load float, ptr @f, align 4, !tbaa !4
+  %inc = fadd float %0, 1.000000e+00
+  store float %inc, ptr @f, align 4, !tbaa !4
+  ret float %0
+}
+
+; Function Attrs: nounwind
+define float @"?bar@@YAMXZ"() local_unnamed_addr #1 {
+entry:
+  %0 = load float, ptr @f, align 4, !tbaa !4
+  %dec = fadd float %0, -1.000000e+00
+  store float %dec, ptr @f, align 4, !tbaa !4
+  ret float %0
+}
+
+; Function Attrs: nounwind
+define internal void @_GLOBAL__sub_I_static_global.hlsl() #1 {
+entry:
+  %call.i = tail call float @"?init@@YAMXZ"() #2
+  store float %call.i, ptr @f, align 4, !tbaa !4
+  ret void
+}
+
+attributes #0 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { nounwind }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+!dx.valver = !{!3}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 7, !"frame-pointer", i32 2}
+!2 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git c5dfff0e58cc66d74e666c31368f6d44328dd2f7)"}
+!3 = !{i32 1, i32 7}
+!4 = !{!5, !5, i64 0}
+!5 = !{!"float", !6, i64 0}
+!6 = !{!"omnipotent char", !7, i64 0}
+!7 = !{!"Simple C++ TBAA"}

diff  --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
index ee6c4f90ac259..7b1e4bfb96ae1 100644
--- a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
+++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
@@ -46,8 +46,9 @@ TEST(PointerTypeAnalysis, DigressToi8) {
   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
-  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));  
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
 }
 
 TEST(PointerTypeAnalysis, DiscoverStore) {
@@ -68,7 +69,8 @@ TEST(PointerTypeAnalysis, DiscoverStore) {
   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
 }
 
@@ -90,7 +92,8 @@ TEST(PointerTypeAnalysis, DiscoverLoad) {
   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
 }
 
@@ -113,7 +116,8 @@ TEST(PointerTypeAnalysis, DiscoverGEP) {
   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
   Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
   EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
 }
@@ -139,7 +143,8 @@ TEST(PointerTypeAnalysis, TraceIndirect) {
   Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
   EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
 }
@@ -168,7 +173,8 @@ TEST(PointerTypeAnalysis, WithNoOpCasts) {
   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
 
-  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map,
+              Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));


        


More information about the llvm-commits mailing list