[llvm] [NVPTX] Customize getScalarizationOverhead (PR #128077)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 12:17:52 PST 2025


peterbell10 wrote:

> Fair enough. It's still not free as it needs prmt/xmad. I think that cost adjust ment here is a wash.

The current heuristic results in a cost of 2 to build `<2 x half>` and on SM90 the test case I added is not currently vectorized, so this is a meaningful change.

> Perhaps I'm missing something? Can you compile the test cases you've added to v2f16.ll to PTX, and see what we get for both variants in both PTX and SASS?

Here are the results for `sm_90`:
<details>
<summary>PTX for `sm_90`</summary>

//
// Generated by LLVM NVPTX Back-End
//

.version 7.8
.target sm_90
.address_size 64

	// .globl	fusion                  // -- Begin function fusion
                                        // @fusion
.visible .func fusion(
	.param .b64 fusion_param_0,
	.param .b64 fusion_param_1,
	.param .b32 fusion_param_2,
	.param .b32 fusion_param_3
)
{
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<6>;

// %bb.0:
	ld.param.u64 	%rd1, [fusion_param_0];
	ld.param.u64 	%rd2, [fusion_param_1];
	ld.param.u32 	%r1, [fusion_param_2];
	ld.param.u32 	%r2, [fusion_param_3];
	shl.b32 	%r3, %r2, 2;
	shl.b32 	%r4, %r1, 8;
	or.b32  	%r5, %r4, %r3;
	mul.wide.u32 	%rd3, %r5, 2;
	add.s64 	%rd4, %rd2, %rd3;
	add.s64 	%rd5, %rd1, %rd3;
	ld.b32 	%r6, [%rd4];
	mov.b32 	%r7, 1475368944;
	mov.b32 	%r8, 1400918912;
	fma.rn.f16x2 	%r9, %r6, %r8, %r7;
	st.b32 	[%rd5], %r9;
	ret;
                                        // -- End function
}
	// .globl	add_f16                 // -- Begin function add_f16
.visible .entry add_f16(
	.param .u64 .ptr .global .align 1 add_f16_param_0,
	.param .align 2 .b8 add_f16_param_1[4],
	.param .align 2 .b8 add_f16_param_2[4]
)                                       // @add_f16
{
	.reg .b16 	%rs<5>;
	.reg .b32 	%r<7>;
	.reg .b64 	%rd<4>;

// %bb.0:
	ld.param.u64 	%rd1, [add_f16_param_0];
	ld.param.b16 	%rs1, [add_f16_param_1+2];
	ld.param.b16 	%rs2, [add_f16_param_1];
	ld.param.b16 	%rs3, [add_f16_param_2+2];
	ld.param.b16 	%rs4, [add_f16_param_2];
	mov.b32 	%r1, {%rs2, %rs1};
	mov.b32 	%r2, {%rs4, %rs3};
	add.rn.f16x2 	%r3, %r1, %r2;
	mov.u32 	%r4, %tid.x;
	shl.b32 	%r5, %r4, 1;
	and.b32  	%r6, %r5, 62;
	mul.wide.u32 	%rd2, %r6, 2;
	add.s64 	%rd3, %rd1, %rd2;
	st.global.b32 	[%rd3], %r3;
	ret;
                                        // -- End function
}

</details>

<details>
<summary>SASS for sm_90</summary>

	code for sm_90
		Function : add_f16
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   LDC R1, c[0x0][0x28] ;                  /* 0x00000a00ff017b82 */
                                                                           /* 0x000fe20000000800 */
        /*0010*/                   S2R R0, SR_TID.X ;                      /* 0x0000000000007919 */
                                                                           /* 0x000e6e0000002100 */
        /*0020*/                   LDC.64 R4, c[0x0][0x218] ;              /* 0x00008600ff047b82 */
                                                                           /* 0x000ea20000000a00 */
        /*0030*/                   ULDC.64 UR4, c[0x0][0x208] ;            /* 0x0000820000047ab9 */
                                                                           /* 0x000fce0000000a00 */
        /*0040*/                   LDC.64 R2, c[0x0][0x210] ;              /* 0x00008400ff027b82 */
                                                                           /* 0x000ee20000000a00 */
        /*0050*/                   HFMA2.MMA R5, R4, 1, 1, R5 ;            /* 0x3c003c0004057835 */
                                                                           /* 0x004fe20000000005 */
        /*0060*/                   SHF.L.U32 R0, R0, 0x1, RZ ;             /* 0x0000000100007819 */
                                                                           /* 0x002fc800000006ff */
        /*0070*/                   LOP3.LUT R7, R0, 0x3e, RZ, 0xc0, !PT ;  /* 0x0000003e00077812 */
                                                                           /* 0x000fca00078ec0ff */
        /*0080*/                   IMAD.WIDE.U32 R2, R7, 0x2, R2 ;         /* 0x0000000207027825 */
                                                                           /* 0x008fca00078e0002 */
        /*0090*/                   STG.E desc[UR4][R2.64], R5 ;            /* 0x0000000502007986 */
                                                                           /* 0x000fe2000c101904 */
        /*00a0*/                   EXIT ;                                  /* 0x000000000000794d */
                                                                           /* 0x000fea0003800000 */
        /*00b0*/                   BRA 0xb0;                               /* 0xfffffffc00fc7947 */
                                                                           /* 0x000fc0000383ffff */
        /*00c0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00d0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00e0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00f0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0100*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0110*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0120*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0130*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0140*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0150*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0160*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0170*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
		..........


		Function : fusion
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   SHF.L.U32 R9, R9, 0x2, RZ ;             /* 0x0000000209097819 */
                                                                           /* 0x000fe200000006ff */
        /*0010*/                   ULDC.64 UR4, c[0x0][0x208] ;            /* 0x0000820000047ab9 */
                                                                           /* 0x000fe20000000a00 */
        /*0020*/                   SHF.L.U32 R8, R8, 0x8, RZ ;             /* 0x0000000808087819 */
                                                                           /* 0x000fc800000006ff */
        /*0030*/                   LOP3.LUT R9, R8, R9, RZ, 0xfc, !PT ;    /* 0x0000000908097212 */
                                                                           /* 0x000fca00078efcff */
        /*0040*/                   IMAD.WIDE.U32 R6, R9, 0x2, R6 ;         /* 0x0000000209067825 */
                                                                           /* 0x000fcc00078e0006 */
        /*0050*/                   LD.E R6, desc[UR4][R6.64] ;             /* 0x0000000406067980 */
                                                                           /* 0x000ea2000c101900 */
        /*0060*/                   HFMA2.MMA R3, -RZ, RZ, 0, 60 ;          /* 0x00005380ff037435 */
                                                                           /* 0x000fe200000001ff */
        /*0070*/                   IMAD.WIDE.U32 R4, R9, 0x2, R4 ;         /* 0x0000000209047825 */
                                                                           /* 0x000fd200078e0004 */
        /*0080*/                   HFMA2 R3, R6, R3.H0_H0, 127, 127 ;      /* 0x57f057f006037431 */
                                                                           /* 0x004fca0000040003 */
        /*0090*/                   ST.E desc[UR4][R4.64], R3 ;             /* 0x0000000304007985 */
                                                                           /* 0x0001e4000c101904 */
        /*00a0*/                   RET.ABS.NODEC R20 0x0 ;                 /* 0x0000000014007950 */
                                                                           /* 0x001fea0003e00000 */
        /*00b0*/                   BRA 0xb0;                               /* 0xfffffffc00fc7947 */
                                                                           /* 0x000fc0000383ffff */
        /*00c0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00d0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00e0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00f0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0100*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0110*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0120*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0130*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0140*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0150*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0160*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0170*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */

</details>


The SASS is a bit confusing, but I think `LDC.64 R4, c[0x0][0x218]` is loading `R4` and `R5` directly as bf16x2 vector registers in a single instruction, which is quite cool.

https://github.com/llvm/llvm-project/pull/128077


More information about the llvm-commits mailing list