[llvm] [NVPTX] Add idp2a, idp4a intrinsics (PR #102763)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 12 17:08:30 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/102763

>From b0451cd8907720ea4375c7afaf7a33136d9f02e8 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 5 Aug 2024 19:10:53 +0000
Subject: [PATCH 1/2] [NVPTX] Add idp{2,4}a intrinsics

---
 llvm/docs/NVPTXUsage.rst                |  56 ++++++
 llvm/include/llvm/IR/IntrinsicsNVVM.td  |  12 ++
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td |  26 +++
 llvm/lib/Target/NVPTX/NVPTXSubtarget.h  |   3 +
 llvm/test/CodeGen/NVPTX/dot-product.ll  | 216 ++++++++++++++++++++++++
 5 files changed, 313 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/dot-product.ll

diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index b2839b4348336a..536186db85e4ca 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -287,6 +287,62 @@ The ``@llvm.nvvm.fence.proxy.tensormap_generic.*`` is a uni-directional fence us
 
 The address operand ``addr`` and the operand ``size`` together specify the memory range ``[addr, addr+size)`` on which the ordering guarantees on the memory accesses across the proxies is to be provided. The only supported value for the ``size`` operand is ``128`` and must be an immediate. Generic Addressing is used unconditionally, and the address specified by the operand addr must fall within the ``.global`` state space. Otherwise, the behavior is undefined. For more information, see `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`_.
 
+Arithmetic Intrinsics
+---------------------
+
+'``llvm.nvvm.idp2a``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.idp2a(i32 %a, i1 immarg %a.unsigned, i32 %b, i1 immarg %b.unsigned, i1 immarg %is.hi, i32 %c)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.idp2a``' intrinsic performs a 2-element vector dot product
+followed by addition. It corresponds directly to the ``dp2a`` PTX instruction.
+
+Semantics:
+""""""""""
+
+The 32-bit value in ``%a`` is broken into 2 16-bit values which are either sign
+or zero extended, depending on the value of ``%a.unsigned``, to 32 bits. Two
+bytes are selected from ``%b``, if ``%is.hi`` is true, the most significant
+bytes are selected, otherwise the least significant bytes are selected. These
+bytes are each sign or zero extended, depending on ``%b.unsigned``. The dot
+product of these 2-element vectors is added to ``%c`` to produce the return.
+
+
+'``llvm.nvvm.idp4a``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+    declare i32 @llvm.nvvm.idp4a(i32 %a, i1 immarg %a.unsigned, i32 %b, i1 immarg %b.unsigned, i32 %c)
+
+Overview:
+"""""""""
+
+The '``llvm.nvvm.idp4a``' intrinsic performs a 4-element vector dot product
+followed by addition. It corresponds directly to the ``dp4a`` PTX instruction.
+
+Semantics:
+""""""""""
+
+Each of the 4 bytes in both ``%a`` and ``%b`` are extended to 32-bit integers
+forming 2 ``<4 x i32>``. zero-extension is used if ``%a.unsigned`` or
+``%b.unsigned`` is true respectively. The dot product of these 4-element vectors
+is added to ``%c`` to produce the return.
+
+
+
 Other Intrinsics
 ----------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7caada24dad564..8ab7da31cf3962 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1052,6 +1052,18 @@ let TargetPrefix = "nvvm" in {
       DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty],
         [IntrNoMem, IntrSpeculatable, Commutative]>;
 
+//
+// Dot Product
+//
+  def int_nvvm_idp4a :
+      DefaultAttrsIntrinsic<[llvm_i32_ty],
+          [llvm_i32_ty, llvm_i1_ty, llvm_i32_ty, llvm_i1_ty, llvm_i32_ty],
+          [IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<3>>]>;
+  def int_nvvm_idp2a :
+      DefaultAttrsIntrinsic<[llvm_i32_ty],
+        [llvm_i32_ty, llvm_i1_ty, llvm_i32_ty, llvm_i1_ty, llvm_i1_ty, llvm_i32_ty],
+        [IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>]>;
+
 //
 // Convert
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index d75dc8781f7802..2c38bdf20af57e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -159,6 +159,7 @@ def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
 
 def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
 def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
+def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
 
 def True : Predicate<"true">;
 def False : Predicate<"false">;
@@ -3920,6 +3921,31 @@ let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in
 }
 
 
+foreach a_unsigned = [0, -1] in {
+  foreach b_unsigned = [0, -1] in {
+    defvar a_suffix = !if(a_unsigned, "u32", "s32");
+    defvar b_suffix = !if(b_unsigned, "u32", "s32");
+
+    def DOT4_ # a_suffix # _ # b_suffix :
+      NVPTXInst<(outs Int32Regs:$dst),
+                (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+                "dp4a." # a_suffix # "." # b_suffix # " \t$dst, $a, $b, $c;",
+                [(set Int32Regs:$dst, (int_nvvm_idp4a (i32 Int32Regs:$a), a_unsigned, (i32 Int32Regs:$b), b_unsigned, (i32 Int32Regs:$c)))]>,
+                Requires<[hasDotInstructions]>;
+
+    foreach is_hi = [0, -1] in {
+      defvar lohi_suffix = !if(is_hi, "hi", "lo");
+
+      def DOT2_ # lohi_suffix # _ # a_suffix # _ # b_suffix :
+        NVPTXInst<(outs Int32Regs:$dst),
+                  (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+                  "dp2a." # lohi_suffix # "." # a_suffix # "." # b_suffix # " \t$dst, $a, $b, $c;",
+                  [(set Int32Regs:$dst, (int_nvvm_idp2a (i32 Int32Regs:$a), a_unsigned, (i32 Int32Regs:$b), b_unsigned, is_hi, (i32 Int32Regs:$c)))]>,
+                  Requires<[hasDotInstructions]>;
+    }
+  }
+}
+
 include "NVPTXIntrinsics.td"
 
 //-----------------------------------
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 8df41913ff12ef..e47050734aae1e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -90,6 +90,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
   bool hasMemoryOrdering() const { return SmVersion >= 70 && PTXVersion >= 60; }
   // Does SM & PTX support atomic relaxed MMIO operations ?
   bool hasRelaxedMMIO() const { return SmVersion >= 70 && PTXVersion >= 82; }
+  bool hasDotInstructions() const {
+    return SmVersion >= 61 && PTXVersion >= 50;
+  }
   unsigned int getFullSmVersion() const { return FullSmVersion; }
   unsigned int getSmVersion() const { return getFullSmVersion() / 10; }
   // GPUs with "a" suffix have include architecture-accelerated features that
diff --git a/llvm/test/CodeGen/NVPTX/dot-product.ll b/llvm/test/CodeGen/NVPTX/dot-product.ll
new file mode 100644
index 00000000000000..d06d6a799c2c1b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/dot-product.ll
@@ -0,0 +1,216 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx -mcpu=sm_61 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_61 | FileCheck %s
+
+target triple = "nvptx-nvidia-cuda"
+
+declare i32 @llvm.nvvm.idp4a(i32, i1 immarg, i32, i1 immarg, i32)
+
+define i32 @test_dp4a_u32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp4a_u32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp4a_u32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp4a_u32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp4a_u32_u32_param_2];
+; CHECK-NEXT:    dp4a.u32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 1, i32 %b, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp4a_u32imm_u32imm(i32 %c) {
+; CHECK-LABEL: test_dp4a_u32imm_u32imm(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp4a_u32imm_u32imm_param_0];
+; CHECK-NEXT:    mov.b32 %r2, 0;
+; CHECK-NEXT:    dp4a.u32.u32 %r3, %r2, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp4a(i32 0, i1 1, i32 0, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp4a_u32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp4a_u32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp4a_u32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp4a_u32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp4a_u32_s32_param_2];
+; CHECK-NEXT:    dp4a.u32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 1, i32 %b, i1 0, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp4a_s32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp4a_s32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp4a_s32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp4a_s32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp4a_s32_u32_param_2];
+; CHECK-NEXT:    dp4a.s32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 0, i32 %b, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp4a_s32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp4a_s32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp4a_s32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp4a_s32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp4a_s32_s32_param_2];
+; CHECK-NEXT:    dp4a.s32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 0, i32 %b, i1 0, i32 %c)
+  ret i32 %call
+}
+
+declare i32 @llvm.nvvm.idp2a(i32, i1 immarg, i32, i1 immarg, i1 immarg, i32)
+
+define i32 @test_dp2a_lo_u32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_lo_u32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_lo_u32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_lo_u32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_lo_u32_u32_param_2];
+; CHECK-NEXT:    dp2a.lo.u32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 1, i1 0, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_lo_u32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_lo_u32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_lo_u32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_lo_u32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_lo_u32_s32_param_2];
+; CHECK-NEXT:    dp2a.lo.u32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 0, i1 0, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_lo_s32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_lo_s32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_lo_s32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_lo_s32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_lo_s32_u32_param_2];
+; CHECK-NEXT:    dp2a.lo.s32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 1, i1 0, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_lo_s32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_lo_s32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_lo_s32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_lo_s32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_lo_s32_s32_param_2];
+; CHECK-NEXT:    dp2a.lo.s32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 0, i1 0, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_hi_u32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_hi_u32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_hi_u32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_hi_u32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_hi_u32_u32_param_2];
+; CHECK-NEXT:    dp2a.hi.u32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 1, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_hi_u32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_hi_u32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_hi_u32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_hi_u32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_hi_u32_s32_param_2];
+; CHECK-NEXT:    dp2a.hi.u32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 0, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_hi_s32_u32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_hi_s32_u32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_hi_s32_u32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_hi_s32_u32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_hi_s32_u32_param_2];
+; CHECK-NEXT:    dp2a.hi.s32.u32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 1, i1 1, i32 %c)
+  ret i32 %call
+}
+
+define i32 @test_dp2a_hi_s32_s32(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: test_dp2a_hi_s32_s32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_dp2a_hi_s32_s32_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_dp2a_hi_s32_s32_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test_dp2a_hi_s32_s32_param_2];
+; CHECK-NEXT:    dp2a.hi.s32.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 0, i1 1, i32 %c)
+  ret i32 %call
+}

>From a802d935571120446ffe7a5c6268b0766245340f Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 12 Aug 2024 23:37:28 +0000
Subject: [PATCH 2/2] address comments

---
 llvm/docs/NVPTXUsage.rst                | 53 ++++++++++++++++---------
 llvm/include/llvm/IR/IntrinsicsNVVM.td  | 20 ++++++----
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 22 +++++-----
 llvm/test/CodeGen/NVPTX/dot-product.ll  | 36 ++++++++++-------
 4 files changed, 79 insertions(+), 52 deletions(-)

diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 536186db85e4ca..872dedf8a82def 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -290,56 +290,71 @@ The address operand ``addr`` and the operand ``size`` together specify the memor
 Arithmetic Intrinsics
 ---------------------
 
-'``llvm.nvvm.idp2a``' Intrinsic
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Syntax:
 """""""
 
 .. code-block:: llvm
 
-    declare i32 @llvm.nvvm.idp2a(i32 %a, i1 immarg %a.unsigned, i32 %b, i1 immarg %b.unsigned, i1 immarg %is.hi, i32 %c)
+    declare i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
+    declare i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
+    declare i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
+    declare i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
+
 
 Overview:
 """""""""
 
-The '``llvm.nvvm.idp2a``' intrinsic performs a 2-element vector dot product
-followed by addition. It corresponds directly to the ``dp2a`` PTX instruction.
+The '``llvm.nvvm.idp2a.[us].[us]``' intrinsics performs a 2-element vector dot
+product followed by addition. They corresponds directly to the ``dp2a`` PTX 
+instruction.
 
 Semantics:
 """"""""""
 
-The 32-bit value in ``%a`` is broken into 2 16-bit values which are either sign
-or zero extended, depending on the value of ``%a.unsigned``, to 32 bits. Two
-bytes are selected from ``%b``, if ``%is.hi`` is true, the most significant
-bytes are selected, otherwise the least significant bytes are selected. These
-bytes are each sign or zero extended, depending on ``%b.unsigned``. The dot
-product of these 2-element vectors is added to ``%c`` to produce the return.
+The 32-bit value in ``%a`` is broken into 2 16-bit values which are extended to
+32 bits. For the '``llvm.nvvm.idp2a.u.[us]``' variants zero-extension is used,
+while for the '``llvm.nvvm.idp2a.s.[us]``' sign-extension is used. Two bytes are
+selected from ``%b``, if ``%is.hi`` is true, the most significant bytes are
+selected, otherwise the least significant bytes are selected. These bytes are
+then extended to 32-bits. For the '``llvm.nvvm.idp2a.[us].u``' variants
+zero-extension is used, while for the '``llvm.nvvm.idp2a.[us].s``'
+sign-extension is used. The dot product of these 2-element vectors is added to
+``%c`` to produce the return.
 
 
-'``llvm.nvvm.idp4a``' Intrinsic
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+'``llvm.nvvm.idp4a.[us].[us]``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Syntax:
 """""""
 
 .. code-block:: llvm
 
-    declare i32 @llvm.nvvm.idp4a(i32 %a, i1 immarg %a.unsigned, i32 %b, i1 immarg %b.unsigned, i32 %c)
+    declare i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
+    declare i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
+    declare i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
+    declare i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)
 
 Overview:
 """""""""
 
-The '``llvm.nvvm.idp4a``' intrinsic performs a 4-element vector dot product
-followed by addition. It corresponds directly to the ``dp4a`` PTX instruction.
+The '``llvm.nvvm.idp4a.[us].[us]``' intrinsics perform a 4-element vector dot
+product followed by addition. They corresponds directly to the ``dp4a`` PTX
+instruction.
 
 Semantics:
 """"""""""
 
 Each of the 4 bytes in both ``%a`` and ``%b`` are extended to 32-bit integers
-forming 2 ``<4 x i32>``. zero-extension is used if ``%a.unsigned`` or
-``%b.unsigned`` is true respectively. The dot product of these 4-element vectors
-is added to ``%c`` to produce the return.
+forming 2 ``<4 x i32>``. For ``%a``, zero-extension is used in the
+'``llvm.nvvm.idp4a.u.[us]``' variants, while sign-extension is used with
+'``llvm.nvvm.idp4a.s.[us]``' variants. Similarly, for ``%b``, zero-extension is
+used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used
+with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element
+vectors is added to ``%c`` to produce the return.
 
 
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 8ab7da31cf3962..65a3d2d0f943a7 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1055,14 +1055,18 @@ let TargetPrefix = "nvvm" in {
 //
 // Dot Product
 //
-  def int_nvvm_idp4a :
-      DefaultAttrsIntrinsic<[llvm_i32_ty],
-          [llvm_i32_ty, llvm_i1_ty, llvm_i32_ty, llvm_i1_ty, llvm_i32_ty],
-          [IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<3>>]>;
-  def int_nvvm_idp2a :
-      DefaultAttrsIntrinsic<[llvm_i32_ty],
-        [llvm_i32_ty, llvm_i1_ty, llvm_i32_ty, llvm_i1_ty, llvm_i1_ty, llvm_i32_ty],
-        [IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>]>;
+  foreach a_type = ["s", "u"] in {
+    foreach b_type = ["s", "u"] in {
+      def int_nvvm_idp4a_ # a_type # _ # b_type :
+          DefaultAttrsIntrinsic<[llvm_i32_ty],
+              [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
+              [IntrNoMem, IntrSpeculatable]>;
+      def int_nvvm_idp2a_ # a_type # _ # b_type :
+          DefaultAttrsIntrinsic<[llvm_i32_ty],
+            [llvm_i32_ty, llvm_i32_ty, llvm_i1_ty, llvm_i32_ty],
+            [IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<2>>]>;
+    }
+  }
 
 //
 // Convert
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 2c38bdf20af57e..b57c86fcf697cd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3921,26 +3921,28 @@ let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in
 }
 
 
-foreach a_unsigned = [0, -1] in {
-  foreach b_unsigned = [0, -1] in {
-    defvar a_suffix = !if(a_unsigned, "u32", "s32");
-    defvar b_suffix = !if(b_unsigned, "u32", "s32");
+foreach a_type = ["s", "u"] in {
+  foreach b_type = ["s", "u"] in {
 
-    def DOT4_ # a_suffix # _ # b_suffix :
+    def DOT4_ # a_type # b_type :
       NVPTXInst<(outs Int32Regs:$dst),
                 (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
-                "dp4a." # a_suffix # "." # b_suffix # " \t$dst, $a, $b, $c;",
-                [(set Int32Regs:$dst, (int_nvvm_idp4a (i32 Int32Regs:$a), a_unsigned, (i32 Int32Regs:$b), b_unsigned, (i32 Int32Regs:$c)))]>,
+                "dp4a." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
+                [(set Int32Regs:$dst,
+                    (!cast<Intrinsic>("int_nvvm_idp4a_" # a_type # "_" # b_type)
+                     (i32 Int32Regs:$a), (i32 Int32Regs:$b), (i32 Int32Regs:$c)))]>,
                 Requires<[hasDotInstructions]>;
 
     foreach is_hi = [0, -1] in {
       defvar lohi_suffix = !if(is_hi, "hi", "lo");
 
-      def DOT2_ # lohi_suffix # _ # a_suffix # _ # b_suffix :
+      def DOT2_ # lohi_suffix # _ # a_type # b_type :
         NVPTXInst<(outs Int32Regs:$dst),
                   (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
-                  "dp2a." # lohi_suffix # "." # a_suffix # "." # b_suffix # " \t$dst, $a, $b, $c;",
-                  [(set Int32Regs:$dst, (int_nvvm_idp2a (i32 Int32Regs:$a), a_unsigned, (i32 Int32Regs:$b), b_unsigned, is_hi, (i32 Int32Regs:$c)))]>,
+                  "dp2a." # lohi_suffix # "." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
+                  [(set Int32Regs:$dst,
+                      (!cast<Intrinsic>("int_nvvm_idp2a_" # a_type # "_" # b_type)
+                       (i32 Int32Regs:$a), (i32 Int32Regs:$b), is_hi, (i32 Int32Regs:$c)))]>,
                   Requires<[hasDotInstructions]>;
     }
   }
diff --git a/llvm/test/CodeGen/NVPTX/dot-product.ll b/llvm/test/CodeGen/NVPTX/dot-product.ll
index d06d6a799c2c1b..36529bbef90332 100644
--- a/llvm/test/CodeGen/NVPTX/dot-product.ll
+++ b/llvm/test/CodeGen/NVPTX/dot-product.ll
@@ -4,7 +4,10 @@
 
 target triple = "nvptx-nvidia-cuda"
 
-declare i32 @llvm.nvvm.idp4a(i32, i1 immarg, i32, i1 immarg, i32)
+declare i32 @llvm.nvvm.idp4a.s.s(i32, i32, i32)
+declare i32 @llvm.nvvm.idp4a.s.u(i32, i32, i32)
+declare i32 @llvm.nvvm.idp4a.u.s(i32, i32, i32)
+declare i32 @llvm.nvvm.idp4a.u.u(i32, i32, i32)
 
 define i32 @test_dp4a_u32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-LABEL: test_dp4a_u32_u32(
@@ -18,7 +21,7 @@ define i32 @test_dp4a_u32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp4a.u32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 1, i32 %b, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)
   ret i32 %call
 }
 
@@ -33,7 +36,7 @@ define i32 @test_dp4a_u32imm_u32imm(i32 %c) {
 ; CHECK-NEXT:    dp4a.u32.u32 %r3, %r2, %r2, %r1;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp4a(i32 0, i1 1, i32 0, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp4a.u.u(i32 0, i32 0, i32 %c)
   ret i32 %call
 }
 
@@ -49,7 +52,7 @@ define i32 @test_dp4a_u32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp4a.u32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 1, i32 %b, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
   ret i32 %call
 }
 
@@ -65,7 +68,7 @@ define i32 @test_dp4a_s32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp4a.s32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 0, i32 %b, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
   ret i32 %call
 }
 
@@ -81,11 +84,14 @@ define i32 @test_dp4a_s32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp4a.s32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp4a(i32 %a, i1 0, i32 %b, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
   ret i32 %call
 }
 
-declare i32 @llvm.nvvm.idp2a(i32, i1 immarg, i32, i1 immarg, i1 immarg, i32)
+declare i32 @llvm.nvvm.idp2a.s.s(i32, i32, i1 immarg, i32)
+declare i32 @llvm.nvvm.idp2a.s.u(i32, i32, i1 immarg, i32)
+declare i32 @llvm.nvvm.idp2a.u.s(i32, i32, i1 immarg, i32)
+declare i32 @llvm.nvvm.idp2a.u.u(i32, i32, i1 immarg, i32)
 
 define i32 @test_dp2a_lo_u32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-LABEL: test_dp2a_lo_u32_u32(
@@ -99,7 +105,7 @@ define i32 @test_dp2a_lo_u32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.lo.u32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 1, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 0, i32 %c)
   ret i32 %call
 }
 
@@ -115,7 +121,7 @@ define i32 @test_dp2a_lo_u32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.lo.u32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 0, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 0, i32 %c)
   ret i32 %call
 }
 
@@ -131,7 +137,7 @@ define i32 @test_dp2a_lo_s32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.lo.s32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 1, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 0, i32 %c)
   ret i32 %call
 }
 
@@ -147,7 +153,7 @@ define i32 @test_dp2a_lo_s32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.lo.s32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 0, i1 0, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 0, i32 %c)
   ret i32 %call
 }
 
@@ -163,7 +169,7 @@ define i32 @test_dp2a_hi_u32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.hi.u32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 1, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 1, i32 %c)
   ret i32 %call
 }
 
@@ -179,7 +185,7 @@ define i32 @test_dp2a_hi_u32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.hi.u32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 1, i32 %b, i1 0, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 1, i32 %c)
   ret i32 %call
 }
 
@@ -195,7 +201,7 @@ define i32 @test_dp2a_hi_s32_u32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.hi.s32.u32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 1, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 1, i32 %c)
   ret i32 %call
 }
 
@@ -211,6 +217,6 @@ define i32 @test_dp2a_hi_s32_s32(i32 %a, i32 %b, i32 %c) {
 ; CHECK-NEXT:    dp2a.hi.s32.s32 %r4, %r1, %r2, %r3;
 ; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
 ; CHECK-NEXT:    ret;
-  %call = call i32 @llvm.nvvm.idp2a(i32 %a, i1 0, i32 %b, i1 0, i1 1, i32 %c)
+  %call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 1, i32 %c)
   ret i32 %call
 }



More information about the llvm-commits mailing list