[llvm] AMDGPU: Optimize set_rounding if input is known to fit in 2 bits (PR #88588)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 07:00:19 PDT 2024


================
@@ -4073,34 +4073,55 @@ SDValue SITargetLowering::lowerSET_ROUNDING(SDValue Op,
       NewMode = DAG.getConstant(
           AMDGPU::decodeFltRoundToHWConversionTable(ClampedVal), SL, MVT::i32);
   } else {
-    SDValue BitTable =
-      DAG.getConstant(AMDGPU::FltRoundToHWConversionTable, SL, MVT::i64);
-
+    // If we know the input can only be one of the supported standard modes in
+    // the range 0-3, we can use a simplified mapping to hardware values.
+    KnownBits KB = DAG.computeKnownBits(NewMode);
+    const bool UseReducedTable = KB.countMinLeadingZeros() >= 30;
     // The supported standard values are 0-3. The extended values start at 8. We
     // need to offset by 4 if the value is in the extended range.
 
-    // is_standard = value < 4;
-    // table_index = is_standard ? value : (value - 4)
-    // MODE.fp_round = (bit_table >> table_index) & 0xf
-
-    SDValue Four = DAG.getConstant(4, SL, MVT::i32);
-    SDValue IsStandardValue =
-      DAG.getSetCC(SL, MVT::i1, NewMode, Four, ISD::SETULT);
-    SDValue OffsetEnum = DAG.getNode(ISD::SUB, SL, MVT::i32, NewMode, Four);
-    SDValue IndexVal = DAG.getNode(ISD::SELECT, SL, MVT::i32, IsStandardValue,
-                                   NewMode, OffsetEnum);
+    if (UseReducedTable) {
+      // Truncate to the low 32-bits.
+      SDValue BitTable = DAG.getConstant(
+        AMDGPU::FltRoundToHWConversionTable & 0xffff, SL, MVT::i32);
 
-    SDValue Two = DAG.getConstant(2, SL, MVT::i32);
-    SDValue RoundModeTimesNumBits =
-      DAG.getNode(ISD::SHL, SL, MVT::i32, IndexVal, Two);
+      SDValue Two = DAG.getConstant(2, SL, MVT::i32);
+      SDValue RoundModeTimesNumBits =
+        DAG.getNode(ISD::SHL, SL, MVT::i32, NewMode, Two);
 
-    SDValue TableValue =
-      DAG.getNode(ISD::SRL, SL, MVT::i64, BitTable, RoundModeTimesNumBits);
-    SDValue TruncTable = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, TableValue);
+      SDValue TableValue =
+        DAG.getNode(ISD::SRL, SL, MVT::i32, BitTable, RoundModeTimesNumBits);
+      NewMode = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, TableValue);
----------------
jayfoad wrote:

Don't need this? It's already i32.

https://github.com/llvm/llvm-project/pull/88588


More information about the llvm-commits mailing list