[llvm] [DirectX] Add `WaveActiveOp` builtin (PR #112058)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 15:40:12 PDT 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/112058

    - create int_dx_wave_active_op in IntrinsicsDirectX.td
    - add mapping to dxil op in DXIL.td
    - add scalarization to DirectXTargetTransformInfo.cpp
    - add tests of lowerings to dxil ops for both scalar and vector values

This is required and part 1 of implementing the DirectX lowering of #99170 and #70106.

>From 3a13167e6e55ba440f7c5839e6e97156bc1d9daf Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 11 Oct 2024 15:03:12 -0700
Subject: [PATCH] [DirectX] Add `WaveActiveOp` builtin

    - create int_dx_wave_active_op in IntrinsicsDirectX.td
    - add mapping to dxil op in DXIL.td
    - add scalarization to DirectXTargetTransformInfo.cpp
    - add tests of lowerings to dxil ops for both scalar and vector
      values
---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  1 +
 llvm/lib/Target/DirectX/DXIL.td               | 10 ++++
 .../DirectX/DirectXTargetTransformInfo.cpp    |  4 ++
 llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll | 34 ++++++++++++
 llvm/test/CodeGen/DirectX/WaveActiveOp.ll     | 53 +++++++++++++++++++
 5 files changed, 102 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll
 create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveOp.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f2b9e286ebb476..95c00e66a3d6fd 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -83,6 +83,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_wave_active_op : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i8_ty, llvm_i8_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9aa0af3e3a6b17..78c7055567dec1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -793,6 +793,16 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
   let stages = [Stages<DXIL1_6, [all_stages]>];
 }
 
+def WaveActiveOp : DXILOp<119, waveActiveOp> {
+  let Doc = "returns the result of the operation across waves";
+  let LLVMIntrinsic = int_dx_wave_active_op;
+  let arguments = [OverloadTy, Int8Ty, Int8Ty];
+  let result = OverloadTy;
+  let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
   let Doc = "returns 1 for the first lane in the wave";
   let LLVMIntrinsic = int_dx_wave_is_first_lane;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..b0f54a0679de25 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -18,6 +18,9 @@ using namespace llvm;
 bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                         unsigned ScalarOpdIdx) {
   switch (ID) {
+  case Intrinsic::dx_wave_active_op: {
+    return ScalarOpdIdx == 1 || ScalarOpdIdx == 2;
+  }
   default:
     return false;
   }
@@ -26,6 +29,7 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
+  case Intrinsic::dx_wave_active_op:
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
     return true;
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll
new file mode 100644
index 00000000000000..d5d1e615e99af1
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll
@@ -0,0 +1,34 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
+
+define noundef <2 x half> @wave_active_op_v2half(<2 x half> noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 0, i8 0)
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 0, i8 0)
+  %ret = call <2 x half> @llvm.dx.wave.active.op.f16(<2 x half> %expr, i8 0, i8 0)
+  ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_op_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 1, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 1, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 1, i8 1)
+  %ret = call <3 x i32> @llvm.dx.wave.active.op(<3 x i32> %expr, i8 1, i8 1)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_op_v4f64(<4 x double> noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 2, i8 0)
+  %ret = call <4 x double> @llvm.dx.wave.active.op(<4 x double> %expr, i8 2, i8 0)
+  ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.active.op.v2f16(<2 x half>, i8, i8)
+declare <3 x i32> @llvm.dx.wave.active.op.v3i32(<3 x i32>, i8, i8)
+declare <4 x double> @llvm.dx.wave.active.op.v4f64(<4 x double>, i8, i8)
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll
new file mode 100644
index 00000000000000..e6cafd696d25c6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll
@@ -0,0 +1,53 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
+
+define noundef half @wave_active_op_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 0, i8 0)
+  %ret = call half @llvm.dx.wave.active.op.f16(half %expr, i8 0, i8 0)
+  ret half %ret
+}
+
+define noundef float @wave_active_op_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 1, i8 0)
+  %ret = call float @llvm.dx.wave.active.op(float %expr, i8 1, i8 0)
+  ret float %ret
+}
+
+define noundef double @wave_active_op_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 2, i8 0)
+  %ret = call double @llvm.dx.wave.active.op(double %expr, i8 2, i8 0)
+  ret double %ret
+}
+
+define noundef i16 @wave_active_op_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 1, i8 0)
+  %ret = call i16 @llvm.dx.wave.active.op.i16(i16 %expr, i8 1, i8 0)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_op_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 2, i8 1)
+  %ret = call i32 @llvm.dx.wave.active.op.i32(i32 %expr, i8 2, i8 1)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_op_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0)
+  %ret = call i64 @llvm.dx.wave.active.op.i64(i64 %expr, i8 3, i8 0)
+  ret i64 %ret
+}
+
+declare half @llvm.dx.wave.active.op.f16(half, i8, i8)
+declare float @llvm.dx.wave.active.op.f32(float, i8, i8)
+declare double @llvm.dx.wave.active.op.f64(double, i8, i8)
+
+declare i16 @llvm.dx.wave.active.op.i16(i16, i8, i8)
+declare i32 @llvm.dx.wave.active.op.i32(i32, i8, i8)
+declare i64 @llvm.dx.wave.active.op.i64(i64, i8, i8)



More information about the llvm-commits mailing list