[llvm] [DirectX] Add `WaveActiveOp` builtin (PR #112058)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 15:46:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-llvm-ir

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

    - create int_dx_wave_active_op in IntrinsicsDirectX.td
    - add mapping to dxil op in DXIL.td
    - add scalarization to DirectXTargetTransformInfo.cpp
    - add tests of lowerings to dxil ops for both scalar and vector values

This is required and part 1 of implementing the DirectX lowering of #<!-- -->99170 and #<!-- -->70106.

---
Full diff: https://github.com/llvm/llvm-project/pull/112058.diff


5 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+10) 
- (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+4) 
- (added) llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll (+34) 
- (added) llvm/test/CodeGen/DirectX/WaveActiveOp.ll (+53) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 45aea1ccdb6d4c..fa865718bc5528 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -84,6 +84,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_wave_active_op : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i8_ty, llvm_i8_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index e8f56b18730d71..df43cae5edaed5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -793,6 +793,16 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
   let stages = [Stages<DXIL1_6, [all_stages]>];
 }
 
+def WaveActiveOp : DXILOp<119, waveActiveOp> {
+  let Doc = "returns the result of the operation across waves";
+  let LLVMIntrinsic = int_dx_wave_active_op;
+  let arguments = [OverloadTy, Int8Ty, Int8Ty];
+  let result = OverloadTy;
+  let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
   let Doc = "returns 1 for the first lane in the wave";
   let LLVMIntrinsic = int_dx_wave_is_first_lane;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..b0f54a0679de25 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -18,6 +18,9 @@ using namespace llvm;
 bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                         unsigned ScalarOpdIdx) {
   switch (ID) {
+  case Intrinsic::dx_wave_active_op: {
+    return ScalarOpdIdx == 1 || ScalarOpdIdx == 2;
+  }
   default:
     return false;
   }
@@ -26,6 +29,7 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
+  case Intrinsic::dx_wave_active_op:
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
     return true;
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll
new file mode 100644
index 00000000000000..d5d1e615e99af1
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll
@@ -0,0 +1,34 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
+
+define noundef <2 x half> @wave_active_op_v2half(<2 x half> noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 0, i8 0)
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 0, i8 0)
+  %ret = call <2 x half> @llvm.dx.wave.active.op.f16(<2 x half> %expr, i8 0, i8 0)
+  ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_op_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 1, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 1, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 1, i8 1)
+  %ret = call <3 x i32> @llvm.dx.wave.active.op(<3 x i32> %expr, i8 1, i8 1)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_op_v4f64(<4 x double> noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 2, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 2, i8 0)
+  %ret = call <4 x double> @llvm.dx.wave.active.op(<4 x double> %expr, i8 2, i8 0)
+  ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.active.op.v2f16(<2 x half>, i8, i8)
+declare <3 x i32> @llvm.dx.wave.active.op.v3i32(<3 x i32>, i8, i8)
+declare <4 x double> @llvm.dx.wave.active.op.v4f64(<4 x double>, i8, i8)
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll
new file mode 100644
index 00000000000000..e6cafd696d25c6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll
@@ -0,0 +1,53 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
+
+define noundef half @wave_active_op_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 0, i8 0)
+  %ret = call half @llvm.dx.wave.active.op.f16(half %expr, i8 0, i8 0)
+  ret half %ret
+}
+
+define noundef float @wave_active_op_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 1, i8 0)
+  %ret = call float @llvm.dx.wave.active.op(float %expr, i8 1, i8 0)
+  ret float %ret
+}
+
+define noundef double @wave_active_op_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 2, i8 0)
+  %ret = call double @llvm.dx.wave.active.op(double %expr, i8 2, i8 0)
+  ret double %ret
+}
+
+define noundef i16 @wave_active_op_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 1, i8 0)
+  %ret = call i16 @llvm.dx.wave.active.op.i16(i16 %expr, i8 1, i8 0)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_op_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 2, i8 1)
+  %ret = call i32 @llvm.dx.wave.active.op.i32(i32 %expr, i8 2, i8 1)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_op_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0)
+  %ret = call i64 @llvm.dx.wave.active.op.i64(i64 %expr, i8 3, i8 0)
+  ret i64 %ret
+}
+
+declare half @llvm.dx.wave.active.op.f16(half, i8, i8)
+declare float @llvm.dx.wave.active.op.f32(float, i8, i8)
+declare double @llvm.dx.wave.active.op.f64(double, i8, i8)
+
+declare i16 @llvm.dx.wave.active.op.i16(i16, i8, i8)
+declare i32 @llvm.dx.wave.active.op.i32(i32, i8, i8)
+declare i64 @llvm.dx.wave.active.op.i64(i64, i8, i8)

``````````

</details>


https://github.com/llvm/llvm-project/pull/112058


More information about the llvm-commits mailing list