[llvm] 3f22756 - [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 9 13:21:26 PDT 2024


Author: Justin Bogner
Date: 2024-09-09T13:21:22-07:00
New Revision: 3f22756f391e20040fa3581206b77c409433bd9f

URL: https://github.com/llvm/llvm-project/commit/3f22756f391e20040fa3581206b77c409433bd9f
DIFF: https://github.com/llvm/llvm-project/commit/3f22756f391e20040fa3581206b77c409433bd9f.diff

LOG: [DirectX] Lower `@llvm.dx.typedBufferLoad` to DXIL ops

The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`.
There's some complexity here in translating to scalarized IR, which I've
abstracted out into a function that should be useful for samples, gathers, and
CBuffer loads.

I've also updated the DXILResources.rst docs to match what I'm doing here and
the proposal in llvm/wg-hlsl#59. I've removed the content about stores and raw
buffers for now with the expectation that it will be added along with the work.

Note that this change includes a bit of a hack in how it deals with
`getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal
with operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.

Part of #91367

Pull Request: https://github.com/llvm/llvm-project/pull/104252

Added: 
    llvm/test/CodeGen/DirectX/BufferLoad.ll

Modified: 
    llvm/docs/DirectX/DXILResources.rst
    llvm/include/llvm/IR/IntrinsicsDirectX.td
    llvm/lib/Target/DirectX/DXIL.td
    llvm/lib/Target/DirectX/DXILOpBuilder.cpp
    llvm/lib/Target/DirectX/DXILOpBuilder.h
    llvm/lib/Target/DirectX/DXILOpLowering.cpp
    llvm/utils/TableGen/DXILEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/docs/DirectX/DXILResources.rst b/llvm/docs/DirectX/DXILResources.rst
index a6ec80ce4329b2..8e43bfaaaf32ea 100644
--- a/llvm/docs/DirectX/DXILResources.rst
+++ b/llvm/docs/DirectX/DXILResources.rst
@@ -267,45 +267,38 @@ Examples:
                @llvm.dx.handle.fromHeap.tdx.RawBuffer_v4f32_1_0(
                    i32 2, i1 false)
 
-Buffer Loads and Stores
------------------------
-
-*relevant types: Buffers*
-
-We need to treat buffer loads and stores from "dx.TypedBuffer" and
-"dx.RawBuffer" separately. For TypedBuffer, we have ``llvm.dx.typedBufferLoad``
-and ``llvm.dx.typedBufferStore``, which load and store 16-byte "rows" of data
-via a simple index. For RawBuffer, we have ``llvm.dx.rawBufferPtr``, which
-return a pointer that can be indexed, loaded, and stored to as needed.
-
-The typed load and store operations always operate on exactly 16 bytes of data,
-so there are only a few valid overloads. For types that are 32-bits or smaller,
-we operate on 4-element vectors, such as ``<4 x i32>``, ``<4 x float>``, or
-``<4 x half>``. Note that in 16-bit cases each 16-bit value occupies 32-bits of
-storage. For 64-bit types we operate on 2-element vectors - ``<2 x double>`` or
-``<2 x i64>``. When a type like `Buffer<float>` is used at the HLSL level, it
-is expected that this will operate on a single float in each 16 byte row - that
-is, a load would use the ``<4 x float>`` variant and then extract the first
-element.
-
-.. note:: In DXC, trying to operate on a ``Buffer<double4>`` crashes the
-          compiler. We should probably just reject this in the frontend.
-
-The TypedBuffer intrinsics are lowered to the `bufferLoad`_ and `bufferStore`_
-operations, and the operations on the memory accessed by RawBufferPtr are
-lowered to `rawBufferLoad`_ and `rawBufferStore`_. Note that if we want to
-support DXIL versions prior to 1.2 we'll need to lower the RawBuffer loads and
-stores to the non-raw operations as well.
-
-.. note:: TODO: We need to account for `CheckAccessFullyMapped`_ here.
-
-   In DXIL the load operations always return an ``i32`` status value, but this
-   isn't very ergonomic when it isn't used. We can (1) bite the bullet and have
-   the loads return `{%ret_type, %i32}` all the time, (2) create a variant or
-   update the signature iff the status is used, or (3) hide this in a sideband
-   channel somewhere. I'm leaning towards (2), but could probably be convinced
-   that the ugliness of (1) is worth the simplicity.
-
+16-byte Loads, Samples, and Gathers
+-----------------------------------
+
+*relevant types: TypedBuffer, CBuffer, and Textures*
+
+TypedBuffer, CBuffer, and Texture loads, as well as samples and gathers, can
+return 1 to 4 elements from the given resource, to a maximum of 16 bytes of
+data. DXIL's modeling of this is influenced by DirectX and DXBC's history and
+it generally treats these operations as returning 4 32-bit values. For 16-bit
+elements the values are 16-bit values, and for 64-bit values the operations
+return 4 32-bit integers and emit further code to construct the double.
+
+In DXIL, these operations return `ResRet`_ and `CBufRet`_ values, are structs
+containing 4 elements of the same type, and in the case of `ResRet` a 5th
+element that is used by the `CheckAccessFullyMapped`_ operation.
+
+In LLVM IR the intrinsics will return the contained type of the resource
+instead. That is, ``llvm.dx.typedBufferLoad`` from a ``Buffer<float>`` would
+return a single float, from ``Buffer<float4>`` a vector of 4 floats, and from
+``Buffer<double2>`` a vector of two doubles, etc. The operations are then
+expanded out to match DXIL's format during lowering.
+
+In cases where we need ``CheckAccessFullyMapped``, we have a second intrinsic
+that returns an anonymous struct with element-0 being the contained type, and
+element-1 being the ``i1`` result of a ``CheckAccessFullyMapped`` call. We
+don't have a separate call to ``CheckAccessFullyMapped`` at all, since that's
+the only operation that can possibly be done on this value. In practice this
+may mean we insert a DXIL operation for the check when this was missing in the
+HLSL source, but this actually matches DXC's behaviour in practice.
+
+.. _ResRet: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#resource-operation-return-types
+.. _CBufRet: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#cbufferloadlegacy
 .. _CheckAccessFullyMapped: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/checkaccessfullymapped
 
 .. list-table:: ``@llvm.dx.typedBufferLoad``
@@ -317,7 +310,7 @@ stores to the non-raw operations as well.
      - Description
    * - Return value
      -
-     - A 4- or 2-element vector of the type of the buffer
+     - The contained type of the buffer
      - The data loaded from the buffer
    * - ``%buffer``
      - 0
@@ -332,16 +325,23 @@ Examples:
 
 .. code-block:: llvm
 
-   %ret = call <4 x float> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f32_0_0t(
-       target("dx.TypedBuffer", f32, 0, 0) %buffer, i32 %index)
-   %ret = call <4 x i32> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_i32_0_0t(
-       target("dx.TypedBuffer", i32, 0, 0) %buffer, i32 %index)
-   %ret = call <4 x half> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f16_0_0t(
-       target("dx.TypedBuffer", f16, 0, 0) %buffer, i32 %index)
-   %ret = call <2 x double> @llvm.dx.typedBufferLoad.tdx.TypedBuffer_f64_0_0t(
-       target("dx.TypedBuffer", double, 0, 0) %buffer, i32 %index)
-
-.. list-table:: ``@llvm.dx.typedBufferStore``
+   %ret = call <4 x float>
+       @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
+           target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
+   %ret = call float
+       @llvm.dx.typedBufferLoad.f32.tdx.TypedBuffer_f32_0_0_0t(
+           target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 %index)
+   %ret = call <4 x i32>
+       @llvm.dx.typedBufferLoad.v4i32.tdx.TypedBuffer_v4i32_0_0_0t(
+           target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 %index)
+   %ret = call <4 x half>
+       @llvm.dx.typedBufferLoad.v4f16.tdx.TypedBuffer_v4f16_0_0_0t(
+           target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 %index)
+   %ret = call <2 x double>
+       @llvm.dx.typedBufferLoad.v2f64.tdx.TypedBuffer_v2f64_0_0t(
+           target("dx.TypedBuffer", <2 x double>, 0, 0, 0) %buffer, i32 %index)
+
+.. list-table:: ``@llvm.dx.typedBufferLoad.checkbit``
    :header-rows: 1
 
    * - Argument
@@ -350,46 +350,11 @@ Examples:
      - Description
    * - Return value
      -
-     - ``void``
-     -
+     - A structure of the contained type and the check bit
+     - The data loaded from the buffer and the check bit
    * - ``%buffer``
      - 0
      - ``target(dx.TypedBuffer, ...)``
-     - The buffer to store into
-   * - ``%index``
-     - 1
-     - ``i32``
-     - Index into the buffer
-   * - ``%data``
-     - 2
-     - A 4- or 2-element vector of the type of the buffer
-     - The data to store
-
-Examples:
-
-.. code-block:: llvm
-
-   call void @llvm.dx.bufferStore.tdx.Buffer_f32_1_0t(
-       target("dx.TypedBuffer", f32, 1, 0) %buf, i32 %index, <4 x f32> %data)
-   call void @llvm.dx.bufferStore.tdx.Buffer_f16_1_0t(
-       target("dx.TypedBuffer", f16, 1, 0) %buf, i32 %index, <4 x f16> %data)
-   call void @llvm.dx.bufferStore.tdx.Buffer_f64_1_0t(
-       target("dx.TypedBuffer", f64, 1, 0) %buf, i32 %index, <2 x f64> %data)
-
-.. list-table:: ``@llvm.dx.rawBufferPtr``
-   :header-rows: 1
-
-   * - Argument
-     -
-     - Type
-     - Description
-   * - Return value
-     -
-     - ``ptr``
-     - Pointer to an element of the buffer
-   * - ``%buffer``
-     - 0
-     - ``target(dx.RawBuffer, ...)``
      - The buffer to load from
    * - ``%index``
      - 1
@@ -400,37 +365,7 @@ Examples:
 
 .. code-block:: llvm
 
-   ; Load a float4 from a buffer
-   %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_v4f32_0_0t(
-       target("dx.RawBuffer", <4 x f32>, 0, 0) %buffer, i32 %index)
-   %val = load <4 x float>, ptr %buf, align 16
-
-   ; Load the double from a struct containing an int, a float, and a double
-   %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_sl_i32f32f64s_0_0t(
-       target("dx.RawBuffer", {i32, f32, f64}, 0, 0) %buffer, i32 %index)
-   %val = getelementptr inbounds {i32, f32, f64}, ptr %buf, i32 0, i32 2
-   %d = load double, ptr %val, align 8
-
-   ; Load a float from a byte address buffer
-   %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_i8_0_0t(
-       target("dx.RawBuffer", i8, 0, 0) %buffer, i32 %index)
-   %val = getelementptr inbounds float, ptr %buf, i64 0
-   %f = load float, ptr %val, align 4
-
-   ; Store to a buffer containing float4
-   %addr = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_v4f32_0_0t(
-       target("dx.RawBuffer", <4 x f32>, 0, 0) %buffer, i32 %index)
-   store <4 x float> %val, ptr %addr
-
-   ; Store the double in a struct containing an int, a float, and a double
-   %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_sl_i32f32f64s_0_0t(
-       target("dx.RawBuffer", {i32, f32, f64}, 0, 0) %buffer, i32 %index)
-   %addr = getelementptr inbounds {i32, f32, f64}, ptr %buf, i32 0, i32 2
-   store double %d, ptr %addr
-
-   ; Store a float into a byte address buffer
-   %buf = call ptr @llvm.dx.rawBufferPtr.tdx.RawBuffer_i8_0_0t(
-       target("dx.RawBuffer", i8, 0, 0) %buffer, i32 %index)
-   %addr = getelementptr inbounds float, ptr %buf, i64 0
-   store float %f, ptr %val
+   %ret = call {<4 x float>, i1}
+       @llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
+           target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
 

diff  --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f089d51fa1b459..40c9ac3f0da346 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -30,6 +30,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_typedBufferLoad
+    : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
+
 // Cast between target extension handle types and dxil-style opaque handles
 def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 

diff  --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 4e3ecf4300d825..67a9b9d02bb6a1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
 def HalfTy : DXILOpParamType;
 def FloatTy : DXILOpParamType;
 def DoubleTy : DXILOpParamType;
-def ResRetTy : DXILOpParamType;
+def ResRetHalfTy : DXILOpParamType;
+def ResRetFloatTy : DXILOpParamType;
+def ResRetInt16Ty : DXILOpParamType;
+def ResRetInt32Ty : DXILOpParamType;
 def HandleTy : DXILOpParamType;
 def ResBindTy : DXILOpParamType;
 def ResPropsTy : DXILOpParamType;
@@ -693,6 +696,17 @@ def CreateHandle : DXILOp<57, createHandle> {
   let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
 }
 
+def BufferLoad : DXILOp<68, bufferLoad> {
+  let Doc = "reads from a TypedBuffer";
+  // Handle, Coord0, Coord1
+  let arguments = [HandleTy, Int32Ty, Int32Ty];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0,
+                 [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def ThreadId :  DXILOp<93, threadId> {
   let Doc = "Reads the thread ID";
   let LLVMIntrinsic = int_dx_thread_id;

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index efe019a07acaa9..3b2a5f5061eb83 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -120,8 +120,12 @@ static OverloadKind getOverloadKind(Type *Ty) {
   }
   case Type::PointerTyID:
     return OverloadKind::UserDefineType;
-  case Type::StructTyID:
-    return OverloadKind::ObjectType;
+  case Type::StructTyID: {
+    // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
+    // how we're handling overloads and remove the `OverloadKind` proxy enum.
+    StructType *ST = cast<StructType>(Ty);
+    return getOverloadKind(ST->getElementType(0));
+  }
   default:
     return OverloadKind::UNDEFINED;
   }
@@ -194,10 +198,11 @@ static StructType *getOrCreateStructType(StringRef Name,
   return StructType::create(Ctx, EltTys, Name);
 }
 
-static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
-  OverloadKind Kind = getOverloadKind(OverloadTy);
+static StructType *getResRetType(Type *ElementTy) {
+  LLVMContext &Ctx = ElementTy->getContext();
+  OverloadKind Kind = getOverloadKind(ElementTy);
   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
-  Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+  Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
                          Type::getInt32Ty(Ctx)};
   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
 }
@@ -247,8 +252,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
     return Type::getInt64Ty(Ctx);
   case OpParamType::OverloadTy:
     return OverloadTy;
-  case OpParamType::ResRetTy:
-    return getResRetType(OverloadTy, Ctx);
+  case OpParamType::ResRetHalfTy:
+    return getResRetType(Type::getHalfTy(Ctx));
+  case OpParamType::ResRetFloatTy:
+    return getResRetType(Type::getFloatTy(Ctx));
+  case OpParamType::ResRetInt16Ty:
+    return getResRetType(Type::getInt16Ty(Ctx));
+  case OpParamType::ResRetInt32Ty:
+    return getResRetType(Type::getInt32Ty(Ctx));
   case OpParamType::HandleTy:
     return getHandleType(Ctx);
   case OpParamType::ResBindTy:
@@ -390,6 +401,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
       return makeOpError(OpCode, "Wrong number of arguments");
     OverloadTy = Args[ArgIndex]->getType();
   }
+
   FunctionType *DXILOpFT =
       getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
 
@@ -450,6 +462,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
   return *Result;
 }
 
+StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
+  return ::getResRetType(ElementTy);
+}
+
 StructType *DXILOpBuilder::getHandleType() {
   return ::getHandleType(IRB.getContext());
 }

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 4a55a8ac9eadb5..a68f0c43f67afb 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -46,6 +46,8 @@ class DXILOpBuilder {
   Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
                                    Type *RetTy = nullptr);
 
+  /// Get a `%dx.types.ResRet` type with the given element type.
+  StructType *getResRetType(Type *ElementTy);
   /// Get the `%dx.types.Handle` type.
   StructType *getHandleType();
 

diff  --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 1f6d37087bc9f4..df2751d99576a8 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -259,6 +259,115 @@ class OpLowerer {
       lowerToBindAndAnnotateHandle(F);
   }
 
+  /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
+  /// Since we expect to be post-scalarization, make an effort to avoid vectors.
+  Error replaceResRetUses(CallInst *Intrin, CallInst *Op) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+
+    Type *OldTy = Intrin->getType();
+
+    // For scalars, we just extract the first element.
+    if (!isa<FixedVectorType>(OldTy)) {
+      Value *EVI = IRB.CreateExtractValue(Op, 0);
+      Intrin->replaceAllUsesWith(EVI);
+      Intrin->eraseFromParent();
+      return Error::success();
+    }
+
+    std::array<Value *, 4> Extracts = {};
+    SmallVector<ExtractElementInst *> DynamicAccesses;
+
+    // The users of the operation should all be scalarized, so we attempt to
+    // replace the extractelements with extractvalues directly.
+    for (Use &U : make_early_inc_range(Intrin->uses())) {
+      if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
+        if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+          size_t IndexVal = IndexOp->getZExtValue();
+          assert(IndexVal < 4 && "Index into buffer load out of range");
+          if (!Extracts[IndexVal])
+            Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
+          EEI->replaceAllUsesWith(Extracts[IndexVal]);
+          EEI->eraseFromParent();
+        } else {
+          DynamicAccesses.push_back(EEI);
+        }
+      }
+    }
+
+    const auto *VecTy = cast<FixedVectorType>(OldTy);
+    const unsigned N = VecTy->getNumElements();
+
+    // If there's a dynamic access we need to round trip through stack memory so
+    // that we don't leave vectors around.
+    if (!DynamicAccesses.empty()) {
+      Type *Int32Ty = IRB.getInt32Ty();
+      Constant *Zero = ConstantInt::get(Int32Ty, 0);
+
+      Type *ElTy = VecTy->getElementType();
+      Type *ArrayTy = ArrayType::get(ElTy, N);
+      Value *Alloca = IRB.CreateAlloca(ArrayTy);
+
+      for (int I = 0, E = N; I != E; ++I) {
+        if (!Extracts[I])
+          Extracts[I] = IRB.CreateExtractValue(Op, I);
+        Value *GEP = IRB.CreateInBoundsGEP(
+            ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
+        IRB.CreateStore(Extracts[I], GEP);
+      }
+
+      for (ExtractElementInst *EEI : DynamicAccesses) {
+        Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
+                                           {Zero, EEI->getIndexOperand()});
+        Value *Load = IRB.CreateLoad(ElTy, GEP);
+        EEI->replaceAllUsesWith(Load);
+        EEI->eraseFromParent();
+      }
+    }
+
+    // If we still have uses, then we're not fully scalarized and need to
+    // recreate the vector. This should only happen for things like exported
+    // functions from libraries.
+    if (!Intrin->use_empty()) {
+      for (int I = 0, E = N; I != E; ++I)
+        if (!Extracts[I])
+          Extracts[I] = IRB.CreateExtractValue(Op, I);
+
+      Value *Vec = UndefValue::get(OldTy);
+      for (int I = 0, E = N; I != E; ++I)
+        Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+      Intrin->replaceAllUsesWith(Vec);
+    }
+
+    Intrin->eraseFromParent();
+    return Error::success();
+  }
+
+  void lowerTypedBufferLoad(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      Value *Handle =
+          createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+      Value *Index0 = CI->getArgOperand(1);
+      Value *Index1 = UndefValue::get(Int32Ty);
+
+      Type *NewRetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
+
+      std::array<Value *, 3> Args{Handle, Index0, Index1};
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
+      if (Error E = OpCall.takeError())
+        return E;
+      if (Error E = replaceResRetUses(CI, *OpCall))
+        return E;
+
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
 
@@ -276,6 +385,10 @@ class OpLowerer {
 #include "DXILOperation.inc"
       case Intrinsic::dx_handle_fromBinding:
         lowerHandleFromBinding(F);
+        break;
+      case Intrinsic::dx_typedBufferLoad:
+        lowerTypedBufferLoad(F);
+        break;
       }
       Updated = true;
     }

diff  --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
new file mode 100644
index 00000000000000..4b9fb52f0b5299
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -0,0 +1,171 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+
+define void @loadv4f32() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
+
+  ; The extract order depends on the users, so don't enforce that here.
+  ; CHECK-DAG: [[VAL0_0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+  %data0_0 = extractelement <4 x float> %data0, i32 0
+  ; CHECK-DAG: [[VAL0_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+  %data0_2 = extractelement <4 x float> %data0, i32 2
+
+  ; If all of the uses are extracts, we skip creating a vector
+  ; CHECK-NOT: insertelement
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_0]])
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_2]])
+  call void @scalar_user(float %data0_0)
+  call void @scalar_user(float %data0_2)
+
+  ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+  %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
+
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+  ; CHECK: insertelement <4 x float> undef
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  call void @vector_user(<4 x float> %data4)
+
+  ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+  %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
+
+  ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+  %data12_3 = extractelement <4 x float> %data12, i32 3
+
+  ; If there are a mix of users we need the vector, but extracts are direct
+  ; CHECK: call void @scalar_user(float [[DATA12_3]])
+  call void @scalar_user(float %data12_3)
+  call void @vector_user(<4 x float> %data12)
+
+  ret void
+}
+
+define void @index_dynamic(i32 %bufindex, i32 %elemindex) {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef)
+  %load = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %bufindex)
+
+  ; CHECK: [[ALLOCA:%.*]] = alloca [4 x float]
+  ; CHECK: [[V0:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 0
+  ; CHECK: [[A0:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 0
+  ; CHECK: store float [[V0]], ptr [[A0]]
+  ; CHECK: [[V1:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 1
+  ; CHECK: [[A1:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 1
+  ; CHECK: store float [[V1]], ptr [[A1]]
+  ; CHECK: [[V2:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 2
+  ; CHECK: [[A2:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 2
+  ; CHECK: store float [[V2]], ptr [[A2]]
+  ; CHECK: [[V3:%.*]] = extractvalue %dx.types.ResRet.f32 [[LOAD]], 3
+  ; CHECK: [[A3:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 3
+  ; CHECK: store float [[V3]], ptr [[A3]]
+  ;
+  ; CHECK: [[PTR:%.*]] = getelementptr inbounds [4 x float], ptr [[ALLOCA]], i32 0, i32 %elemindex
+  ; CHECK: [[X:%.*]] = load float, ptr [[PTR]]
+  %data = extractelement <4 x float> %load, i32 %elemindex
+
+  ; CHECK: call void @scalar_user(float [[X]])
+  call void @scalar_user(float %data)
+
+  ret void
+}
+
+define void @loadf32() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", float, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call float @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 0)
+
+  ; CHECK: [[VAL0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+  ; CHECK: call void @scalar_user(float [[VAL0]])
+  call void @scalar_user(float %data0)
+
+  ret void
+}
+
+define void @loadv2f32() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <2 x float>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v2f32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <2 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <2 x float>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
+
+define void @loadv4i32() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
+
+define void @loadv4f16() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}
+
+define void @loadv4i16() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
+
+  ret void
+}

diff  --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 39b4a3ac375ed6..20164e1368ee9c 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
       .Case("Int8Ty", "OverloadKind::I8")
       .Case("Int16Ty", "OverloadKind::I16")
       .Case("Int32Ty", "OverloadKind::I32")
-      .Case("Int64Ty", "OverloadKind::I64");
+      .Case("Int64Ty", "OverloadKind::I64")
+      .Case("ResRetHalfTy", "OverloadKind::HALF")
+      .Case("ResRetFloatTy", "OverloadKind::FLOAT")
+      .Case("ResRetInt16Ty", "OverloadKind::I16")
+      .Case("ResRetInt32Ty", "OverloadKind::I32");
 }
 
 /// Return a string representation of valid overload information denoted


        


More information about the llvm-commits mailing list