[llvm] [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (PR #83859)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 04:59:32 PST 2024


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/83859

>From 11cbbc636fa3c574e64b9ebeb8e5ff52f583c1f0 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 1 Mar 2024 18:49:53 +0000
Subject: [PATCH 1/2] [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to
 a vector math call.

This removes, at least when a vector library is available, a failure
case for scalable vectors. Doing so means we can confidently cost
vector FREM instructions without making an assumption that later
passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism
can be used for the other libm related ISD nodes.
---
 .../SelectionDAG/LegalizeVectorOps.cpp        | 127 ++++++++++++++++++
 llvm/lib/CodeGen/TargetPassConfig.cpp         |   6 +-
 .../CodeGen/AArch64/fp-veclib-expansion.ll    | 116 ++++++++++++++++
 3 files changed, 248 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6074498d9144ff..ebd6f62a63ac4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -28,6 +28,8 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
   void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
   void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                            SmallVectorImpl<SDValue> &Results);
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
+                            RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
+                            RTLIB::Libcall Call_F128,
+                            RTLIB::Libcall Call_PPCF128,
+                            SmallVectorImpl<SDValue> &Results);
+
   void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VP_MERGE:
     Results.push_back(ExpandVP_MERGE(Node));
     return;
+  case ISD::FREM:
+    if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
+                             RTLIB::REM_F80, RTLIB::REM_F128,
+                             RTLIB::REM_PPCF128, Results))
+      return;
+
+    break;
   }
 
   SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,116 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
+// Try to expand libm nodes into a call to a vector math. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. Only simply mappings are considered whereby only
+// matching vector operands are allowed and masked functions are passed an all
+// true vector (i.e. Node cannot be a predicated operation).
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                                           SmallVectorImpl<SDValue> &Results) {
+  // Chain must be propagated but currently strict fp operations are down
+  // converted to their none strict counterpart.
+  assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+  const char *LCName = TLI.getLibcallName(LC);
+  if (!LCName)
+    return false;
+  LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+  EVT VT = Node->getValueType(0);
+  ElementCount VL = VT.getVectorElementCount();
+
+  // Lookup a vector function equivalent to the specified libcall. Prefer
+  // unmasked variants but we will generate a mask if need be.
+  const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+  const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+  if (!VD)
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
+  if (!VD)
+    return false;
+
+  LLVMContext *Ctx = DAG.getContext();
+  Type *Ty = VT.getTypeForEVT(*Ctx);
+  Type *ScalarTy = Ty->getScalarType();
+
+  // Construct a scalar function type based on Node's operands.
+  SmallVector<Type *, 8> ArgTys;
+  for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+    assert(Node->getOperand(i).getValueType() == VT &&
+           "Expected matching vector types!");
+    ArgTys.push_back(ScalarTy);
+  }
+  FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+  // Generate call information for the vector function.
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptVFInfo)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+                    << "\n");
+
+  // Sanity check just in case OptVFInfo has unexpected paramaters.
+  if (OptVFInfo->Shape.Parameters.size() !=
+      Node->getNumOperands() + VD->isMasked())
+    return false;
+
+  // Collect vector call operands.
+
+  SDLoc DL(Node);
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.IsSExt = false;
+  Entry.IsZExt = false;
+
+  unsigned OpNum = 0;
+  for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+    if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+      EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+      Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+      Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+      Args.push_back(Entry);
+      continue;
+    }
+
+    // Only vector operands are supported.
+    if (VFParam.ParamKind != VFParamKind::Vector)
+      return false;
+
+    Entry.Node = Node->getOperand(OpNum++);
+    Entry.Ty = Ty;
+    Args.push_back(Entry);
+  }
+
+  // Emit a call to the vector function.
+  SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(DAG.getEntryNode())
+      .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+  std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+  Results.push_back(CallResult.first);
+  return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+    SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
+    RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
+    RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
+  RTLIB::Libcall LC = RTLIB::getFPLibCall(
+      Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
+      Call_F80, Call_F128, Call_PPCF128);
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL)
+    return false;
+
+  return tryExpandVecMathCall(Node, LC, Results);
+}
+
 void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index cf068ece8d4cab..8832b51333d910 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
 static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
     cl::desc("Run live interval analysis earlier in the pipeline"));
 
+static cl::opt<bool> DisableReplaceWithVecLib(
+    "disable-replace-with-vec-lib", cl::Hidden,
+    cl::desc("Disable replace with vector math call pass"));
+
 /// Option names for limiting the codegen pipeline.
 /// Those are used in error reporting and we didn't want
 /// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
   if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
     addPass(createConstantHoistingPass());
 
-  if (getOptLevel() != CodeGenOptLevel::None)
+  if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
     addPass(createReplaceWithVeclibLegacyPass());
 
   if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
diff --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
new file mode 100644
index 00000000000000..67c056c780cc80
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
+; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
+; ARMPL-LABEL: frem_v2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f64
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_v2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN2vv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <2 x double> %a, %b
+  ret <2 x double> %res
+}
+
+define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
+; ARMPL-LABEL: frem_strict_v4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f32
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_v4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN4vv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <4 x float> %a, %b
+  ret <4 x float> %res
+}
+
+define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; ARMPL-LABEL: frem_nxv4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.s
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f32_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_nxv4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.s
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
+; ARMPL-LABEL: frem_strict_nxv2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.d
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f64_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_nxv2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.d
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %res
+}
+
+attributes #0 = { "target-features"="+sve" }
+attributes #1 = { "target-features"="+sve" strictfp }

>From ae0751103ee441c9e441e56984fb0d9b7c94e531 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Thu, 7 Mar 2024 12:57:01 +0000
Subject: [PATCH 2/2] Reword function comment and fix typos.

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ebd6f62a63ac4d..567ec7df0d2dbe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1859,11 +1859,12 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
-// Try to expand libm nodes into a call to a vector math. Callers provide the
+// Try to expand libm nodes into vector math routine calls. Callers provide the
 // LibFunc equivalent of the passed in Node, which is used to lookup mappings
-// within TargetLibraryInfo. Only simply mappings are considered whereby only
-// matching vector operands are allowed and masked functions are passed an all
-// true vector (i.e. Node cannot be a predicated operation).
+// within TargetLibraryInfo. The only mappings considered are those where the
+// result and all operands are the same vector type. While predicated nodes are
+// not supported, we will emit calls to masked routines by passing in an all
+// true mask.
 bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
                                            SmallVectorImpl<SDValue> &Results) {
   // Chain must be propagated but currently strict fp operations are down
@@ -1883,7 +1884,7 @@ bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
   const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
   const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
   if (!VD)
-    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/ true);
   if (!VD)
     return false;
 
@@ -1909,7 +1910,7 @@ bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
   LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
                     << "\n");
 
-  // Sanity check just in case OptVFInfo has unexpected paramaters.
+  // Sanity check just in case OptVFInfo has unexpected parameters.
   if (OptVFInfo->Shape.Parameters.size() !=
       Node->getNumOperands() + VD->isMasked())
     return false;



More information about the llvm-commits mailing list