[llvm] [AMDGPU] Fix bf16 inv2pi inline constant hadling (PR #82283)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 19 14:34:05 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mc

Author: Stanislav Mekhanoshin (rampitec)

<details>
<summary>Changes</summary>

Inline constant 1/(2*pi) has the truncated value 0x3e22. According to the spec it is not rounded. A bf16 value in a nutshall is a fp32 value with cleared 16 bites of mantissa. The value 0x3e22 converted to fp32 is 0.158203125 and the next representable value 0x3e23 means 0.1591796875. The fp32 value of 1/(2*pi) = 0.15915494 cannot be represented in bf16. Although since bf16 values are essentailly truncated fp32 values we can use 0.15915494 as an idiomatic representation of 1/(2*pi) inline constant. This is also consistent with sp3 behaviour. The patch fixes the problem that value we are printing for inv2pi inline constant is not parsed as inv2pi by the asm parser and gets rounded.

---
Full diff: https://github.com/llvm/llvm-project/pull/82283.diff


2 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+18-7) 
- (modified) llvm/test/MC/AMDGPU/bf16_imm.s (+5-1) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 883b30562e911b..85bd33e4efbd0f 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -2230,6 +2230,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
       // in predicate methods (isLiteralImm())
       llvm_unreachable("fp literal in 64-bit integer instruction.");
 
+    case AMDGPU::OPERAND_REG_IMM_BF16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
+      if (AsmParser->hasInv2PiInlineImm() && Literal == 0x3fc45f306725feed) {
+        // This is the 1/(2*pi) which is going to be truncated to bf16 with the
+        // loss of precision. The constant represents ideomatic fp32 value of
+        // 1/(2*pi) = 0.15915494 since bf16 is in fact fp32 with cleared low 16
+        // bits. Prevent rounding below.
+        Inst.addOperand(MCOperand::createImm(0x3e22));
+        setImmKindLiteral();
+        return;
+      }
+      [[fallthrough]];
+
     case AMDGPU::OPERAND_REG_IMM_INT32:
     case AMDGPU::OPERAND_REG_IMM_FP32:
     case AMDGPU::OPERAND_REG_IMM_FP32_DEFERRED:
@@ -2238,24 +2256,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
     case AMDGPU::OPERAND_REG_IMM_INT16:
-    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
-    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
     case AMDGPU::OPERAND_REG_INLINE_C_INT16:
-    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
-    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
-    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
-    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
-    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
     case AMDGPU::OPERAND_REG_IMM_V2FP32:
diff --git a/llvm/test/MC/AMDGPU/bf16_imm.s b/llvm/test/MC/AMDGPU/bf16_imm.s
index c27cbcc4f294a2..f04e347e5afb3b 100644
--- a/llvm/test/MC/AMDGPU/bf16_imm.s
+++ b/llvm/test/MC/AMDGPU/bf16_imm.s
@@ -34,10 +34,14 @@ v_dot2_bf16_bf16 v2, v0, 4.0, v2
 v_dot2_bf16_bf16 v2, v0, -4.0, v2
 // CHECK: v_dot2_bf16_bf16 v2, v0, -4.0, v2       ; encoding: [0x02,0x00,0x67,0xd6,0x00,0xef,0x09,0x04]
 
-// FIXME: pi/2 rounded value is incorrect in the inst printer.
+// Check 1/(2*pi) rounded value and ideomatic fp32 0.15915494 value
+// which cannot be accurately represented in bf16.
 
 v_dot2_bf16_bf16 v2, v0, 0.158203125, v2
 // CHECK: v_dot2_bf16_bf16 v2, v0, 0.15915494, v2 ; encoding: [0x02,0x00,0x67,0xd6,0x00,0xf1,0x09,0x04]
 
+v_dot2_bf16_bf16 v2, v0, v2, 0.15915494
+// CHECK: v_dot2_bf16_bf16 v2, v0, v2, 0.15915494 ; encoding: [0x02,0x00,0x67,0xd6,0x00,0x05,0xe2,0x03]
+
 v_dot2_bf16_bf16 v2, v0, 0x3e22, v2
 // CHECK: v_dot2_bf16_bf16 v2, v0, 0.15915494, v2 ; encoding: [0x02,0x00,0x67,0xd6,0x00,0xf1,0x09,0x04]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82283


More information about the llvm-commits mailing list