[llvm-branch-commits] [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops (PR #104252)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 14 14:29:53 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: Justin Bogner (bogner)
<details>
<summary>Changes</summary>
The `@<!-- -->llvm.dx.typedBufferLoad` intrinsic is lowered to `@<!-- -->dx.op.bufferLoad`.
There's some complexity here due to translating from a vector return type to a
named struct and trying to avoid excessive IR coming out of that.
Note that this change includes a bit of a hack in how it deals with
`getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal
with operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.
---
Full diff: https://github.com/llvm/llvm-project/pull/104252.diff
7 Files Affected:
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4)
- (modified) llvm/lib/Target/DirectX/DXIL.td (+15-1)
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+25-6)
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2)
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+57)
- (added) llvm/test/CodeGen/DirectX/BufferLoad.ll (+102)
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+5-1)
``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ca3682fa47767..d817b610fa71a 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -30,6 +30,10 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;
+def int_dx_typedBufferLoad
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [llvm_any_ty, llvm_i32_ty]>;
+
// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 31fee04d82158..b114148f84e84 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
def HalfTy : DXILOpParamType;
def FloatTy : DXILOpParamType;
def DoubleTy : DXILOpParamType;
-def ResRetTy : DXILOpParamType;
+def ResRetHalfTy : DXILOpParamType;
+def ResRetFloatTy : DXILOpParamType;
+def ResRetInt16Ty : DXILOpParamType;
+def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
@@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
+def BufferLoad : DXILOp<68, bufferLoad> {
+ let Doc = "reads from a TypedBuffer";
+ // Handle, Coord0, Coord1
+ let arguments = [HandleTy, Int32Ty, Int32Ty];
+ let result = OverloadTy;
+ let overloads =
+ [Overloads<DXIL1_0,
+ [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 692af1b359ced..246e32c264dc9 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) {
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
- case Type::StructTyID:
+ case Type::StructTyID: {
+ // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
+ // how we're handling overloads and remove the `OverloadKind` proxy enum.
+ StructType *ST = cast<StructType>(Ty);
+ if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
+ return getOverloadKind(ST->getElementType(0));
+
return OverloadKind::ObjectType;
+ }
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
@@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name,
return StructType::create(Ctx, EltTys, Name);
}
-static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
- OverloadKind Kind = getOverloadKind(OverloadTy);
+static StructType *getResRetType(Type *ElementTy) {
+ LLVMContext &Ctx = ElementTy->getContext();
+ OverloadKind Kind = getOverloadKind(ElementTy);
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
- Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+ Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
Type::getInt32Ty(Ctx)};
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}
@@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return Type::getInt64Ty(Ctx);
case OpParamType::OverloadTy:
return OverloadTy;
- case OpParamType::ResRetTy:
- return getResRetType(OverloadTy, Ctx);
+ case OpParamType::ResRetHalfTy:
+ return getResRetType(Type::getHalfTy(Ctx));
+ case OpParamType::ResRetFloatTy:
+ return getResRetType(Type::getFloatTy(Ctx));
+ case OpParamType::ResRetInt16Ty:
+ return getResRetType(Type::getInt16Ty(Ctx));
+ case OpParamType::ResRetInt32Ty:
+ return getResRetType(Type::getInt32Ty(Ctx));
case OpParamType::HandleTy:
return getHandleType(Ctx);
case OpParamType::ResBindTy:
@@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
return makeOpError(OpCode, "Wrong number of arguments");
OverloadTy = Args[ArgIndex]->getType();
}
+
FunctionType *DXILOpFT =
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
@@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
return *Result;
}
+StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
+ return ::getResRetType(ElementTy);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 4a55a8ac9eadb..a68f0c43f67af 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -46,6 +46,8 @@ class DXILOpBuilder {
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
Type *RetTy = nullptr);
+ /// Get a `%dx.types.ResRet` type with the given element type.
+ StructType *getResRetType(Type *ElementTy);
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index ab18c57efa307..46dfc905b5875 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -236,6 +236,59 @@ class OpLowerer {
lowerToBindAndAnnotateHandle(F);
}
+ void lowerTypedBufferLoad(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+
+ Value *Handle =
+ createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+ Value *Index0 = CI->getArgOperand(1);
+ Value *Index1 = UndefValue::get(Int32Ty);
+ Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+
+ std::array<Value *, 3> Args{Handle, Index0, Index1};
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+
+ std::array<Value *, 4> Extracts = {};
+
+ // We've switched the return type from a vector to a struct, but at this
+ // point most vectors have probably already been scalarized. Try to
+ // forward arguments directly rather than inserting into and immediately
+ // extracting from a vector.
+ for (Use &U : make_early_inc_range(CI->uses()))
+ if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
+ if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+ size_t IndexVal = Index->getZExtValue();
+ assert(IndexVal < 4 && "Index into buffer load out of range");
+ if (!Extracts[IndexVal])
+ Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
+ EEI->replaceAllUsesWith(Extracts[IndexVal]);
+ EEI->eraseFromParent();
+ }
+
+ // If there are still uses then we need to create a vector.
+ if (!CI->use_empty()) {
+ for (int I = 0, E = 4; I != E; ++I)
+ if (!Extracts[I])
+ Extracts[I] = IRB.CreateExtractValue(*OpCall, I);
+
+ Value *Vec = UndefValue::get(CI->getType());
+ for (int I = 0, E = 4; I != E; ++I)
+ Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+ CI->replaceAllUsesWith(Vec);
+ }
+
+ CI->eraseFromParent();
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
@@ -253,6 +306,10 @@ class OpLowerer {
#include "DXILOperation.inc"
case Intrinsic::dx_handle_fromBinding:
lowerHandleFromBinding(F);
+ break;
+ case Intrinsic::dx_typedBufferLoad:
+ lowerTypedBufferLoad(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
new file mode 100644
index 0000000000000..c3bb96dbdf909
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -0,0 +1,102 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+
+define void @loadfloats() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; The temporary casts should all have been cleaned up
+ ; CHECK-NOT: %dx.cast_handle
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+ ; The extract order depends on the users, so don't enforce that here.
+ ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+ %data0_0 = extractelement <4 x float> %data0, i32 0
+ ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+ %data0_2 = extractelement <4 x float> %data0, i32 2
+
+ ; If all of the uses are extracts, we skip creating a vector
+ ; CHECK-NOT: insertelement
+ call void @scalar_user(float %data0_0)
+ call void @scalar_user(float %data0_2)
+
+ ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+ %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
+
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+ ; CHECK: insertelement <4 x float> undef
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ call void @vector_user(<4 x float> %data4)
+
+ ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+ %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
+
+ ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+ %data12_3 = extractelement <4 x float> %data12, i32 3
+
+ ; If there are a mix of users we need the vector, but extracts are direct
+ ; CHECK: call void @scalar_user(float [[DATA12_3]])
+ call void @scalar_user(float %data12_3)
+ call void @vector_user(<4 x float> %data12)
+
+ ret void
+}
+
+define void @loadint() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadhalf() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadi16() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 9cc1b5ccb8acb..332706f7e3e57 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
.Case("Int8Ty", "OverloadKind::I8")
.Case("Int16Ty", "OverloadKind::I16")
.Case("Int32Ty", "OverloadKind::I32")
- .Case("Int64Ty", "OverloadKind::I64");
+ .Case("Int64Ty", "OverloadKind::I64")
+ .Case("ResRetHalfTy", "OverloadKind::HALF")
+ .Case("ResRetFloatTy", "OverloadKind::FLOAT")
+ .Case("ResRetInt16Ty", "OverloadKind::I16")
+ .Case("ResRetInt32Ty", "OverloadKind::I32");
}
/// Return a string representation of valid overload information denoted
``````````
</details>
https://github.com/llvm/llvm-project/pull/104252
More information about the llvm-branch-commits
mailing list