[llvm] bd6eb54 - [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (#83859)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 04:09:10 PST 2024


Author: Paul Walker
Date: 2024-03-08T12:09:05Z
New Revision: bd6eb54886ad12fb523e924e7291abfa2b010e3c

URL: https://github.com/llvm/llvm-project/commit/bd6eb54886ad12fb523e924e7291abfa2b010e3c
DIFF: https://github.com/llvm/llvm-project/commit/bd6eb54886ad12fb523e924e7291abfa2b010e3c.diff

LOG: [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (#83859)

This removes, at least when a vector library is available, a failure
case for scalable vectors. Doing so means we can confidently cost vector
FREM instructions without making an assumption that later passes will
transform the IR before it gets to the code generator.

NOTE: Whilst only FREM has been implemented the same mechanism
can be used for the other libm related ISD nodes.

Added: 
    llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/TargetPassConfig.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6074498d9144ff..57a3f6a65e002c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -28,6 +28,8 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
   void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
   void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                            SmallVectorImpl<SDValue> &Results);
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
+                            RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
+                            RTLIB::Libcall Call_F128,
+                            RTLIB::Libcall Call_PPCF128,
+                            SmallVectorImpl<SDValue> &Results);
+
   void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VP_MERGE:
     Results.push_back(ExpandVP_MERGE(Node));
     return;
+  case ISD::FREM:
+    if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
+                             RTLIB::REM_F80, RTLIB::REM_F128,
+                             RTLIB::REM_PPCF128, Results))
+      return;
+
+    break;
   }
 
   SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
+// Try to expand libm nodes into vector math routine calls. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. The only mappings considered are those where the
+// result and all operands are the same vector type. While predicated nodes are
+// not supported, we will emit calls to masked routines by passing in an all
+// true mask.
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                                           SmallVectorImpl<SDValue> &Results) {
+  // Chain must be propagated but currently strict fp operations are down
+  // converted to their none strict counterpart.
+  assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+  const char *LCName = TLI.getLibcallName(LC);
+  if (!LCName)
+    return false;
+  LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+  EVT VT = Node->getValueType(0);
+  ElementCount VL = VT.getVectorElementCount();
+
+  // Lookup a vector function equivalent to the specified libcall. Prefer
+  // unmasked variants but we will generate a mask if need be.
+  const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+  const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+  if (!VD)
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true);
+  if (!VD)
+    return false;
+
+  LLVMContext *Ctx = DAG.getContext();
+  Type *Ty = VT.getTypeForEVT(*Ctx);
+  Type *ScalarTy = Ty->getScalarType();
+
+  // Construct a scalar function type based on Node's operands.
+  SmallVector<Type *, 8> ArgTys;
+  for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+    assert(Node->getOperand(i).getValueType() == VT &&
+           "Expected matching vector types!");
+    ArgTys.push_back(ScalarTy);
+  }
+  FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+  // Generate call information for the vector function.
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptVFInfo)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+                    << "\n");
+
+  // Sanity check just in case OptVFInfo has unexpected parameters.
+  if (OptVFInfo->Shape.Parameters.size() !=
+      Node->getNumOperands() + VD->isMasked())
+    return false;
+
+  // Collect vector call operands.
+
+  SDLoc DL(Node);
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.IsSExt = false;
+  Entry.IsZExt = false;
+
+  unsigned OpNum = 0;
+  for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+    if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+      EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+      Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+      Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+      Args.push_back(Entry);
+      continue;
+    }
+
+    // Only vector operands are supported.
+    if (VFParam.ParamKind != VFParamKind::Vector)
+      return false;
+
+    Entry.Node = Node->getOperand(OpNum++);
+    Entry.Ty = Ty;
+    Args.push_back(Entry);
+  }
+
+  // Emit a call to the vector function.
+  SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(DAG.getEntryNode())
+      .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+  std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+  Results.push_back(CallResult.first);
+  return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+    SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
+    RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
+    RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
+  RTLIB::Libcall LC = RTLIB::getFPLibCall(
+      Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
+      Call_F80, Call_F128, Call_PPCF128);
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL)
+    return false;
+
+  return tryExpandVecMathCall(Node, LC, Results);
+}
+
 void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   EVT VT = Node->getValueType(0);

diff  --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index cf068ece8d4cab..8832b51333d910 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
 static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
     cl::desc("Run live interval analysis earlier in the pipeline"));
 
+static cl::opt<bool> DisableReplaceWithVecLib(
+    "disable-replace-with-vec-lib", cl::Hidden,
+    cl::desc("Disable replace with vector math call pass"));
+
 /// Option names for limiting the codegen pipeline.
 /// Those are used in error reporting and we didn't want
 /// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
   if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
     addPass(createConstantHoistingPass());
 
-  if (getOptLevel() != CodeGenOptLevel::None)
+  if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
     addPass(createReplaceWithVeclibLegacyPass());
 
   if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)

diff  --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
new file mode 100644
index 00000000000000..67c056c780cc80
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
+; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
+; ARMPL-LABEL: frem_v2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f64
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_v2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN2vv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <2 x double> %a, %b
+  ret <2 x double> %res
+}
+
+define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
+; ARMPL-LABEL: frem_strict_v4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f32
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_v4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN4vv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <4 x float> %a, %b
+  ret <4 x float> %res
+}
+
+define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; ARMPL-LABEL: frem_nxv4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.s
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f32_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_nxv4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.s
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
+; ARMPL-LABEL: frem_strict_nxv2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.d
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f64_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_nxv2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.d
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %res
+}
+
+attributes #0 = { "target-features"="+sve" }
+attributes #1 = { "target-features"="+sve" strictfp }


        


More information about the llvm-commits mailing list