[llvm] [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (PR #83859)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 06:30:15 PST 2024


================
@@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
+// Try to expand libm nodes into vector math routine calls. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. The only mappings considered are those where the
+// result and all operands are the same vector type. While predicated nodes are
+// not supported, we will emit calls to masked routines by passing in an all
+// true mask.
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                                           SmallVectorImpl<SDValue> &Results) {
+  // Chain must be propagated but currently strict fp operations are down
+  // converted to their none strict counterpart.
+  assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+  const char *LCName = TLI.getLibcallName(LC);
+  if (!LCName)
+    return false;
+  LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+  EVT VT = Node->getValueType(0);
+  ElementCount VL = VT.getVectorElementCount();
+
+  // Lookup a vector function equivalent to the specified libcall. Prefer
+  // unmasked variants but we will generate a mask if need be.
+  const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+  const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+  if (!VD)
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true);
+  if (!VD)
+    return false;
+
+  LLVMContext *Ctx = DAG.getContext();
+  Type *Ty = VT.getTypeForEVT(*Ctx);
+  Type *ScalarTy = Ty->getScalarType();
+
+  // Construct a scalar function type based on Node's operands.
+  SmallVector<Type *, 8> ArgTys;
+  for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+    assert(Node->getOperand(i).getValueType() == VT &&
+           "Expected matching vector types!");
+    ArgTys.push_back(ScalarTy);
+  }
+  FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+  // Generate call information for the vector function.
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptVFInfo)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+                    << "\n");
+
+  // Sanity check just in case OptVFInfo has unexpected parameters.
+  if (OptVFInfo->Shape.Parameters.size() !=
+      Node->getNumOperands() + VD->isMasked())
+    return false;
+
+  // Collect vector call operands.
+
+  SDLoc DL(Node);
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.IsSExt = false;
+  Entry.IsZExt = false;
+
+  unsigned OpNum = 0;
+  for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+    if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+      EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+      Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+      Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+      Args.push_back(Entry);
+      continue;
+    }
+
+    // Only vector operands are supported.
+    if (VFParam.ParamKind != VFParamKind::Vector)
+      return false;
+
+    Entry.Node = Node->getOperand(OpNum++);
+    Entry.Ty = Ty;
+    Args.push_back(Entry);
+  }
+
+  // Emit a call to the vector function.
+  SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(DAG.getEntryNode())
+      .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+  std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+  Results.push_back(CallResult.first);
+  return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+    SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
----------------
paulwalker-arm wrote:

I'm expecting this function to be used by other ISD nodes (e.g. `ISD::FSIN`, `ISD::FSIN` etc) so I followed the idiom used by `ExpandFPLibCall`.

https://github.com/llvm/llvm-project/pull/83859


More information about the llvm-commits mailing list