[llvm] 85285be - [DirectX backend] Add pass to lower llvm intrinsic into dxil op function.

Xiang Li via llvm-commits llvm-commits at lists.llvm.org
Wed May 11 00:03:12 PDT 2022


Author: Xiang Li
Date: 2022-05-11T00:03:05-07:00
New Revision: 85285be9c37ad0b6e3dabe82248d8917a6ebd5ec

URL: https://github.com/llvm/llvm-project/commit/85285be9c37ad0b6e3dabe82248d8917a6ebd5ec
DIFF: https://github.com/llvm/llvm-project/commit/85285be9c37ad0b6e3dabe82248d8917a6ebd5ec.diff

LOG: [DirectX backend] Add pass to lower llvm intrinsic into dxil op function.

A new pass DXILOpLowering was added.
It will scan all llvm intrinsics, create dxil op function if it can map to dxil op function.
Then translate call instructions on the intrinsic into call on dxil op function.
dxil op function will add i32 argument to the begining of args for dxil opcode.
So cannot use setCalledFunction to update the call instruction on intrinsic.

This commit only support sin to start the work.

Reviewed By: kuhar, beanz

Differential Revision: https://reviews.llvm.org/D124805

Added: 
    llvm/lib/Target/DirectX/DXILConstants.h
    llvm/lib/Target/DirectX/DXILOpLowering.cpp
    llvm/test/CodeGen/DirectX/sin.ll

Modified: 
    llvm/lib/Target/DirectX/CMakeLists.txt
    llvm/lib/Target/DirectX/DirectX.h
    llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
    llvm/tools/opt/opt.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index d76eb976c57ba..f2bcbf445bdf9 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -9,6 +9,7 @@ add_public_tablegen_target(DirectXCommonTableGen)
 add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
+  DXILOpLowering.cpp
   DXILPointerType.cpp
   DXILPrepare.cpp
   PointerTypeAnalysis.cpp

diff  --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h
new file mode 100644
index 0000000000000..c7b2be615d0c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILConstants.h
@@ -0,0 +1,29 @@
+//===- DXILConstants.h - Essential DXIL constants -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains essential DXIL constants.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
+
+namespace llvm {
+namespace DXIL {
+// Enumeration for operations specified by DXIL
+enum class OpCode : unsigned {
+  Sin = 13, // returns sine(theta) for theta in radians.
+};
+// Groups for DXIL operations with equivalent function templates
+enum class OpCodeClass : unsigned {
+  Unary,
+};
+
+} // namespace DXIL
+} // namespace llvm
+
+#endif

diff  --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
new file mode 100644
index 0000000000000..f7925b594dcaf
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -0,0 +1,279 @@
+//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains passes and utilities to lower llvm intrinsic call
+/// to DXILOp function call.
+//===----------------------------------------------------------------------===//
+
+#include "DXILConstants.h"
+#include "DirectX.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "dxil-op-lower"
+
+using namespace llvm;
+using namespace llvm::DXIL;
+
+constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
+
+enum OverloadKind : uint16_t {
+  VOID = 1,
+  HALF = 1 << 1,
+  FLOAT = 1 << 2,
+  DOUBLE = 1 << 3,
+  I1 = 1 << 4,
+  I8 = 1 << 5,
+  I16 = 1 << 6,
+  I32 = 1 << 7,
+  I64 = 1 << 8,
+  UserDefineType = 1 << 9,
+  ObjectType = 1 << 10,
+};
+
+static const char *getOverloadTypeName(OverloadKind Kind) {
+  switch (Kind) {
+  case OverloadKind::HALF:
+    return "f16";
+  case OverloadKind::FLOAT:
+    return "f32";
+  case OverloadKind::DOUBLE:
+    return "f64";
+  case OverloadKind::I1:
+    return "i1";
+  case OverloadKind::I8:
+    return "i8";
+  case OverloadKind::I16:
+    return "i16";
+  case OverloadKind::I32:
+    return "i32";
+  case OverloadKind::I64:
+    return "i64";
+  case OverloadKind::VOID:
+  case OverloadKind::ObjectType:
+  case OverloadKind::UserDefineType:
+    llvm_unreachable("invalid overload type for name");
+    break;
+  }
+}
+
+static OverloadKind getOverloadKind(Type *Ty) {
+  Type::TypeID T = Ty->getTypeID();
+  switch (T) {
+  case Type::VoidTyID:
+    return OverloadKind::VOID;
+  case Type::HalfTyID:
+    return OverloadKind::HALF;
+  case Type::FloatTyID:
+    return OverloadKind::FLOAT;
+  case Type::DoubleTyID:
+    return OverloadKind::DOUBLE;
+  case Type::IntegerTyID: {
+    IntegerType *ITy = cast<IntegerType>(Ty);
+    unsigned Bits = ITy->getBitWidth();
+    switch (Bits) {
+    case 1:
+      return OverloadKind::I1;
+    case 8:
+      return OverloadKind::I8;
+    case 16:
+      return OverloadKind::I16;
+    case 32:
+      return OverloadKind::I32;
+    case 64:
+      return OverloadKind::I64;
+    default:
+      llvm_unreachable("invalid overload type");
+      return OverloadKind::VOID;
+    }
+  }
+  case Type::PointerTyID:
+    return OverloadKind::UserDefineType;
+  case Type::StructTyID:
+    return OverloadKind::ObjectType;
+  default:
+    llvm_unreachable("invalid overload type");
+    return OverloadKind::VOID;
+  }
+}
+
+static std::string getTypeName(OverloadKind Kind, Type *Ty) {
+  if (Kind < OverloadKind::UserDefineType) {
+    return getOverloadTypeName(Kind);
+  } else if (Kind == OverloadKind::UserDefineType) {
+    StructType *ST = cast<StructType>(Ty);
+    return ST->getStructName().str();
+  } else if (Kind == OverloadKind::ObjectType) {
+    StructType *ST = cast<StructType>(Ty);
+    return ST->getStructName().str();
+  } else {
+    std::string Str;
+    raw_string_ostream OS(Str);
+    Ty->print(OS);
+    return OS.str();
+  }
+}
+
+// Static properties.
+struct OpCodeProperty {
+  DXIL::OpCode OpCode;
+  // FIXME: change OpCodeName into index to a large string constant when move to
+  // tableGen.
+  const char *OpCodeName;
+  DXIL::OpCodeClass OpCodeClass;
+  uint16_t OverloadTys;
+  llvm::Attribute::AttrKind FuncAttr;
+};
+
+static const char *getOpCodeClassName(const OpCodeProperty &Prop) {
+  // FIXME: generate this table with tableGen.
+  static const char *OpCodeClassNames[] = {
+      "unary",
+  };
+  unsigned Index = static_cast<unsigned>(Prop.OpCodeClass);
+  assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) &&
+         "Out of bound OpCodeClass");
+  return OpCodeClassNames[Index];
+}
+
+static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
+                                         const OpCodeProperty &Prop) {
+  if (Kind == OverloadKind::VOID) {
+    return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
+  }
+  return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
+          getTypeName(Kind, Ty))
+      .str();
+}
+
+static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) {
+  // FIXME: generate this table with tableGen.
+  static const OpCodeProperty OpCodeProps[] = {
+      {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary,
+       OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone},
+  };
+  // FIXME: change search to indexing with
+  // DXILOp once all DXIL op is added.
+  OpCodeProperty TmpProp;
+  TmpProp.OpCode = DXILOp;
+  const OpCodeProperty *Prop =
+      llvm::lower_bound(OpCodeProps, TmpProp,
+                        [](const OpCodeProperty &A, const OpCodeProperty &B) {
+                          return A.OpCode < B.OpCode;
+                        });
+  return Prop;
+}
+
+static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
+                                           Module &M) {
+  const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
+
+  // Get return type as overload type for DXILOp.
+  // Only simple mapping case here, so return type is good enough.
+  Type *OverloadTy = F.getReturnType();
+
+  OverloadKind Kind = getOverloadKind(OverloadTy);
+  // FIXME: find the issue and report error in clang instead of check it in
+  // backend.
+  if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
+    llvm_unreachable("invalid overload");
+  }
+
+  std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
+  assert(!M.getFunction(FnName) && "Function already exists");
+
+  auto &Ctx = M.getContext();
+  Type *OpCodeTy = Type::getInt32Ty(Ctx);
+
+  SmallVector<Type *> ArgTypes;
+  // DXIL has i32 opcode as first arg.
+  ArgTypes.emplace_back(OpCodeTy);
+  FunctionType *FT = F.getFunctionType();
+  ArgTypes.append(FT->param_begin(), FT->param_end());
+  FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
+  return M.getOrInsertFunction(FnName, DXILOpFT);
+}
+
+static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
+  auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
+  IRBuilder<> B(M.getContext());
+  Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+  for (User *U : make_early_inc_range(F.users())) {
+    CallInst *CI = dyn_cast<CallInst>(U);
+    if (!CI)
+      continue;
+
+    SmallVector<Value *> Args;
+    Args.emplace_back(DXILOpArg);
+    Args.append(CI->arg_begin(), CI->arg_end());
+    B.SetInsertPoint(CI);
+    CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
+    CI->replaceAllUsesWith(DXILCI);
+    CI->eraseFromParent();
+  }
+  if (F.user_empty())
+    F.eraseFromParent();
+}
+
+static bool lowerIntrinsics(Module &M) {
+  bool Updated = false;
+  static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = {
+      {Intrinsic::sin, DXIL::OpCode::Sin}};
+  for (Function &F : make_early_inc_range(M.functions())) {
+    if (!F.isDeclaration())
+      continue;
+    Intrinsic::ID ID = F.getIntrinsicID();
+    auto LowerIt = LowerMap.find(ID);
+    if (LowerIt == LowerMap.end())
+      continue;
+    lowerIntrinsic(LowerIt->second, F, M);
+    Updated = true;
+  }
+  return Updated;
+}
+
+namespace {
+/// A pass that transforms external global definitions into declarations.
+class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
+    if (lowerIntrinsics(M))
+      return PreservedAnalyses::none();
+    return PreservedAnalyses::all();
+  }
+};
+} // namespace
+
+namespace {
+class DXILOpLoweringLegacy : public ModulePass {
+public:
+  bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
+  StringRef getPassName() const override { return "DXIL Op Lowering"; }
+  DXILOpLoweringLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+char DXILOpLoweringLegacy::ID = 0;
+
+} // end anonymous namespace
+
+INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
+                      false, false)
+INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
+                    false)
+
+ModulePass *llvm::createDXILOpLoweringLegacyPass() {
+  return new DXILOpLoweringLegacy();
+}

diff  --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 73932aea24fbb..73cd5538ffd32 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -23,6 +23,13 @@ void initializeDXILPrepareModulePass(PassRegistry &);
 
 /// Pass to convert modules into DXIL-compatable modules
 ModulePass *createDXILPrepareModulePass();
+
+/// Initializer for DXILOpLowering
+void initializeDXILOpLoweringLegacyPass(PassRegistry &);
+
+/// Pass to lowering LLVM intrinsic call to DXIL op function call.
+ModulePass *createDXILOpLoweringLegacyPass();
+
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H

diff  --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 98adfbf89bae5..a12d87f8adc8b 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -34,6 +34,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
   initializeDXILPrepareModulePass(*PR);
+  initializeDXILOpLoweringLegacyPass(*PR);
 }
 
 class DXILTargetObjectFile : public TargetLoweringObjectFile {
@@ -84,6 +85,7 @@ bool DirectXTargetMachine::addPassesToEmitFile(
     PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
     CodeGenFileType FileType, bool DisableVerify,
     MachineModuleInfoWrapperPass *MMIWP) {
+  PM.add(createDXILOpLoweringLegacyPass());
   PM.add(createDXILPrepareModulePass());
   switch (FileType) {
   case CGFT_AssemblyFile:

diff  --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
new file mode 100644
index 0000000000000..bb31d28bfcfee
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -0,0 +1,43 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for sin are generated for float and half.
+; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
+; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @_Z3foof(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %1 = call float @llvm.sin.f32(float %0)
+  ret float %1
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare float @llvm.sin.f32(float) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @_Z3barDh(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  %1 = call half @llvm.sin.f16(half %0)
+  ret half %1
+}
+
+; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
+declare half @llvm.sin.f16(half) #1
+
+attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
+
+!llvm.module.flags = !{!0}
+!llvm.ident = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}

diff  --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp
index bd4738ded693b..e43590db594de 100644
--- a/llvm/tools/opt/opt.cpp
+++ b/llvm/tools/opt/opt.cpp
@@ -476,7 +476,7 @@ static bool shouldPinPassToLegacyPM(StringRef Pass) {
       "x86-",    "xcore-", "wasm-",  "systemz-", "ppc-",    "nvvm-",
       "nvptx-",  "mips-",  "lanai-", "hexagon-", "bpf-",    "avr-",
       "thumb2-", "arm-",   "si-",    "gcn-",     "amdgpu-", "aarch64-",
-      "amdgcn-", "polly-", "riscv-"};
+      "amdgcn-", "polly-", "riscv-", "dxil-"};
   std::vector<StringRef> PassNameContain = {"ehprepare"};
   std::vector<StringRef> PassNameExact = {
       "safe-stack",           "cost-model",


        


More information about the llvm-commits mailing list