[llvm-branch-commits] [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops (PR #104252)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 14 14:29:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

The `@<!-- -->llvm.dx.typedBufferLoad` intrinsic is lowered to `@<!-- -->dx.op.bufferLoad`.
There's some complexity here due to translating from a vector return type to a
named struct and trying to avoid excessive IR coming out of that.

Note that this change includes a bit of a hack in how it deals with
`getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal
with operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.


---
Full diff: https://github.com/llvm/llvm-project/pull/104252.diff


7 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+15-1) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+25-6) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+57) 
- (added) llvm/test/CodeGen/DirectX/BufferLoad.ll (+102) 
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+5-1) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ca3682fa47767..d817b610fa71a 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -30,6 +30,10 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_typedBufferLoad
+    : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                            [llvm_any_ty, llvm_i32_ty]>;
+
 // Cast between target extension handle types and dxil-style opaque handles
 def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 31fee04d82158..b114148f84e84 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
 def HalfTy : DXILOpParamType;
 def FloatTy : DXILOpParamType;
 def DoubleTy : DXILOpParamType;
-def ResRetTy : DXILOpParamType;
+def ResRetHalfTy : DXILOpParamType;
+def ResRetFloatTy : DXILOpParamType;
+def ResRetInt16Ty : DXILOpParamType;
+def ResRetInt32Ty : DXILOpParamType;
 def HandleTy : DXILOpParamType;
 def ResBindTy : DXILOpParamType;
 def ResPropsTy : DXILOpParamType;
@@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def BufferLoad : DXILOp<68, bufferLoad> {
+  let Doc = "reads from a TypedBuffer";
+  // Handle, Coord0, Coord1
+  let arguments = [HandleTy, Int32Ty, Int32Ty];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0,
+                 [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def ThreadId :  DXILOp<93, threadId> {
   let Doc = "Reads the thread ID";
   let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 692af1b359ced..246e32c264dc9 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) {
   }
   case Type::PointerTyID:
     return OverloadKind::UserDefineType;
-  case Type::StructTyID:
+  case Type::StructTyID: {
+    // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
+    // how we're handling overloads and remove the `OverloadKind` proxy enum.
+    StructType *ST = cast<StructType>(Ty);
+    if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
+      return getOverloadKind(ST->getElementType(0));
+
     return OverloadKind::ObjectType;
+  }
   default:
     llvm_unreachable("invalid overload type");
     return OverloadKind::VOID;
@@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name,
   return StructType::create(Ctx, EltTys, Name);
 }
 
-static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
-  OverloadKind Kind = getOverloadKind(OverloadTy);
+static StructType *getResRetType(Type *ElementTy) {
+  LLVMContext &Ctx = ElementTy->getContext();
+  OverloadKind Kind = getOverloadKind(ElementTy);
   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
-  Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+  Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
                          Type::getInt32Ty(Ctx)};
   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
 }
@@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
     return Type::getInt64Ty(Ctx);
   case OpParamType::OverloadTy:
     return OverloadTy;
-  case OpParamType::ResRetTy:
-    return getResRetType(OverloadTy, Ctx);
+  case OpParamType::ResRetHalfTy:
+    return getResRetType(Type::getHalfTy(Ctx));
+  case OpParamType::ResRetFloatTy:
+    return getResRetType(Type::getFloatTy(Ctx));
+  case OpParamType::ResRetInt16Ty:
+    return getResRetType(Type::getInt16Ty(Ctx));
+  case OpParamType::ResRetInt32Ty:
+    return getResRetType(Type::getInt32Ty(Ctx));
   case OpParamType::HandleTy:
     return getHandleType(Ctx);
   case OpParamType::ResBindTy:
@@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
       return makeOpError(OpCode, "Wrong number of arguments");
     OverloadTy = Args[ArgIndex]->getType();
   }
+
   FunctionType *DXILOpFT =
       getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
 
@@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
   return *Result;
 }
 
+StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
+  return ::getResRetType(ElementTy);
+}
+
 StructType *DXILOpBuilder::getHandleType() {
   return ::getHandleType(IRB.getContext());
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 4a55a8ac9eadb..a68f0c43f67af 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -46,6 +46,8 @@ class DXILOpBuilder {
   Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
                                    Type *RetTy = nullptr);
 
+  /// Get a `%dx.types.ResRet` type with the given element type.
+  StructType *getResRetType(Type *ElementTy);
   /// Get the `%dx.types.Handle` type.
   StructType *getHandleType();
 
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index ab18c57efa307..46dfc905b5875 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -236,6 +236,59 @@ class OpLowerer {
       lowerToBindAndAnnotateHandle(F);
   }
 
+  void lowerTypedBufferLoad(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      Value *Handle =
+          createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+      Value *Index0 = CI->getArgOperand(1);
+      Value *Index1 = UndefValue::get(Int32Ty);
+      Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+
+      std::array<Value *, 3> Args{Handle, Index0, Index1};
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
+      if (Error E = OpCall.takeError())
+        return E;
+
+      std::array<Value *, 4> Extracts = {};
+
+      // We've switched the return type from a vector to a struct, but at this
+      // point most vectors have probably already been scalarized. Try to
+      // forward arguments directly rather than inserting into and immediately
+      // extracting from a vector.
+      for (Use &U : make_early_inc_range(CI->uses()))
+        if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
+          if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+            size_t IndexVal = Index->getZExtValue();
+            assert(IndexVal < 4 && "Index into buffer load out of range");
+            if (!Extracts[IndexVal])
+              Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
+            EEI->replaceAllUsesWith(Extracts[IndexVal]);
+            EEI->eraseFromParent();
+          }
+
+      // If there are still uses then we need to create a vector.
+      if (!CI->use_empty()) {
+        for (int I = 0, E = 4; I != E; ++I)
+          if (!Extracts[I])
+            Extracts[I] = IRB.CreateExtractValue(*OpCall, I);
+
+        Value *Vec = UndefValue::get(CI->getType());
+        for (int I = 0, E = 4; I != E; ++I)
+          Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+        CI->replaceAllUsesWith(Vec);
+      }
+
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
 
@@ -253,6 +306,10 @@ class OpLowerer {
 #include "DXILOperation.inc"
       case Intrinsic::dx_handle_fromBinding:
         lowerHandleFromBinding(F);
+        break;
+      case Intrinsic::dx_typedBufferLoad:
+        lowerTypedBufferLoad(F);
+        break;
       }
       Updated = true;
     }
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
new file mode 100644
index 0000000000000..c3bb96dbdf909
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -0,0 +1,102 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+
+define void @loadfloats() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+  ; The extract order depends on the users, so don't enforce that here.
+  ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+  %data0_0 = extractelement <4 x float> %data0, i32 0
+  ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+  %data0_2 = extractelement <4 x float> %data0, i32 2
+
+  ; If all of the uses are extracts, we skip creating a vector
+  ; CHECK-NOT: insertelement
+  call void @scalar_user(float %data0_0)
+  call void @scalar_user(float %data0_2)
+
+  ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+  %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
+
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+  ; CHECK: insertelement <4 x float> undef
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  call void @vector_user(<4 x float> %data4)
+
+  ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+  %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
+
+  ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+  %data12_3 = extractelement <4 x float> %data12, i32 3
+
+  ; If there are a mix of users we need the vector, but extracts are direct
+  ; CHECK: call void @scalar_user(float [[DATA12_3]])
+  call void @scalar_user(float %data12_3)
+  call void @vector_user(<4 x float> %data12)
+
+  ret void
+}
+
+define void @loadint() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
+
+define void @loadhalf() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
+
+define void @loadi16() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
+             target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 9cc1b5ccb8acb..332706f7e3e57 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
       .Case("Int8Ty", "OverloadKind::I8")
       .Case("Int16Ty", "OverloadKind::I16")
       .Case("Int32Ty", "OverloadKind::I32")
-      .Case("Int64Ty", "OverloadKind::I64");
+      .Case("Int64Ty", "OverloadKind::I64")
+      .Case("ResRetHalfTy", "OverloadKind::HALF")
+      .Case("ResRetFloatTy", "OverloadKind::FLOAT")
+      .Case("ResRetInt16Ty", "OverloadKind::I16")
+      .Case("ResRetInt32Ty", "OverloadKind::I32");
 }
 
 /// Return a string representation of valid overload information denoted

``````````

</details>


https://github.com/llvm/llvm-project/pull/104252


More information about the llvm-branch-commits mailing list