[llvm] [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (PR #83859)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 08:09:33 PST 2024


https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/83859

This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes.

>From 11cbbc636fa3c574e64b9ebeb8e5ff52f583c1f0 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 1 Mar 2024 18:49:53 +0000
Subject: [PATCH] [LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a
 vector math call.

This removes, at least when a vector library is available, a failure
case for scalable vectors. Doing so means we can confidently cost
vector FREM instructions without making an assumption that later
passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism
can be used for the other libm related ISD nodes.
---
 .../SelectionDAG/LegalizeVectorOps.cpp        | 127 ++++++++++++++++++
 llvm/lib/CodeGen/TargetPassConfig.cpp         |   6 +-
 .../CodeGen/AArch64/fp-veclib-expansion.ll    | 116 ++++++++++++++++
 3 files changed, 248 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6074498d9144ff..ebd6f62a63ac4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -28,6 +28,8 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
   void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
   void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                            SmallVectorImpl<SDValue> &Results);
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
+                            RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
+                            RTLIB::Libcall Call_F128,
+                            RTLIB::Libcall Call_PPCF128,
+                            SmallVectorImpl<SDValue> &Results);
+
   void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VP_MERGE:
     Results.push_back(ExpandVP_MERGE(Node));
     return;
+  case ISD::FREM:
+    if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
+                             RTLIB::REM_F80, RTLIB::REM_F128,
+                             RTLIB::REM_PPCF128, Results))
+      return;
+
+    break;
   }
 
   SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,116 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
+// Try to expand libm nodes into a call to a vector math. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. Only simply mappings are considered whereby only
+// matching vector operands are allowed and masked functions are passed an all
+// true vector (i.e. Node cannot be a predicated operation).
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                                           SmallVectorImpl<SDValue> &Results) {
+  // Chain must be propagated but currently strict fp operations are down
+  // converted to their none strict counterpart.
+  assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+  const char *LCName = TLI.getLibcallName(LC);
+  if (!LCName)
+    return false;
+  LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+  EVT VT = Node->getValueType(0);
+  ElementCount VL = VT.getVectorElementCount();
+
+  // Lookup a vector function equivalent to the specified libcall. Prefer
+  // unmasked variants but we will generate a mask if need be.
+  const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+  const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+  if (!VD)
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
+  if (!VD)
+    return false;
+
+  LLVMContext *Ctx = DAG.getContext();
+  Type *Ty = VT.getTypeForEVT(*Ctx);
+  Type *ScalarTy = Ty->getScalarType();
+
+  // Construct a scalar function type based on Node's operands.
+  SmallVector<Type *, 8> ArgTys;
+  for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+    assert(Node->getOperand(i).getValueType() == VT &&
+           "Expected matching vector types!");
+    ArgTys.push_back(ScalarTy);
+  }
+  FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+  // Generate call information for the vector function.
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptVFInfo)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+                    << "\n");
+
+  // Sanity check just in case OptVFInfo has unexpected paramaters.
+  if (OptVFInfo->Shape.Parameters.size() !=
+      Node->getNumOperands() + VD->isMasked())
+    return false;
+
+  // Collect vector call operands.
+
+  SDLoc DL(Node);
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.IsSExt = false;
+  Entry.IsZExt = false;
+
+  unsigned OpNum = 0;
+  for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+    if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+      EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+      Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+      Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+      Args.push_back(Entry);
+      continue;
+    }
+
+    // Only vector operands are supported.
+    if (VFParam.ParamKind != VFParamKind::Vector)
+      return false;
+
+    Entry.Node = Node->getOperand(OpNum++);
+    Entry.Ty = Ty;
+    Args.push_back(Entry);
+  }
+
+  // Emit a call to the vector function.
+  SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(DAG.getEntryNode())
+      .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+  std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+  Results.push_back(CallResult.first);
+  return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+    SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
+    RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
+    RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
+  RTLIB::Libcall LC = RTLIB::getFPLibCall(
+      Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
+      Call_F80, Call_F128, Call_PPCF128);
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL)
+    return false;
+
+  return tryExpandVecMathCall(Node, LC, Results);
+}
+
 void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index cf068ece8d4cab..8832b51333d910 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
 static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
     cl::desc("Run live interval analysis earlier in the pipeline"));
 
+static cl::opt<bool> DisableReplaceWithVecLib(
+    "disable-replace-with-vec-lib", cl::Hidden,
+    cl::desc("Disable replace with vector math call pass"));
+
 /// Option names for limiting the codegen pipeline.
 /// Those are used in error reporting and we didn't want
 /// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
   if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
     addPass(createConstantHoistingPass());
 
-  if (getOptLevel() != CodeGenOptLevel::None)
+  if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
     addPass(createReplaceWithVeclibLegacyPass());
 
   if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
diff --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
new file mode 100644
index 00000000000000..67c056c780cc80
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
+; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
+; ARMPL-LABEL: frem_v2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f64
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_v2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN2vv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <2 x double> %a, %b
+  ret <2 x double> %res
+}
+
+define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
+; ARMPL-LABEL: frem_strict_v4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f32
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_v4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN4vv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <4 x float> %a, %b
+  ret <4 x float> %res
+}
+
+define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; ARMPL-LABEL: frem_nxv4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.s
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f32_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_nxv4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.s
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
+; ARMPL-LABEL: frem_strict_nxv2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.d
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f64_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_nxv2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.d
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %res
+}
+
+attributes #0 = { "target-features"="+sve" }
+attributes #1 = { "target-features"="+sve" strictfp }



More information about the llvm-commits mailing list