[llvm] [NVPTX] Customize getScalarizationOverhead (PR #128077)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 21 12:17:52 PST 2025
peterbell10 wrote:
> Fair enough. It's still not free as it needs prmt/xmad. I think that cost adjust ment here is a wash.
The current heuristic results in a cost of 2 to build `<2 x half>` and on SM90 the test case I added is not currently vectorized, so this is a meaningful change.
> Perhaps I'm missing something? Can you compile the test cases you've added to v2f16.ll to PTX, and see what we get for both variants in both PTX and SASS?
Here are the results for `sm_90`:
<details>
<summary>PTX for `sm_90`</summary>
//
// Generated by LLVM NVPTX Back-End
//
.version 7.8
.target sm_90
.address_size 64
// .globl fusion // -- Begin function fusion
// @fusion
.visible .func fusion(
.param .b64 fusion_param_0,
.param .b64 fusion_param_1,
.param .b32 fusion_param_2,
.param .b32 fusion_param_3
)
{
.reg .b32 %r<10>;
.reg .b64 %rd<6>;
// %bb.0:
ld.param.u64 %rd1, [fusion_param_0];
ld.param.u64 %rd2, [fusion_param_1];
ld.param.u32 %r1, [fusion_param_2];
ld.param.u32 %r2, [fusion_param_3];
shl.b32 %r3, %r2, 2;
shl.b32 %r4, %r1, 8;
or.b32 %r5, %r4, %r3;
mul.wide.u32 %rd3, %r5, 2;
add.s64 %rd4, %rd2, %rd3;
add.s64 %rd5, %rd1, %rd3;
ld.b32 %r6, [%rd4];
mov.b32 %r7, 1475368944;
mov.b32 %r8, 1400918912;
fma.rn.f16x2 %r9, %r6, %r8, %r7;
st.b32 [%rd5], %r9;
ret;
// -- End function
}
// .globl add_f16 // -- Begin function add_f16
.visible .entry add_f16(
.param .u64 .ptr .global .align 1 add_f16_param_0,
.param .align 2 .b8 add_f16_param_1[4],
.param .align 2 .b8 add_f16_param_2[4]
) // @add_f16
{
.reg .b16 %rs<5>;
.reg .b32 %r<7>;
.reg .b64 %rd<4>;
// %bb.0:
ld.param.u64 %rd1, [add_f16_param_0];
ld.param.b16 %rs1, [add_f16_param_1+2];
ld.param.b16 %rs2, [add_f16_param_1];
ld.param.b16 %rs3, [add_f16_param_2+2];
ld.param.b16 %rs4, [add_f16_param_2];
mov.b32 %r1, {%rs2, %rs1};
mov.b32 %r2, {%rs4, %rs3};
add.rn.f16x2 %r3, %r1, %r2;
mov.u32 %r4, %tid.x;
shl.b32 %r5, %r4, 1;
and.b32 %r6, %r5, 62;
mul.wide.u32 %rd2, %r6, 2;
add.s64 %rd3, %rd1, %rd2;
st.global.b32 [%rd3], %r3;
ret;
// -- End function
}
</details>
<details>
<summary>SASS for sm_90</summary>
code for sm_90
Function : add_f16
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
/*0000*/ LDC R1, c[0x0][0x28] ; /* 0x00000a00ff017b82 */
/* 0x000fe20000000800 */
/*0010*/ S2R R0, SR_TID.X ; /* 0x0000000000007919 */
/* 0x000e6e0000002100 */
/*0020*/ LDC.64 R4, c[0x0][0x218] ; /* 0x00008600ff047b82 */
/* 0x000ea20000000a00 */
/*0030*/ ULDC.64 UR4, c[0x0][0x208] ; /* 0x0000820000047ab9 */
/* 0x000fce0000000a00 */
/*0040*/ LDC.64 R2, c[0x0][0x210] ; /* 0x00008400ff027b82 */
/* 0x000ee20000000a00 */
/*0050*/ HFMA2.MMA R5, R4, 1, 1, R5 ; /* 0x3c003c0004057835 */
/* 0x004fe20000000005 */
/*0060*/ SHF.L.U32 R0, R0, 0x1, RZ ; /* 0x0000000100007819 */
/* 0x002fc800000006ff */
/*0070*/ LOP3.LUT R7, R0, 0x3e, RZ, 0xc0, !PT ; /* 0x0000003e00077812 */
/* 0x000fca00078ec0ff */
/*0080*/ IMAD.WIDE.U32 R2, R7, 0x2, R2 ; /* 0x0000000207027825 */
/* 0x008fca00078e0002 */
/*0090*/ STG.E desc[UR4][R2.64], R5 ; /* 0x0000000502007986 */
/* 0x000fe2000c101904 */
/*00a0*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*00b0*/ BRA 0xb0; /* 0xfffffffc00fc7947 */
/* 0x000fc0000383ffff */
/*00c0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00d0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00e0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00f0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0100*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0110*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0120*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0130*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0140*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0150*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0160*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0170*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
..........
Function : fusion
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
/*0000*/ SHF.L.U32 R9, R9, 0x2, RZ ; /* 0x0000000209097819 */
/* 0x000fe200000006ff */
/*0010*/ ULDC.64 UR4, c[0x0][0x208] ; /* 0x0000820000047ab9 */
/* 0x000fe20000000a00 */
/*0020*/ SHF.L.U32 R8, R8, 0x8, RZ ; /* 0x0000000808087819 */
/* 0x000fc800000006ff */
/*0030*/ LOP3.LUT R9, R8, R9, RZ, 0xfc, !PT ; /* 0x0000000908097212 */
/* 0x000fca00078efcff */
/*0040*/ IMAD.WIDE.U32 R6, R9, 0x2, R6 ; /* 0x0000000209067825 */
/* 0x000fcc00078e0006 */
/*0050*/ LD.E R6, desc[UR4][R6.64] ; /* 0x0000000406067980 */
/* 0x000ea2000c101900 */
/*0060*/ HFMA2.MMA R3, -RZ, RZ, 0, 60 ; /* 0x00005380ff037435 */
/* 0x000fe200000001ff */
/*0070*/ IMAD.WIDE.U32 R4, R9, 0x2, R4 ; /* 0x0000000209047825 */
/* 0x000fd200078e0004 */
/*0080*/ HFMA2 R3, R6, R3.H0_H0, 127, 127 ; /* 0x57f057f006037431 */
/* 0x004fca0000040003 */
/*0090*/ ST.E desc[UR4][R4.64], R3 ; /* 0x0000000304007985 */
/* 0x0001e4000c101904 */
/*00a0*/ RET.ABS.NODEC R20 0x0 ; /* 0x0000000014007950 */
/* 0x001fea0003e00000 */
/*00b0*/ BRA 0xb0; /* 0xfffffffc00fc7947 */
/* 0x000fc0000383ffff */
/*00c0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00d0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00e0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00f0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0100*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0110*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0120*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0130*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0140*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0150*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0160*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*0170*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
</details>
The SASS is a bit confusing, but I think `LDC.64 R4, c[0x0][0x218]` is loading `R4` and `R5` directly as bf16x2 vector registers in a single instruction, which is quite cool.
https://github.com/llvm/llvm-project/pull/128077
More information about the llvm-commits
mailing list