[llvm] 3f22756 - [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 9 13:21:26 PDT 2024
Author: Justin Bogner
Date: 2024-09-09T13:21:22-07:00
New Revision: 3f22756f391e20040fa3581206b77c409433bd9f
URL: https://github.com/llvm/llvm-project/commit/3f22756f391e20040fa3581206b77c409433bd9f
DIFF: https://github.com/llvm/llvm-project/commit/3f22756f391e20040fa3581206b77c409433bd9f.diff
LOG: [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`.
There's some complexity here in translating to scalarized IR, which I've
abstracted out into a function that should be useful for samples, gathers, and
CBuffer loads.
I've also updated the DXILResources.rst docs to match what I'm doing here and
the proposal in llvm/wg-hlsl#59. I've removed the content about stores and raw
buffers for now with the expectation that it will be added along with the work.
Note that this change includes a bit of a hack in how it deals with
`getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal
with operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.
Part of #91367
Pull Request: https://github.com/llvm/llvm-project/pull/104252
Added:
llvm/test/CodeGen/DirectX/BufferLoad.ll
Modified:
llvm/docs/DirectX/DXILResources.rst
llvm/include/llvm/IR/IntrinsicsDirectX.td
llvm/lib/Target/DirectX/DXIL.td
llvm/lib/Target/DirectX/DXILOpBuilder.cpp
llvm/lib/Target/DirectX/DXILOpBuilder.h
llvm/lib/Target/DirectX/DXILOpLowering.cpp
llvm/utils/TableGen/DXILEmitter.cpp
Removed:
################################################################################
diff --git a/llvm/docs/DirectX/DXILResources.rst b/llvm/docs/DirectX/DXILResources.rst
index a6ec80ce4329b2..8e43bfaaaf32ea 100644
--- a/llvm/docs/DirectX/DXILResources.rst
+++ b/llvm/docs/DirectX/DXILResources.rst
@@ -267,45 +267,38 @@ Examples:
@llvm.dx.handle.fromHeap.tdx.RawBuffer_v4f32_1_0(
i32 2, i1 false)
-Buffer Loads and Stores
------------------------
-
-*relevant types: Buffers*
-
-We need to treat buffer loads and stores from "dx.TypedBuffer" and
-"dx.RawBuffer" separately. For TypedBuffer, we have ``llvm.dx.typedBufferLoad``
-and ``llvm.dx.typedBufferStore``, which load and store 16-byte "rows" of data
-via a simple index. For RawBuffer, we have ``llvm.dx.rawBufferPtr``, which
-return a pointer that can be indexed, loaded, and stored to as needed.
-
-The typed load and store operations always operate on exactly 16 bytes of data,
-so there are only a few valid overloads. For types that are 32-bits or smaller,
-we operate on 4-element vectors, such as ``<4 x i32>``, ``<4 x float>``, or
-``<4 x half>``. Note that in 16-bit cases each 16-bit value occupies 32-bits of
-storage. For 64-bit types we operate on 2-element vectors - ``<2 x double>`` or
-``<2 x i64>``. When a type like `Buffer<float>` is used at the HLSL level, it
-is expected that this will operate on a single float in each 16 byte row - that
-is, a load would use the ``<4 x float>`` variant and then extract the first
-element.
-
-.. note:: In DXC, trying to operate on a ``Buffer<double4>`` crashes the
- compiler. We should probably just reject this in the frontend.
-
-The TypedBuffer intrinsics are lowered to the `bufferLoad`_ and `bufferStore`_
-operations, and the operations on the memory accessed by RawBufferPtr are
-lowered to `rawBufferLoad`_ and `rawBufferStore`_. Note that if we want to
-support DXIL versions prior to 1.2 we'll need to lower the RawBuffer loads and
-stores to the non-raw operations as well.
-
-.. note:: TODO: We need to account for `CheckAccessFullyMapped`_ here.
-
- In DXIL the load operations always return an ``i32`` status value, but this
- isn't very ergonomic when it isn't used. We can (1) bite the bullet and have
- the loads return `{%ret_type, %i32}` all the time, (2) create a variant or
- update the signature iff the status is used, or (3) hide this in a sideband
- channel somewhere. I'm leaning towards (2), but could probably be convinced
- that the ugliness of (1) is worth the simplicity.
-
+16-byte Loads, Samples, and Gathers
+-----------------------------------
+
+*relevant types: TypedBuffer, CBuffer, and Textures*
+
+TypedBuffer, CBuffer, and Texture loads, as well as samples and gathers, can
+return 1 to 4 elements from the given resource, to a maximum of 16 bytes of
+data. DXIL's modeling of this is influenced by DirectX and DXBC's history and
+it generally treats these operations as returning 4 32-bit values. For 16-bit
+elements the values are 16-bit values, and for 64-bit values the operations
+return 4 32-bit integers and emit further code to construct the double.
+
+In DXIL, these operations return `ResRet`_ and `CBufRet`_ values, are structs
+containing 4 elements of the same type, and in the case of `ResRet` a 5th
+element that is used by the `CheckAccessFullyMapped`_ operation.
+
+In LLVM IR the intrinsics will return the contained type of the resource
+instead. That is, ``llvm.dx.typedBufferLoad`` from a ``Buffer<float>`` would
+return a single float, from ``Buffer<float4>`` a vector of 4 floats, and from
+``Buffer<double2>`` a vector of two doubles, etc. The operations are then
+expanded out to match DXIL's format during lowering.
+
+In cases where we need ``CheckAccessFullyMapped``, we have a second intrinsic
+that returns an anonymous struct with element-0 being the contained type, and
+element-1 being the ``i1`` result of a ``CheckAccessFullyMapped`` call. We
+don't have a separate call to ``CheckAccessFullyMapped`` at all, since that's
+the only operation that can possibly be done on this value. In practice this
+may mean we insert a DXIL operation for the check when this was missing in the
+HLSL source, but this actually matches DXC's behaviour in practice.
+
+.. _ResRet: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#resource-operation-return-types
+.. _CBufRet: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#cbufferloadlegacy
.. _CheckAccessFullyMapped: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/checkaccessfullymapped
.. list-table:: ``@llvm.dx.typedBufferLoad``
@@ -317,7 +310,7 @@ stores to the non-raw operations as well.
- Description
* - Return value
-
- - A 4- or 2-element vector of the type of the buffer
+ - The contained type of the buffer
- The data loaded from the buffer
* - ``%buffer``
- 0
@@ -332,16 +325,23 @@ Examples:
.. code-block:: llvm
- %ret = call <4 x float> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f32_0_0t(
- target("dx.TypedBuffer", f32, 0, 0) %buffer, i32 %index)
- %ret = call <4 x i32> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_i32_0_0t(
- target("dx.TypedBuffer", i32, 0, 0) %buffer, i32 %index)
- %ret = call <4 x half> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f16_0_0t(
- target("dx.TypedBuffer", f16, 0, 0) %buffer, i32 %index)
- %ret = call <2 x double> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f64_0_0t(
- target("dx.TypedBuffer", double, 0, 0) %buffer, i32 %index)
-
-.. list-table:: ``@llvm.dx.typedBufferStore``
+ %ret = call <4 x float>
+ @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
+ %ret = call float
+ @llvm.dx.typedBufferLoad.f32.tdx.TypedBuffer_f32_0_0_0t(
+ target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 %index)
+ %ret = call <4 x i32>
+ @llvm.dx.typedBufferLoad.v4i32.tdx.TypedBuffer_v4i32_0_0_0t(
+ target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 %index)
+ %ret = call <4 x half>
+ @llvm.dx.typedBufferLoad.v4f16.tdx.TypedBuffer_v4f16_0_0_0t(
+ target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 %index)
+ %ret = call <2 x double>
+ @llvm.dx.typedBufferLoad.v2f64.tdx.TypedBuffer_v2f64_0_0t(
+ target("dx.TypedBuffer", <2 x double>, 0, 0, 0) %buffer, i32 %index)
+
+.. list-table:: ``@llvm.dx.typedBufferLoad.checkbit``
:header-rows: 1
* - Argument
@@ -350,46 +350,11 @@ Examples:
- Description
* - Return value
-
- - ``void``
- -
+ - A structure of the contained type and the check bit
+ - The data loaded from the buffer and the check bit
* - ``%buffer``
- 0
- ``target(dx.TypedBuffer, ...)``
- - The buffer to store into
- * - ``%index``
- - 1
- - ``i32``
- - Index into the buffer
- * - ``%data``
- - 2
- - A 4- or 2-element vector of the type of the buffer
- - The data to store
-
-Examples:
-
-.. code-block:: llvm
-
- call void @llvm.dx.bufferStore.tdx.Buffer_f32_1_0t(
- target("dx.TypedBuffer", f32, 1, 0) %buf, i32 %index, <4 x f32> %data)
- call void @llvm.dx.bufferStore.tdx.Buffer_f16_1_0t(
- target("dx.TypedBuffer", f16, 1, 0) %buf, i32 %index, <4 x f16> %data)
- call void @llvm.dx.bufferStore.tdx.Buffer_f64_1_0t(
- target("dx.TypedBuffer", f64, 1, 0) %buf, i32 %index, <2 x f64> %data)
-
-.. list-table:: ``@llvm.dx.rawBufferPtr``
- :header-rows: 1
-
- * - Argument
- -
- - Type
- - Description
- * - Return value
- -
- - ``ptr``
- - Pointer to an element of the buffer
- * - ``%buffer``
- - 0
- - ``target(dx.RawBuffer, ...)``
- The buffer to load from
* - ``%index``
- 1
@@ -400,37 +365,7 @@ Examples:
.. code-block:: llvm
- ; Load a float4 from a buffer
- %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_v4f32_0_0t(
- target("dx.RawBuffer", <4 x f32>, 0, 0) %buffer, i32 %index)
- %val = load <4 x float>, ptr %buf, align 16
-
- ; Load the double from a struct containing an int, a float, and a double
- %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_sl_i32f32f64s_0_0t(
- target("dx.RawBuffer", {i32, f32, f64}, 0, 0) %buffer, i32 %index)
- %val = getelementptr inbounds {i32, f32, f64}, ptr %buf, i32 0, i32 2
- %d = load double, ptr %val, align 8
-
- ; Load a float from a byte address buffer
- %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_i8_0_0t(
- target("dx.RawBuffer", i8, 0, 0) %buffer, i32 %index)
- %val = getelementptr inbounds float, ptr %buf, i64 0
- %f = load float, ptr %val, align 4
-
- ; Store to a buffer containing float4
- %addr = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_v4f32_0_0t(
- target("dx.RawBuffer", <4 x f32>, 0, 0) %buffer, i32 %index)
- store <4 x float> %val, ptr %addr
-
- ; Store the double in a struct containing an int, a float, and a double
- %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_sl_i32f32f64s_0_0t(
- target("dx.RawBuffer", {i32, f32, f64}, 0, 0) %buffer, i32 %index)
- %addr = getelementptr inbounds {i32, f32, f64}, ptr %buf, i32 0, i32 2
- store double %d, ptr %addr
-
- ; Store a float into a byte address buffer
- %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_i8_0_0t(
- target("dx.RawBuffer", i8, 0, 0) %buffer, i32 %index)
- %addr = getelementptr inbounds float, ptr %buf, i64 0
- store float %f, ptr %val
+ %ret = call {<4 x float>, i1}
+ @llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f089d51fa1b459..40c9ac3f0da346 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -30,6 +30,9 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;
+def int_dx_typedBufferLoad
+ : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
+
// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 4e3ecf4300d825..67a9b9d02bb6a1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
def HalfTy : DXILOpParamType;
def FloatTy : DXILOpParamType;
def DoubleTy : DXILOpParamType;
-def ResRetTy : DXILOpParamType;
+def ResRetHalfTy : DXILOpParamType;
+def ResRetFloatTy : DXILOpParamType;
+def ResRetInt16Ty : DXILOpParamType;
+def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
@@ -693,6 +696,17 @@ def CreateHandle : DXILOp<57, createHandle> {
let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
}
+def BufferLoad : DXILOp<68, bufferLoad> {
+ let Doc = "reads from a TypedBuffer";
+ // Handle, Coord0, Coord1
+ let arguments = [HandleTy, Int32Ty, Int32Ty];
+ let result = OverloadTy;
+ let overloads =
+ [Overloads<DXIL1_0,
+ [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index efe019a07acaa9..3b2a5f5061eb83 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -120,8 +120,12 @@ static OverloadKind getOverloadKind(Type *Ty) {
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
- case Type::StructTyID:
- return OverloadKind::ObjectType;
+ case Type::StructTyID: {
+ // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
+ // how we're handling overloads and remove the `OverloadKind` proxy enum.
+ StructType *ST = cast<StructType>(Ty);
+ return getOverloadKind(ST->getElementType(0));
+ }
default:
return OverloadKind::UNDEFINED;
}
@@ -194,10 +198,11 @@ static StructType *getOrCreateStructType(StringRef Name,
return StructType::create(Ctx, EltTys, Name);
}
-static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
- OverloadKind Kind = getOverloadKind(OverloadTy);
+static StructType *getResRetType(Type *ElementTy) {
+ LLVMContext &Ctx = ElementTy->getContext();
+ OverloadKind Kind = getOverloadKind(ElementTy);
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
- Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+ Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
Type::getInt32Ty(Ctx)};
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}
@@ -247,8 +252,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return Type::getInt64Ty(Ctx);
case OpParamType::OverloadTy:
return OverloadTy;
- case OpParamType::ResRetTy:
- return getResRetType(OverloadTy, Ctx);
+ case OpParamType::ResRetHalfTy:
+ return getResRetType(Type::getHalfTy(Ctx));
+ case OpParamType::ResRetFloatTy:
+ return getResRetType(Type::getFloatTy(Ctx));
+ case OpParamType::ResRetInt16Ty:
+ return getResRetType(Type::getInt16Ty(Ctx));
+ case OpParamType::ResRetInt32Ty:
+ return getResRetType(Type::getInt32Ty(Ctx));
case OpParamType::HandleTy:
return getHandleType(Ctx);
case OpParamType::ResBindTy:
@@ -390,6 +401,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
return makeOpError(OpCode, "Wrong number of arguments");
OverloadTy = Args[ArgIndex]->getType();
}
+
FunctionType *DXILOpFT =
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
@@ -450,6 +462,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
return *Result;
}
+StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
+ return ::getResRetType(ElementTy);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 4a55a8ac9eadb5..a68f0c43f67afb 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -46,6 +46,8 @@ class DXILOpBuilder {
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
Type *RetTy = nullptr);
+ /// Get a `%dx.types.ResRet` type with the given element type.
+ StructType *getResRetType(Type *ElementTy);
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 1f6d37087bc9f4..df2751d99576a8 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -259,6 +259,115 @@ class OpLowerer {
lowerToBindAndAnnotateHandle(F);
}
+ /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
+ /// Since we expect to be post-scalarization, make an effort to avoid vectors.
+ Error replaceResRetUses(CallInst *Intrin, CallInst *Op) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+
+ Type *OldTy = Intrin->getType();
+
+ // For scalars, we just extract the first element.
+ if (!isa<FixedVectorType>(OldTy)) {
+ Value *EVI = IRB.CreateExtractValue(Op, 0);
+ Intrin->replaceAllUsesWith(EVI);
+ Intrin->eraseFromParent();
+ return Error::success();
+ }
+
+ std::array<Value *, 4> Extracts = {};
+ SmallVector<ExtractElementInst *> DynamicAccesses;
+
+ // The users of the operation should all be scalarized, so we attempt to
+ // replace the extractelements with extractvalues directly.
+ for (Use &U : make_early_inc_range(Intrin->uses())) {
+ if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
+ if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+ size_t IndexVal = IndexOp->getZExtValue();
+ assert(IndexVal < 4 && "Index into buffer load out of range");
+ if (!Extracts[IndexVal])
+ Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
+ EEI->replaceAllUsesWith(Extracts[IndexVal]);
+ EEI->eraseFromParent();
+ } else {
+ DynamicAccesses.push_back(EEI);
+ }
+ }
+ }
+
+ const auto *VecTy = cast<FixedVectorType>(OldTy);
+ const unsigned N = VecTy->getNumElements();
+
+ // If there's a dynamic access we need to round trip through stack memory so
+ // that we don't leave vectors around.
+ if (!DynamicAccesses.empty()) {
+ Type *Int32Ty = IRB.getInt32Ty();
+ Constant *Zero = ConstantInt::get(Int32Ty, 0);
+
+ Type *ElTy = VecTy->getElementType();
+ Type *ArrayTy = ArrayType::get(ElTy, N);
+ Value *Alloca = IRB.CreateAlloca(ArrayTy);
+
+ for (int I = 0, E = N; I != E; ++I) {
+ if (!Extracts[I])
+ Extracts[I] = IRB.CreateExtractValue(Op, I);
+ Value *GEP = IRB.CreateInBoundsGEP(
+ ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
+ IRB.CreateStore(Extracts[I], GEP);
+ }
+
+ for (ExtractElementInst *EEI : DynamicAccesses) {
+ Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
+ {Zero, EEI->getIndexOperand()});
+ Value *Load = IRB.CreateLoad(ElTy, GEP);
+ EEI->replaceAllUsesWith(Load);
+ EEI->eraseFromParent();
+ }
+ }
+
+ // If we still have uses, then we're not fully scalarized and need to
+ // recreate the vector. This should only happen for things like exported
+ // functions from libraries.
+ if (!Intrin->use_empty()) {
+ for (int I = 0, E = N; I != E; ++I)
+ if (!Extracts[I])
+ Extracts[I] = IRB.CreateExtractValue(Op, I);
+
+ Value *Vec = UndefValue::get(OldTy);
+ for (int I = 0, E = N; I != E; ++I)
+ Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+ Intrin->replaceAllUsesWith(Vec);
+ }
+
+ Intrin->eraseFromParent();
+ return Error::success();
+ }
+
+ void lowerTypedBufferLoad(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+
+ Value *Handle =
+ createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+ Value *Index0 = CI->getArgOperand(1);
+ Value *Index1 = UndefValue::get(Int32Ty);
+
+ Type *NewRetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+
+ std::array<Value *, 3> Args{Handle, Index0, Index1};
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+ if (Error E = replaceResRetUses(CI, *OpCall))
+ return E;
+
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
@@ -276,6 +385,10 @@ class OpLowerer {
#include "DXILOperation.inc"
case Intrinsic::dx_handle_fromBinding:
lowerHandleFromBinding(F);
+ break;
+ case Intrinsic::dx_typedBufferLoad:
+ lowerTypedBufferLoad(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
new file mode 100644
index 00000000000000..4b9fb52f0b5299
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -0,0 +1,171 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+
+define void @loadv4f32() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; The temporary casts should all have been cleaned up
+ ; CHECK-NOT: %dx.cast_handle
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+ ; The extract order depends on the users, so don't enforce that here.
+ ; CHECK-DAG: [[VAL0_0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+ %data0_0 = extractelement <4 x float> %data0, i32 0
+ ; CHECK-DAG: [[VAL0_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+ %data0_2 = extractelement <4 x float> %data0, i32 2
+
+ ; If all of the uses are extracts, we skip creating a vector
+ ; CHECK-NOT: insertelement
+ ; CHECK-DAG: call void @scalar_user(float [[VAL0_0]])
+ ; CHECK-DAG: call void @scalar_user(float [[VAL0_2]])
+ call void @scalar_user(float %data0_0)
+ call void @scalar_user(float %data0_2)
+
+ ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+ %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
+
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+ ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+ ; CHECK: insertelement <4 x float> undef
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ ; CHECK: insertelement <4 x float>
+ call void @vector_user(<4 x float> %data4)
+
+ ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+ %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
+
+ ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+ %data12_3 = extractelement <4 x float> %data12, i32 3
+
+ ; If there are a mix of users we need the vector, but extracts are direct
+ ; CHECK: call void @scalar_user(float [[DATA12_3]])
+ call void @scalar_user(float %data12_3)
+ call void @vector_user(<4 x float> %data12)
+
+ ret void
+}
+
+define void @index_dynamic(i32 %bufindex, i32 %elemindex) {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef)
+ %load = call <4 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %bufindex)
+
+ ; CHECK: [[ALLOCA:%.*]] = alloca [4 x float]
+ ; CHECK: [[V0:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 0
+ ; CHECK: [[A0:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 0
+ ; CHECK: store float [[V0]], ptr [[A0]]
+ ; CHECK: [[V1:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 1
+ ; CHECK: [[A1:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 1
+ ; CHECK: store float [[V1]], ptr [[A1]]
+ ; CHECK: [[V2:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 2
+ ; CHECK: [[A2:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 2
+ ; CHECK: store float [[V2]], ptr [[A2]]
+ ; CHECK: [[V3:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 3
+ ; CHECK: [[A3:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 3
+ ; CHECK: store float [[V3]], ptr [[A3]]
+ ;
+ ; CHECK: [[PTR:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 %elemindex
+ ; CHECK: [[X:%.*]] = load float, ptr [[PTR]]
+ %data = extractelement <4 x float> %load, i32 %elemindex
+
+ ; CHECK: call void @scalar_user(float [[X]])
+ call void @scalar_user(float %data)
+
+ ret void
+}
+
+define void @loadf32() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", float, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call float @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 0)
+
+ ; CHECK: [[VAL0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+ ; CHECK: call void @scalar_user(float [[VAL0]])
+ call void @scalar_user(float %data0)
+
+ ret void
+}
+
+define void @loadv2f32() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <2 x float>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v2f32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <2 x float> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <2 x float>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadv4i32() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadv4f16() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
+
+define void @loadv4i16() {
+ ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+ ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+ %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
+ @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
+ i32 0, i32 0, i32 1, i32 0, i1 false)
+
+ ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+ %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
+ target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
+
+ ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 39b4a3ac375ed6..20164e1368ee9c 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
.Case("Int8Ty", "OverloadKind::I8")
.Case("Int16Ty", "OverloadKind::I16")
.Case("Int32Ty", "OverloadKind::I32")
- .Case("Int64Ty", "OverloadKind::I64");
+ .Case("Int64Ty", "OverloadKind::I64")
+ .Case("ResRetHalfTy", "OverloadKind::HALF")
+ .Case("ResRetFloatTy", "OverloadKind::FLOAT")
+ .Case("ResRetInt16Ty", "OverloadKind::I16")
+ .Case("ResRetInt32Ty", "OverloadKind::I32");
}
/// Return a string representation of valid overload information denoted
More information about the llvm-commits
mailing list