[llvm] [C API] Add getters and setters for fast-math flags on relevant instructions (PR #75123)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 17:05:55 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Benji Smith (Benjins)

<details>
<summary>Changes</summary>

These flags are usable on floating point arithmetic, as well as call, select, and phi instructions whose resulting type is floating point, or a vector of, or an array of, a valid type. Whether or not the flags are valid for a given instruction can be checked with the new LLVMGetCanUseFastMathFlags function

These are exposed using a new LLVMFastMathFlags type, which is an alias for unsigned. An anonymous enum defines the bit values for it

Tests are added in echo.ll for select/phil/call, and the floating point types in the new float_ops.ll bindings test

Select and the floating point arithmetic instructions were not implemented in llvm-c-test/echo.cpp, so they were added as well

---
Full diff: https://github.com/llvm/llvm-project/pull/75123.diff


6 Files Affected:

- (modified) llvm/docs/ReleaseNotes.rst (+4) 
- (modified) llvm/include/llvm-c/Core.h (+46) 
- (modified) llvm/lib/IR/Core.cpp (+62) 
- (modified) llvm/test/Bindings/llvm-c/echo.ll (+35) 
- (added) llvm/test/Bindings/llvm-c/float_ops.ll (+155) 
- (modified) llvm/tools/llvm-c-test/echo.cpp (+63) 


``````````diff
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index d5c634d2f29af..757c66b1d3df4 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -226,6 +226,10 @@ Changes to the C API
   * ``LLVMGetOperandBundleArgAtIndex``
   * ``LLVMGetOperandBundleTag``
 
+* Added ``LLVMGetFastMathFlags`` and ``LLVMSetFastMathFlags`` for getting/setting
+  the fast-math flags of an instruction, as well as ``LLVMGetCanUseFastMathFlags``
+  for checking if an instruction can use such flags
+
 Changes to the CodeGen infrastructure
 -------------------------------------
 
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 7cb809d378c95..43a64378587df 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -483,6 +483,26 @@ typedef enum {
 
 typedef unsigned LLVMAttributeIndex;
 
+enum {
+  LLVMFastMathAllowReassoc = (1 << 0),
+  LLVMFastMathNoNaNs = (1 << 1),
+  LLVMFastMathNoInfs = (1 << 2),
+  LLVMFastMathNoSignedZeros = (1 << 3),
+  LLVMFastMathAllowReciprocal = (1 << 4),
+  LLVMFastMathAllowContract = (1 << 5),
+  LLVMFastMathApproxFunc = (1 << 6),
+  LLVMFastMathNone = 0,
+  LLVMFastMathAll = -1,
+};
+
+/**
+ * Flags to indicate what fast-math-style optimizations are allowed
+ * on operations
+ *
+ * See https://llvm.org/docs/LangRef.html#fast-math-flags
+ */
+typedef unsigned LLVMFastMathFlags;
+
 /**
  * @}
  */
@@ -4075,6 +4095,32 @@ LLVMBool LLVMGetNNeg(LLVMValueRef NonNegInst);
  */
 void LLVMSetNNeg(LLVMValueRef NonNegInst, LLVMBool IsNonNeg);
 
+/**
+ * Get the flags for which fast-math-style optimizations are allowed for this
+ * value
+ *
+ * Only valid on floating point instructions
+ * @see LLVMGetCanUseFastMathFlags
+ */
+LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst);
+/**
+ * Sets the flags for which fast-math-style optimizations are allowed for this
+ * value
+ *
+ * Only valid on floating point instructions
+ * @see LLVMGetCanUseFastMathFlags
+ */
+void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF);
+
+/**
+ * Check if a given value can potentially have fast math flags
+ *
+ * Will return true for floating point arithmetic instructions, and for select,
+ * phil, and call instructions whose type is a floating point type, or a vector
+ * or array thereof See https://llvm.org/docs/LangRef.html#fast-math-flags
+ */
+LLVMBool LLVMGetCanUseFastMathFlags(LLVMValueRef Inst);
+
 /**
  * Gets whether the instruction has the disjoint flag set.
  * Only valid for or instructions.
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index 96629de8a7534..dea6d551b19e8 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -3319,6 +3319,52 @@ void LLVMSetArgOperand(LLVMValueRef Funclet, unsigned i, LLVMValueRef value) {
 
 /*--.. Arithmetic ..........................................................--*/
 
+static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF) {
+  FastMathFlags NewFMF;
+  // First, check if all bits are set
+  // If not, check each one explicitly
+  if (FMF == static_cast<LLVMFastMathFlags>(LLVMFastMathAll))
+    NewFMF.set();
+  else {
+    NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0);
+    NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0);
+    NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0);
+    NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0);
+    NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0);
+    NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0);
+    NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0);
+  }
+
+  return NewFMF;
+}
+
+static LLVMFastMathFlags mapToLLVMFastMathFlags(FastMathFlags FMF) {
+
+  // First, check if all bits are set
+  // If not, check each one explicitly
+  if (FMF.isFast())
+    return LLVMFastMathAll;
+  else {
+    LLVMFastMathFlags NewFMF = LLVMFastMathNone;
+    if (FMF.allowReassoc())
+      NewFMF |= LLVMFastMathAllowReassoc;
+    if (FMF.noNaNs())
+      NewFMF |= LLVMFastMathNoNaNs;
+    if (FMF.noInfs())
+      NewFMF |= LLVMFastMathNoInfs;
+    if (FMF.noSignedZeros())
+      NewFMF |= LLVMFastMathNoSignedZeros;
+    if (FMF.allowReciprocal())
+      NewFMF |= LLVMFastMathAllowReciprocal;
+    if (FMF.allowContract())
+      NewFMF |= LLVMFastMathAllowContract;
+    if (FMF.approxFunc())
+      NewFMF |= LLVMFastMathApproxFunc;
+
+    return NewFMF;
+  }
+}
+
 LLVMValueRef LLVMBuildAdd(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS,
                           const char *Name) {
   return wrap(unwrap(B)->CreateAdd(unwrap(LHS), unwrap(RHS), Name));
@@ -3518,6 +3564,22 @@ void LLVMSetNNeg(LLVMValueRef NonNegInst, LLVMBool IsNonNeg) {
   cast<Instruction>(P)->setNonNeg(IsNonNeg);
 }
 
+LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst) {
+  Value *P = unwrap<Value>(FPMathInst);
+  FastMathFlags FMF = cast<Instruction>(P)->getFastMathFlags();
+  return mapToLLVMFastMathFlags(FMF);
+}
+
+void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) {
+  Value *P = unwrap<Value>(FPMathInst);
+  cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
+}
+
+LLVMBool LLVMGetCanUseFastMathFlags(LLVMValueRef V) {
+  Value *Val = unwrap<Value>(V);
+  return isa<FPMathOperator>(Val);
+}
+
 LLVMBool LLVMGetIsDisjoint(LLVMValueRef Inst) {
   Value *P = unwrap<Value>(Inst);
   return cast<PossiblyDisjointInst>(P)->isDisjoint();
diff --git a/llvm/test/Bindings/llvm-c/echo.ll b/llvm/test/Bindings/llvm-c/echo.ll
index 2e195beebd7bb..be0207599478b 100644
--- a/llvm/test/Bindings/llvm-c/echo.ll
+++ b/llvm/test/Bindings/llvm-c/echo.ll
@@ -299,6 +299,41 @@ entry:
   ret void
 }
 
+define void @test_fast_math_flags(i1 %c, float %a, float %b) {
+entry:
+  %select.f.1 = select i1 %c, float %a, float %b
+  %select.f.2 = select nsz i1 %c, float %a, float %b
+  %select.f.3 = select fast i1 %c, float %a, float %b
+  %select.f.4 = select nnan arcp afn i1 %c, float %a, float %b
+
+  br i1 %c, label %choose_a, label %choose_b
+
+choose_a:
+  br label %final
+
+choose_b:
+  br label %final
+
+final:
+  %phi.f.1 = phi float  [ %a, %choose_a ], [ %b, %choose_b ]
+  %phi.f.2 = phi nsz float [ %a, %choose_a ], [ %b, %choose_b ]
+  %phi.f.3 = phi fast float [ %a, %choose_a ], [ %b, %choose_b ]
+  %phi.f.4 = phi nnan arcp afn float [ %a, %choose_a ], [ %b, %choose_b ]
+  ret void
+}
+
+define float @test_fast_math_flags_call_inner(float %a) {
+  ret float %a
+}
+
+define void @test_fast_math_flags_call_outer(float %a) {
+  %a.1 = call float @test_fast_math_flags_call_inner(float %a)
+  %a.2 = call nsz float @test_fast_math_flags_call_inner(float %a)
+  %a.3 = call fast float @test_fast_math_flags_call_inner(float %a)
+  %a.4 = call nnan arcp afn float @test_fast_math_flags_call_inner(float %a)
+  ret void
+}
+
 !llvm.dbg.cu = !{!0, !2}
 !llvm.module.flags = !{!3}
 
diff --git a/llvm/test/Bindings/llvm-c/float_ops.ll b/llvm/test/Bindings/llvm-c/float_ops.ll
new file mode 100644
index 0000000000000..e569ce26b8535
--- /dev/null
+++ b/llvm/test/Bindings/llvm-c/float_ops.ll
@@ -0,0 +1,155 @@
+; RUN: llvm-as < %s | llvm-dis > %t.orig
+; RUN: llvm-as < %s | llvm-c-test --echo > %t.echo
+; RUN: diff -w %t.orig %t.echo
+;
+source_filename = "/test/Bindings/float_ops.ll"
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-apple-macosx10.11.0"
+
+
+define float @float_ops_f32(float %a, float %b) {
+  %1 = fneg float %a
+
+  %2 = fadd float %a, %b
+  %3 = fsub float %a, %b
+  %4 = fmul float %a, %b
+  %5 = fdiv float %a, %b
+  %6 = frem float %a, %b
+
+  ret float %1
+}
+
+define double @float_ops_f64(double %a, double %b) {
+  %1 = fneg double %a
+
+  %2 = fadd double %a, %b
+  %3 = fsub double %a, %b
+  %4 = fmul double %a, %b
+  %5 = fdiv double %a, %b
+  %6 = frem double %a, %b
+
+  ret double %1
+}
+
+define void @float_cmp_f32(float %a, float %b) {
+  %1  = fcmp oeq float %a, %b
+  %2  = fcmp ogt float %a, %b
+  %3  = fcmp olt float %a, %b
+  %4  = fcmp ole float %a, %b
+  %5  = fcmp one float %a, %b
+
+  %6  = fcmp ueq float %a, %b
+  %7  = fcmp ugt float %a, %b
+  %8  = fcmp ult float %a, %b
+  %9  = fcmp ule float %a, %b
+  %10 = fcmp une float %a, %b
+
+  %11 = fcmp ord float %a, %b
+  %12 = fcmp false float %a, %b
+  %13 = fcmp true float %a, %b
+
+  ret void
+}
+
+define void @float_cmp_f64(double %a, double %b) {
+  %1  = fcmp oeq double %a, %b
+  %2  = fcmp ogt double %a, %b
+  %3  = fcmp olt double %a, %b
+  %4  = fcmp ole double %a, %b
+  %5  = fcmp one double %a, %b
+
+  %6  = fcmp ueq double %a, %b
+  %7  = fcmp ugt double %a, %b
+  %8  = fcmp ult double %a, %b
+  %9  = fcmp ule double %a, %b
+  %10 = fcmp une double %a, %b
+
+  %11 = fcmp ord double %a, %b
+  %12 = fcmp false double %a, %b
+  %13 = fcmp true double %a, %b
+
+  ret void
+}
+
+define void @float_cmp_fast_f32(float %a, float %b) {
+  %1  = fcmp fast oeq float %a, %b
+  %2  = fcmp nsz ogt float %a, %b
+  %3  = fcmp nsz nnan olt float %a, %b
+  %4  = fcmp contract ole float %a, %b
+  %5  = fcmp nnan one float %a, %b
+
+  %6  = fcmp nnan ninf nsz ueq float %a, %b
+  %7  = fcmp arcp ugt float %a, %b
+  %8  = fcmp fast ult float %a, %b
+  %9  = fcmp fast ule float %a, %b
+  %10 = fcmp fast une float %a, %b
+
+  %11 = fcmp fast ord float %a, %b
+  %12 = fcmp nnan ninf false float %a, %b
+  %13 = fcmp nnan ninf true float %a, %b
+
+  ret void
+}
+
+define void @float_cmp_fast_f64(double %a, double %b) {
+  %1  = fcmp fast oeq double %a, %b
+  %2  = fcmp nsz ogt double %a, %b
+  %3  = fcmp nsz nnan olt double %a, %b
+  %4  = fcmp contract ole double %a, %b
+  %5  = fcmp nnan one double %a, %b
+
+  %6  = fcmp nnan ninf nsz ueq double %a, %b
+  %7  = fcmp arcp ugt double %a, %b
+  %8  = fcmp fast ult double %a, %b
+  %9  = fcmp fast ule double %a, %b
+  %10 = fcmp fast une double %a, %b
+
+  %11 = fcmp fast ord double %a, %b
+  %12 = fcmp nnan ninf false double %a, %b
+  %13 = fcmp nnan ninf true double %a, %b
+
+  ret void
+}
+
+define float @float_ops_fast_f32(float %a, float %b) {
+  %1 = fneg nnan float %a
+
+  %2 = fadd ninf float %a, %b
+  %3 = fsub nsz float %a, %b
+  %4 = fmul arcp float %a, %b
+  %5 = fdiv contract float %a, %b
+  %6 = frem afn float %a, %b
+
+  %7 = fadd reassoc float %a, %b
+  %8 = fadd reassoc float %7, %b
+
+  %9  = fadd fast float %a, %b
+  %10 = fadd nnan nsz float %a, %b
+  %11 = frem nnan nsz float %a, %b
+  %12 = fdiv nnan nsz arcp float %a, %b
+  %13 = fmul nnan nsz ninf contract float %a, %b
+
+  ret float %1
+}
+
+define double @float_ops_fast_f64(double %a, double %b) {
+  %1 = fneg nnan double %a
+
+  %2 = fadd ninf double %a, %b
+  %3 = fsub nsz double %a, %b
+  %4 = fmul arcp double %a, %b
+  %5 = fdiv contract double %a, %b
+  %6 = frem afn double %a, %b
+
+  %7 = fadd reassoc double %a, %b
+  %8 = fadd reassoc double %7, %b
+
+  %9  = fadd fast double %a, %b
+  %10 = fadd nnan nsz double %a, %b
+  %11 = frem nnan nsz double %a, %b
+  %12 = fdiv nnan nsz arcp double %a, %b
+  %13 = fmul nnan nsz ninf contract double %a, %b
+
+  ret double %1
+}
+
diff --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp
index bfc14e85a12bf..84c39b06102f1 100644
--- a/llvm/tools/llvm-c-test/echo.cpp
+++ b/llvm/tools/llvm-c-test/echo.cpp
@@ -770,8 +770,19 @@ struct FunCloner {
         }
 
         LLVMAddIncoming(Dst, Values.data(), Blocks.data(), IncomingCount);
+        if (LLVMGetCanUseFastMathFlags(Src))
+          LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
         return Dst;
       }
+      case LLVMSelect: {
+        LLVMValueRef If = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef Then = CloneValue(LLVMGetOperand(Src, 1));
+        LLVMValueRef Else = CloneValue(LLVMGetOperand(Src, 2));
+        Dst = LLVMBuildSelect(Builder, If, Then, Else, Name);
+        if (LLVMGetCanUseFastMathFlags(Src))
+          LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
       case LLVMCall: {
         SmallVector<LLVMValueRef, 8> Args;
         SmallVector<LLVMOperandBundleRef, 8> Bundles;
@@ -790,6 +801,9 @@ struct FunCloner {
                                               ArgCount, Bundles.data(),
                                               Bundles.size(), Name);
         LLVMSetTailCallKind(Dst, LLVMGetTailCallKind(Src));
+        if (LLVMGetCanUseFastMathFlags(Src))
+          LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+
         CloneAttrs(Src, Dst);
         for (auto Bundle : Bundles)
           LLVMDisposeOperandBundle(Bundle);
@@ -930,6 +944,55 @@ struct FunCloner {
         LLVMSetNNeg(Dst, NNeg);
         break;
       }
+      case LLVMFAdd: {
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFAdd(Builder, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFSub: {
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFSub(Builder, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFMul: {
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFMul(Builder, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFDiv: {
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFDiv(Builder, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFRem: {
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFRem(Builder, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFNeg: {
+        LLVMValueRef Val = CloneValue(LLVMGetOperand(Src, 0));
+        Dst = LLVMBuildFNeg(Builder, Val, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
+      case LLVMFCmp: {
+        LLVMRealPredicate Pred = LLVMGetFCmpPredicate(Src);
+        LLVMValueRef LHS = CloneValue(LLVMGetOperand(Src, 0));
+        LLVMValueRef RHS = CloneValue(LLVMGetOperand(Src, 1));
+        Dst = LLVMBuildFCmp(Builder, Pred, LHS, RHS, Name);
+        LLVMSetFastMathFlags(Dst, LLVMGetFastMathFlags(Src));
+        break;
+      }
       default:
         break;
     }

``````````

</details>


https://github.com/llvm/llvm-project/pull/75123


More information about the llvm-commits mailing list