[llvm] [AMDGPU] Filter candidates of LiveRegOptimizer for profitable cases (PR #124624)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 14 15:57:33 PST 2025


================
@@ -125,8 +131,210 @@ class LiveRegOptimizer {
     return LK.first != TargetLoweringBase::TypeLegal;
   }
 
-  LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST)
-      : Mod(Mod), DL(Mod.getDataLayout()), ST(ST),
+  // Filtering based on operation or its cost.
+  // If an operation incurs high enough cost or natively work on
+  // vector of illegal type, ie. v2i8, then it makes sense to try
+  // to coerce them as packed VGPR across BB.
+  bool shouldReplaceByOp(Instruction *II) {
+    static const int SCALARIZE_INST_COST = 2;
+    static const int LRO_COST_THRES = 12;
+
+    // Ignore pseudos
+    if (II->isDebugOrPseudoInst())
+      return false;
+
+    // Instruction Cost
+    auto Cost = TTI.getInstructionCost(
+        II, TargetTransformInfo::TargetCostKind::TCK_SizeAndLatency);
+    if (const auto *Def = II->getOperand(0)) {
+      if (const auto *DefTy = dyn_cast<FixedVectorType>(Def->getType())) {
+        const auto *ElTy = dyn_cast<IntegerType>(DefTy->getElementType());
+        // Assume vNi8 and vNi16 will be scalarized.
+        if (ElTy && ElTy->getBitWidth() <= 16) {
+          const auto ElCount = DefTy->getElementCount().getFixedValue();
+          Cost += SCALARIZE_INST_COST * ElCount;
+        }
+      }
+    }
+    LLVM_DEBUG(dbgs() << "shouldReplaceByOp: " << *II << " Cost=" << Cost
+                      << '\n';);
+    if (Cost >= LRO_COST_THRES)
+      return true;
+
+    if (isOpLegal(II))
+      return true;
+
+    return false;
+  }
+
+  /// Check if intrinsic natively operates on 8-bit or 16-bit
+  bool isNativeIntrinsic(Intrinsic::ID ID) {
+    switch (ID) {
+    case Intrinsic::amdgcn_dot4_f32_fp8_bf8:
+    case Intrinsic::amdgcn_dot4_f32_bf8_fp8:
+    case Intrinsic::amdgcn_dot4_f32_fp8_fp8:
+    case Intrinsic::amdgcn_dot4_f32_bf8_bf8:
+    case Intrinsic::amdgcn_fdot2_f16_f16:
+    case Intrinsic::amdgcn_fdot2:
+    case Intrinsic::amdgcn_sdot4:
+    case Intrinsic::amdgcn_sdot2:
+    case Intrinsic::amdgcn_sdot8:
+    case Intrinsic::amdgcn_udot2:
+    case Intrinsic::amdgcn_udot4:
+    case Intrinsic::amdgcn_udot8:
+    case Intrinsic::amdgcn_sudot4:
+    case Intrinsic::amdgcn_sudot8:
+    case Intrinsic::amdgcn_mfma_f32_4x4x1f32:
+    case Intrinsic::amdgcn_mfma_f32_16x16x1f32:
+    case Intrinsic::amdgcn_mfma_f32_16x16x4f32:
+    case Intrinsic::amdgcn_mfma_f32_32x32x1f32:
+    case Intrinsic::amdgcn_mfma_f32_32x32x2f32:
+    case Intrinsic::amdgcn_mfma_f32_4x4x4f16:
+    case Intrinsic::amdgcn_mfma_i32_4x4x4i8:
+    case Intrinsic::amdgcn_mfma_f32_16x16x4f16:
+    case Intrinsic::amdgcn_mfma_f32_16x16x16f16:
+    case Intrinsic::amdgcn_mfma_i32_16x16x4i8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x4f16:
+    case Intrinsic::amdgcn_mfma_f32_32x32x8f16:
+    case Intrinsic::amdgcn_mfma_i32_32x32x4i8:
+    case Intrinsic::amdgcn_mfma_i32_16x16x16i8:
+    case Intrinsic::amdgcn_mfma_i32_32x32x8i8:
+    case Intrinsic::amdgcn_mfma_f32_4x4x2bf16:
+    case Intrinsic::amdgcn_mfma_f32_16x16x2bf16:
+    case Intrinsic::amdgcn_mfma_f32_16x16x8bf16:
+    case Intrinsic::amdgcn_mfma_f32_32x32x2bf16:
+    case Intrinsic::amdgcn_mfma_f32_32x32x4bf16:
+    case Intrinsic::amdgcn_mfma_f32_16x16x32_f16:
+    case Intrinsic::amdgcn_mfma_f32_32x32x16_f16:
+    case Intrinsic::amdgcn_mfma_i32_16x16x64_i8:
+    case Intrinsic::amdgcn_mfma_i32_32x32x32_i8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x4bf16_1k:
+    case Intrinsic::amdgcn_mfma_f32_16x16x4bf16_1k:
+    case Intrinsic::amdgcn_mfma_f32_4x4x4bf16_1k:
+    case Intrinsic::amdgcn_mfma_f32_32x32x8bf16_1k:
+    case Intrinsic::amdgcn_mfma_f32_16x16x16bf16_1k:
+    case Intrinsic::amdgcn_mfma_f64_16x16x4f64:
+    case Intrinsic::amdgcn_mfma_f64_4x4x4f64:
+    case Intrinsic::amdgcn_mfma_i32_32x32x16_i8:
+    case Intrinsic::amdgcn_mfma_i32_16x16x32_i8:
+    case Intrinsic::amdgcn_mfma_f32_16x16x8_xf32:
+    case Intrinsic::amdgcn_mfma_f32_32x32x4_xf32:
+    case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_bf8:
+    case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_fp8:
+    case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_bf8:
+    case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_fp8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_bf8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_fp8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_bf8:
+    case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x16_bf16:
+    case Intrinsic::amdgcn_smfmac_i32_16x16x64_i8:
+    case Intrinsic::amdgcn_smfmac_i32_32x32x32_i8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x64_f16:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x32_f16:
+    case Intrinsic::amdgcn_smfmac_i32_16x16x128_i8:
+    case Intrinsic::amdgcn_smfmac_i32_32x32x64_i8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_fp8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_bf8:
+    case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_fp8:
+    case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
+    case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_f16:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_bf16:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8:
+    case Intrinsic::amdgcn_wmma_f16_16x16x16_f16:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_f16:
+    case Intrinsic::amdgcn_swmmac_f16_16x16x32_f16:
+    case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16:
+    case Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8:
+    case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8:
+    case Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied:
+    case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied:
+    case Intrinsic::amdgcn_wmma_i32_16x16x16_iu8:
+    case Intrinsic::amdgcn_wmma_i32_16x16x16_iu4:
+    case Intrinsic::amdgcn_wmma_i32_16x16x32_iu4:
+    case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8:
+    case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4:
+    case Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4:
+      return true;
+    default:
+      return false;
+    }
+  }
+
+  bool isOpLegal(Instruction *I) {
+    Type *T = I->getType();
+    if (!TTI.isTypeLegal(T)) {
+      if (const auto Intr = dyn_cast<IntrinsicInst>(I)) {
+        Intrinsic::ID ID = Intr->getIntrinsicID();
+        if (isNativeIntrinsic(ID))
+          return true;
+      }
+      // Stores
+      if (isa<StoreInst>(I))
+        return true;
+      return false;
+    }
+    return true;
+  }
+
+  bool isCoercionProfitable(Instruction *II) {
----------------
jrbyrnes wrote:

Something like this: 

```
  bool isCoercionProfitable(Instruction *II) {
    SmallPtrSet<Instruction *, 4> CVisited;
    SmallVector<Instruction *, 4> UserList;

    // Check users for profitable conditions (across block user which can natively
    // handle the illegal vector).
    for (User *V : II->users())
      if (auto *UseInst = dyn_cast<Instruction>(V))
        UserList.push_back(UseInst);

    auto IsLookThru = [](Instruction *II) {
      return isa<PHINode>(II) || isa<ShuffleVectorInst>(II) ||
          isa<InsertElementInst>(II) || isa<ExtractElementInst>(II) || isa<CastInst>(II);
    };

    while (!UserList.empty()) {
      auto CII = UserList.pop_back_val();
      if (!CVisited.insert(CII).second)
        continue;    

      if (CII->getParent() == II->getParent() && !IsLookThru(II))
        continue;
      
      if (isOpLegal(CII))
        return true;

      if (IsLookThru(CII))
        for (User *V : CII->users())
          if (auto *UseInst = dyn_cast<Instruction>(V))
            UserList.push_back(UseInst);
    }
    return false;
  }
```

https://github.com/llvm/llvm-project/pull/124624


More information about the llvm-commits mailing list