[llvm] e56ad22 - [DirectX] Encapsulate DXILOpLowering's state into a class. NFC

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 10:51:35 PDT 2024


Author: Justin Bogner
Date: 2024-08-20T10:51:32-07:00
New Revision: e56ad22b4a41e65984d6997b2c2496a20f906d1d

URL: https://github.com/llvm/llvm-project/commit/e56ad22b4a41e65984d6997b2c2496a20f906d1d
DIFF: https://github.com/llvm/llvm-project/commit/e56ad22b4a41e65984d6997b2c2496a20f906d1d.diff

LOG: [DirectX] Encapsulate DXILOpLowering's state into a class. NFC

This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops,
and moves the DXILOpBuilder there instead of creating a new one for every
operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes
it simpler to ensure that it isn't misused.

Pull Request: https://github.com/llvm/llvm-project/pull/104248

Added: 
    

Modified: 
    llvm/lib/Target/DirectX/DXILOpBuilder.cpp
    llvm/lib/Target/DirectX/DXILOpBuilder.h
    llvm/lib/Target/DirectX/DXILOpLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index ca372026141fbf..8e26483d675c89 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -11,7 +11,6 @@
 
 #include "DXILOpBuilder.h"
 #include "DXILConstants.h"
-#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -335,7 +334,7 @@ namespace dxil {
 // Triple is well-formed or that the target is supported since these checks
 // would have been done at the time the module M is constructed in the earlier
 // stages of compilation.
-DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {
+DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
   Triple TT(Triple(M.getTargetTriple()));
   DXILVersion = TT.getDXILVersion();
   ShaderStage = TT.getEnvironment();
@@ -417,10 +416,10 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
 
   // We need to inject the opcode as the first argument.
   SmallVector<Value *> OpArgs;
-  OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode)));
+  OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
   OpArgs.append(Args.begin(), Args.end());
 
-  return B.CreateCall(DXILFn, OpArgs);
+  return IRB.CreateCall(DXILFn, OpArgs);
 }
 
 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 5d83357f7a2e94..483d5ddc8b6197 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -14,8 +14,9 @@
 
 #include "DXILConstants.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/TargetParser/Triple.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/Support/Error.h"
+#include "llvm/TargetParser/Triple.h"
 
 namespace llvm {
 class Module;
@@ -29,7 +30,9 @@ namespace dxil {
 
 class DXILOpBuilder {
 public:
-  DXILOpBuilder(Module &M, IRBuilderBase &B);
+  DXILOpBuilder(Module &M);
+
+  IRBuilder<> &getIRB() { return IRB; }
 
   /// Create a call instruction for the given DXIL op. The arguments
   /// must be valid for an overload of the operation.
@@ -51,7 +54,7 @@ class DXILOpBuilder {
                                   Type *OverloadType = nullptr);
 
   Module &M;
-  IRBuilderBase &B;
+  IRBuilder<> IRB;
   VersionTuple DXILVersion;
   Triple::EnvironmentType ShaderStage;
 };

diff  --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 5f84cdcfda6dea..e458720fcd6e9f 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -73,67 +73,84 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
-static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
-  IRBuilder<> B(M.getContext());
-  DXILOpBuilder OpBuilder(M, B);
-  for (User *U : make_early_inc_range(F.users())) {
-    CallInst *CI = dyn_cast<CallInst>(U);
-    if (!CI)
-      continue;
-
-    SmallVector<Value *> Args;
-    B.SetInsertPoint(CI);
-    if (isVectorArgExpansion(F)) {
-      SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
-      Args.append(NewArgs.begin(), NewArgs.end());
-    } else
-      Args.append(CI->arg_begin(), CI->arg_end());
-
-    Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args,
-                                                             F.getReturnType());
-    if (Error E = OpCallOrErr.takeError()) {
-      std::string Message(toString(std::move(E)));
-      DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
-                                     CI->getDebugLoc());
-      M.getContext().diagnose(Diag);
-      continue;
+namespace {
+class OpLowerer {
+  Module &M;
+  DXILOpBuilder OpBuilder;
+
+public:
+  OpLowerer(Module &M) : M(M), OpBuilder(M) {}
+
+  void replaceFunction(Function &F,
+                       llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
+    for (User *U : make_early_inc_range(F.users())) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI)
+        continue;
+
+      if (Error E = ReplaceCall(CI)) {
+        std::string Message(toString(std::move(E)));
+        DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
+                                       CI->getDebugLoc());
+        M.getContext().diagnose(Diag);
+        continue;
+      }
     }
-    CallInst *OpCall = *OpCallOrErr;
+    if (F.user_empty())
+      F.eraseFromParent();
+  }
 
-    CI->replaceAllUsesWith(OpCall);
-    CI->eraseFromParent();
+  void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
+    bool IsVectorArgExpansion = isVectorArgExpansion(F);
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      SmallVector<Value *> Args;
+      OpBuilder.getIRB().SetInsertPoint(CI);
+      if (IsVectorArgExpansion) {
+        SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+        Args.append(NewArgs.begin(), NewArgs.end());
+      } else
+        Args.append(CI->arg_begin(), CI->arg_end());
+
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType());
+      if (Error E = OpCall.takeError())
+        return E;
+
+      CI->replaceAllUsesWith(*OpCall);
+      CI->eraseFromParent();
+      return Error::success();
+    });
   }
-  if (F.user_empty())
-    F.eraseFromParent();
-}
 
-static bool lowerIntrinsics(Module &M) {
-  bool Updated = false;
+  bool lowerIntrinsics() {
+    bool Updated = false;
 
-  for (Function &F : make_early_inc_range(M.functions())) {
-    if (!F.isDeclaration())
-      continue;
-    Intrinsic::ID ID = F.getIntrinsicID();
-    switch (ID) {
-    default:
-      continue;
+    for (Function &F : make_early_inc_range(M.functions())) {
+      if (!F.isDeclaration())
+        continue;
+      Intrinsic::ID ID = F.getIntrinsicID();
+      switch (ID) {
+      default:
+        continue;
 #define DXIL_OP_INTRINSIC(OpCode, Intrin)                                      \
   case Intrin:                                                                 \
-    lowerIntrinsic(OpCode, F, M);                                              \
+    replaceFunctionWithOp(F, OpCode);                                          \
     break;
 #include "DXILOperation.inc"
+      }
+      Updated = true;
     }
-    Updated = true;
+    return Updated;
   }
-  return Updated;
-}
+};
+} // namespace
 
 namespace {
 /// A pass that transforms external global definitions into declarations.
 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
-    if (lowerIntrinsics(M))
+    if (OpLowerer(M).lowerIntrinsics())
       return PreservedAnalyses::none();
     return PreservedAnalyses::all();
   }
@@ -143,7 +160,9 @@ class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
 namespace {
 class DXILOpLoweringLegacy : public ModulePass {
 public:
-  bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
+  bool runOnModule(Module &M) override {
+    return OpLowerer(M).lowerIntrinsics();
+  }
   StringRef getPassName() const override { return "DXIL Op Lowering"; }
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 


        


More information about the llvm-commits mailing list