[llvm] e6f44a3 - Add PointerType analysis for DirectX backend

Chris Bieneman via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 25 15:55:46 PDT 2022


Author: Chris Bieneman
Date: 2022-04-25T17:49:43-05:00
New Revision: e6f44a3cd2735e92987f51ea59ae44f959807df4

URL: https://github.com/llvm/llvm-project/commit/e6f44a3cd2735e92987f51ea59ae44f959807df4
DIFF: https://github.com/llvm/llvm-project/commit/e6f44a3cd2735e92987f51ea59ae44f959807df4.diff

LOG: Add PointerType analysis for DirectX backend

As implemented this patch assumes that Typed pointer support remains in
the llvm::PointerType class, however this could be modified to use a
different subclass of llvm::Type that could be disallowed from use in
other contexts.

This does not rely on inserting typed pointers into the Module, it just
uses the llvm::PointerType class to track and unique types.

Fixes #54918

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D122268

Added: 
    llvm/lib/Target/DirectX/DXILPointerType.cpp
    llvm/lib/Target/DirectX/DXILPointerType.h
    llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
    llvm/lib/Target/DirectX/PointerTypeAnalysis.h
    llvm/unittests/Target/DirectX/CMakeLists.txt
    llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp

Modified: 
    llvm/include/llvm/IR/LLVMContext.h
    llvm/include/llvm/IR/Type.h
    llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/IR/Core.cpp
    llvm/lib/IR/LLVMContext.cpp
    llvm/lib/IR/LLVMContextImpl.h
    llvm/lib/Target/DirectX/CMakeLists.txt
    llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h
index 792e59a303623..91712df153a0b 100644
--- a/llvm/include/llvm/IR/LLVMContext.h
+++ b/llvm/include/llvm/IR/LLVMContext.h
@@ -24,6 +24,7 @@
 
 namespace llvm {
 
+class Any;
 class DiagnosticInfo;
 enum DiagnosticSeverity : char;
 class Function;
@@ -322,6 +323,10 @@ class LLVMContext {
   /// Whether typed pointers are supported. If false, all pointers are opaque.
   bool supportsTypedPointers() const;
 
+  /// Optionally target-spcific data can be attached to the context for lifetime
+  /// management and bypassing layering restrictions.
+  llvm::Any &getTargetData() const;
+
 private:
   // Module needs access to the add/removeModule methods.
   friend class Module;

diff  --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index f998eeb7dc7c2..51263c6b8fccc 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -68,13 +68,14 @@ class Type {
     TokenTyID,     ///< Tokens
 
     // Derived types... see DerivedTypes.h file.
-    IntegerTyID,       ///< Arbitrary bit width integers
-    FunctionTyID,      ///< Functions
-    PointerTyID,       ///< Pointers
-    StructTyID,        ///< Structures
-    ArrayTyID,         ///< Arrays
-    FixedVectorTyID,   ///< Fixed width SIMD vector type
-    ScalableVectorTyID ///< Scalable SIMD vector type
+    IntegerTyID,        ///< Arbitrary bit width integers
+    FunctionTyID,       ///< Functions
+    PointerTyID,        ///< Pointers
+    StructTyID,         ///< Structures
+    ArrayTyID,          ///< Arrays
+    FixedVectorTyID,    ///< Fixed width SIMD vector type
+    ScalableVectorTyID, ///< Scalable SIMD vector type
+    DXILPointerTyID,    ///< DXIL typed pointer used by DirectX target
   };
 
 private:

diff  --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8131c1cfb6f86..e773ba8607faa 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1018,6 +1018,8 @@ void ModuleBitcodeWriter::writeTypeTable() {
         TypeVals.push_back(true);
       break;
     }
+    case Type::DXILPointerTyID:
+      llvm_unreachable("DXIL pointers cannot be added to IR modules");
     }
 
     // Emit the finished record.

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index ae1ed2ddaf6b0..66595bcfb9aab 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -612,6 +612,11 @@ void TypePrinting::print(Type *Ty, raw_ostream &OS) {
     OS << '>';
     return;
   }
+  case Type::DXILPointerTyID:
+    // DXIL pointer types are only handled by the DirectX backend. To avoid
+    // extra dependencies we just print the pointer's address here.
+    OS << "dxil-ptr (" << Ty << ")";
+    return;
   }
   llvm_unreachable("Invalid TypeID");
 }

diff  --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index 182be54c844bd..f651e3dd20ad6 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -534,6 +534,8 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
     return LLVMTokenTypeKind;
   case Type::ScalableVectorTyID:
     return LLVMScalableVectorTypeKind;
+  case Type::DXILPointerTyID:
+    llvm_unreachable("DXIL pointers are unsupported via the C API");
   }
   llvm_unreachable("Unhandled TypeID.");
 }

diff  --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp
index 09ab4b9f75fe5..9a0e9ac85c26f 100644
--- a/llvm/lib/IR/LLVMContext.cpp
+++ b/llvm/lib/IR/LLVMContext.cpp
@@ -374,3 +374,7 @@ void LLVMContext::setOpaquePointers(bool Enable) const {
 bool LLVMContext::supportsTypedPointers() const {
   return !pImpl->getOpaquePointers();
 }
+
+Any &LLVMContext::getTargetData() const {
+  return pImpl->TargetDataStorage;
+}

diff  --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index e6a6a61038808..8db171527b6a0 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -17,6 +17,7 @@
 #include "ConstantsContext.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/Any.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseMapInfo.h"
@@ -1567,6 +1568,8 @@ class LLVMContextImpl {
   bool hasOpaquePointersValue();
   void setOpaquePointers(bool OP);
 
+  llvm::Any TargetDataStorage;
+
 private:
   Optional<bool> OpaquePointers;
 };

diff  --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 2b119a2c8bad3..d76eb976c57ba 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -9,7 +9,9 @@ add_public_tablegen_target(DirectXCommonTableGen)
 add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
+  DXILPointerType.cpp
   DXILPrepare.cpp
+  PointerTypeAnalysis.cpp
 
   LINK_COMPONENTS
   Core

diff  --git a/llvm/lib/Target/DirectX/DXILPointerType.cpp b/llvm/lib/Target/DirectX/DXILPointerType.cpp
new file mode 100644
index 0000000000000..1e67f1a30ec4a
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILPointerType.cpp
@@ -0,0 +1,66 @@
+//===- Target/DirectX/DXILTypedPointerType.cpp - DXIL Typed Pointer Type
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILPointerType.h"
+#include "llvm/ADT/Any.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/LLVMContext.h"
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+class TypedPointerTracking {
+public:
+  TypedPointerTracking() {}
+  DenseMap<Type *, std::unique_ptr<TypedPointerType>> PointerTypes;
+  DenseMap<std::pair<Type *, unsigned>, std::unique_ptr<TypedPointerType>>
+      ASPointerTypes;
+};
+
+TypedPointerType *TypedPointerType::get(Type *EltTy, unsigned AddressSpace) {
+  assert(EltTy && "Can't get a pointer to <null> type!");
+  assert(isValidElementType(EltTy) && "Invalid type for pointer element!");
+
+  llvm::Any &TargetData = EltTy->getContext().getTargetData();
+  if (!TargetData.hasValue())
+    TargetData = Any{std::make_shared<TypedPointerTracking>()};
+
+  assert(any_isa<std::shared_ptr<TypedPointerTracking>>(TargetData) &&
+         "Unexpected target data type");
+
+  std::shared_ptr<TypedPointerTracking> Tracking =
+      any_cast<std::shared_ptr<TypedPointerTracking>>(TargetData);
+
+  // Since AddressSpace #0 is the common case, we special case it.
+  std::unique_ptr<TypedPointerType> &Entry =
+      AddressSpace == 0
+          ? Tracking->PointerTypes[EltTy]
+          : Tracking->ASPointerTypes[std::make_pair(EltTy, AddressSpace)];
+
+  if (!Entry)
+    Entry = std::unique_ptr<TypedPointerType>(
+        new TypedPointerType(EltTy, AddressSpace));
+  return Entry.get();
+}
+
+TypedPointerType::TypedPointerType(Type *E, unsigned AddrSpace)
+    : Type(E->getContext(), DXILPointerTyID), PointeeTy(E) {
+  ContainedTys = &PointeeTy;
+  NumContainedTys = 1;
+  setSubclassData(AddrSpace);
+}
+
+bool TypedPointerType::isValidElementType(Type *ElemTy) {
+  return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() &&
+         !ElemTy->isMetadataTy() && !ElemTy->isTokenTy() &&
+         !ElemTy->isX86_AMXTy();
+}

diff  --git a/llvm/lib/Target/DirectX/DXILPointerType.h b/llvm/lib/Target/DirectX/DXILPointerType.h
new file mode 100644
index 0000000000000..52cf2dbc40b04
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILPointerType.h
@@ -0,0 +1,52 @@
+//===- Target/DirectX/DXILPointerType.h - DXIL Typed Pointer Type ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H
+#define LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H
+
+#include "llvm/IR/Type.h"
+
+namespace llvm {
+namespace dxil {
+
+// DXIL has typed pointers, this pointer type abstraction is used for tracking
+// in PointerTypeAnalysis and for the bitcode ValueEnumerator
+class TypedPointerType : public Type {
+  explicit TypedPointerType(Type *ElType, unsigned AddrSpace);
+
+  Type *PointeeTy;
+
+public:
+  TypedPointerType(const TypedPointerType &) = delete;
+  TypedPointerType &operator=(const TypedPointerType &) = delete;
+
+  /// This constructs a pointer to an object of the specified type in a numbered
+  /// address space.
+  static TypedPointerType *get(Type *ElementType, unsigned AddressSpace);
+
+  /// Return true if the specified type is valid as a element type.
+  static bool isValidElementType(Type *ElemTy);
+
+  /// Return the address space of the Pointer type.
+  unsigned getAddressSpace() const { return getSubclassData(); }
+
+  Type *getElementType() const { return PointeeTy; }
+
+  /// Implement support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Type *T) {
+    return T->getTypeID() == DXILPointerTyID;
+  }
+};
+
+} // namespace dxil
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H

diff  --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
index 5dbfe6f74c86f..c6f9308d8603a 100644
--- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
+++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
@@ -1049,6 +1049,7 @@ void DXILBitcodeWriter::writeTypeTable() {
     case Type::BFloatTyID:
     case Type::X86_AMXTyID:
     case Type::TokenTyID:
+    case Type::DXILPointerTyID:
       llvm_unreachable("These should never be used!!!");
       break;
     case Type::VoidTyID:

diff  --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
new file mode 100644
index 0000000000000..1d536bbd00114
--- /dev/null
+++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp
@@ -0,0 +1,119 @@
+//===- Target/DirectX/PointerTypeAnalisis.cpp - PointerType analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Analysis pass to assign types to opaque pointers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PointerTypeAnalysis.h"
+#include "llvm/IR/Instructions.h"
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+namespace {
+
+// Classifies the type of the value passed in by walking the value's users to
+// find a typed instruction to materialize a type from.
+TypedPointerType *classifyPointerType(const Value *V) {
+  assert(V->getType()->isOpaquePointerTy() &&
+         "classifyPointerType called with non-opaque pointer");
+  Type *PointeeTy = nullptr;
+  if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
+    if (!Inst->getResultElementType()->isOpaquePointerTy())
+      PointeeTy = Inst->getResultElementType();
+  } else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
+    PointeeTy = Inst->getAllocatedType();
+  }
+  for (const auto *User : V->users()) {
+    Type *NewPointeeTy = nullptr;
+    if (const auto *Inst = dyn_cast<LoadInst>(User)) {
+      NewPointeeTy = Inst->getType();
+    } else if (const auto *Inst = dyn_cast<StoreInst>(User)) {
+      NewPointeeTy = Inst->getValueOperand()->getType();
+    } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
+      NewPointeeTy = Inst->getSourceElementType();
+    }
+    if (NewPointeeTy) {
+      // HLSL doesn't support pointers, so it is unlikely to get more than one
+      // or two levels of indirection in the IR. Because of this, recursion is
+      // pretty safe.
+      if (NewPointeeTy->isOpaquePointerTy())
+        return TypedPointerType::get(classifyPointerType(User),
+                                     V->getType()->getPointerAddressSpace());
+      if (!PointeeTy)
+        PointeeTy = NewPointeeTy;
+      else if (PointeeTy != NewPointeeTy)
+        PointeeTy = Type::getInt8Ty(V->getContext());
+    }
+  }
+  // If we were unable to determine the pointee type, set to i8
+  if (!PointeeTy)
+    PointeeTy = Type::getInt8Ty(V->getContext());
+  return TypedPointerType::get(PointeeTy,
+                               V->getType()->getPointerAddressSpace());
+}
+
+// This function constructs a function type accepting typed pointers. It only
+// handles function arguments and return types, and assigns the function type to
+// the function's value in the type map.
+void classifyFunctionType(const Function &F, PointerTypeMap &Map) {
+  SmallVector<Type *, 8> NewArgs;
+  bool HasOpaqueTy = false;
+  Type *RetTy = F.getReturnType();
+  if (RetTy->isOpaquePointerTy()) {
+    RetTy = nullptr;
+    for (const auto &B : F) {
+      for (const auto &I : B) {
+        if (const auto *RetInst = dyn_cast_or_null<ReturnInst>(&I)) {
+          Type *NewRetTy = classifyPointerType(RetInst->getReturnValue());
+          if (!RetTy)
+            RetTy = NewRetTy;
+          else if (RetTy != NewRetTy)
+            RetTy = TypedPointerType::get(
+                Type::getInt8Ty(I.getContext()),
+                F.getReturnType()->getPointerAddressSpace());
+        }
+      }
+    }
+  }
+  for (auto &A : F.args()) {
+    Type *ArgTy = A.getType();
+    if (ArgTy->isOpaquePointerTy()) {
+      TypedPointerType *NewTy = classifyPointerType(&A);
+      Map[&A] = NewTy;
+      ArgTy = NewTy;
+      HasOpaqueTy = true;
+    }
+    NewArgs.push_back(ArgTy);
+  }
+  if (!HasOpaqueTy)
+    return;
+  Map[&F] = FunctionType::get(RetTy, NewArgs, false);
+}
+} // anonymous namespace
+
+PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
+  PointerTypeMap Map;
+  for (auto &G : M.globals()) {
+    if (G.getType()->isOpaquePointerTy())
+      Map[&G] = classifyPointerType(&G);
+  }
+  for (auto &F : M) {
+    classifyFunctionType(F, Map);
+
+    for (const auto &B : F) {
+      for (const auto &I : B) {
+        if (I.getType()->isOpaquePointerTy())
+          Map[&I] = classifyPointerType(&I);
+      }
+    }
+  }
+
+  return Map;
+}

diff  --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.h b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h
new file mode 100644
index 0000000000000..c4164b6bf359b
--- /dev/null
+++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h
@@ -0,0 +1,43 @@
+//===- Target/DirectX/PointerTypeAnalysis.h - PointerType analysis --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Analysis pass to assign types to opaque pointers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H
+#define LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H
+
+#include "DXILPointerType.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+namespace dxil {
+
+// Store the underlying type and the number of pointer indirections
+using PointerTypeMap = DenseMap<const Value *, Type *>;
+
+/// An analysis to compute the \c PointerTypes for pointers in a \c Module.
+/// Since this analysis is only run during codegen and the new pass manager
+/// doesn't support codegen passes, this is wrtten as a function in a namespace.
+/// It is very simple to transform it into a proper analysis pass.
+/// This code relies on typed pointers existing as LLVM types, but could be
+/// migrated to a custom Type if PointerType loses typed support.
+namespace PointerTypeAnalysis {
+
+/// Compute the \c PointerTypeMap for the module \c M.
+PointerTypeMap run(const Module &M);
+} // namespace PointerTypeAnalysis
+
+} // namespace dxil
+
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H

diff  --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt
new file mode 100644
index 0000000000000..621b00ae65d3a
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/CMakeLists.txt
@@ -0,0 +1,15 @@
+include_directories(
+  ${LLVM_MAIN_SRC_DIR}/lib/Target/DirectX
+  ${LLVM_BINARY_DIR}/lib/Target/DirectX
+  )
+
+set(LLVM_LINK_COMPONENTS
+  AsmParser
+  Core
+  DirectXCodeGen
+  Support
+)
+
+add_llvm_target_unittest(DirectXTests
+  PointerTypeAnalysisTests.cpp
+  )

diff  --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
new file mode 100644
index 0000000000000..22ec2bda90beb
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp
@@ -0,0 +1,185 @@
+//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILPointerType.h"
+#include "PointerTypeAnalysis.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+TEST(DXILPointerType, PrintTest) {
+  std::string Buffer;
+  LLVMContext Context;
+  raw_string_ostream OS(Buffer);
+
+  Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
+  I8Ptr->print(OS);
+  EXPECT_TRUE(StringRef(Buffer).startswith("dxil-ptr ("));
+}
+
+TEST(PointerTypeAnalysis, DigressToi8) {
+  StringRef Assembly = R"(
+    define i64 @test(ptr %p) {
+      store i32 0, ptr %p
+      %v = load i64, ptr %p
+      ret i64 %v
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 2u);
+  Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
+  Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));  
+}
+
+TEST(PointerTypeAnalysis, DiscoverStore) {
+  StringRef Assembly = R"(
+    define i32 @test(ptr %p) {
+      store i32 0, ptr %p
+      ret i32 0
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 2u);
+  Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
+  Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
+}
+
+TEST(PointerTypeAnalysis, DiscoverLoad) {
+  StringRef Assembly = R"(
+    define i32 @test(ptr %p) {
+      %v = load i32, ptr %p
+      ret i32 %v
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 2u);
+  Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
+  Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
+}
+
+TEST(PointerTypeAnalysis, DiscoverGEP) {
+  StringRef Assembly = R"(
+    define ptr @test(ptr %p) {
+      %p2 = getelementptr i64, ptr %p, i64 1
+      ret ptr %p2
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 3u);
+
+  Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
+  Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
+}
+
+TEST(PointerTypeAnalysis, TraceIndirect) {
+  StringRef Assembly = R"(
+    define i64 @test(ptr %p) {
+      %p2 = load ptr, ptr %p
+      %v = load i64, ptr %p2
+      ret i64 %v
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 3u);
+
+  Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
+  Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
+  Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
+}
+
+TEST(PointerTypeAnalysis, WithNoOpCasts) {
+  StringRef Assembly = R"(
+    define i64 @test(ptr %p) {
+      %1 = bitcast ptr %p to ptr
+      %2 = bitcast ptr %p to ptr
+      store i32 0, ptr %1, align 4
+      %3 = load i64, ptr %2, align 8
+      ret i64 %3
+    }
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  PointerTypeMap Map = PointerTypeAnalysis::run(*M);
+  ASSERT_EQ(Map.size(), 4u);
+
+  Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
+  Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
+  Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
+  Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
+
+  EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
+  EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
+}


        


More information about the llvm-commits mailing list