[llvm] [SandboxIR] Implement FPMathOperator (PR #112921)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 18 08:49:49 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/112921

>From 05c6eac0f88102cad534a76e22077e22bc9280fe Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 17 Oct 2024 16:03:29 -0700
Subject: [PATCH] [SandboxIR] Implement FPMathOperator

---
 llvm/include/llvm/SandboxIR/Operator.h    | 39 +++++++++++++++++
 llvm/include/llvm/SandboxIR/Type.h        |  6 ++-
 llvm/include/llvm/SandboxIR/Value.h       |  2 +
 llvm/unittests/SandboxIR/OperatorTest.cpp | 53 +++++++++++++++++++++++
 4 files changed, 98 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Operator.h b/llvm/include/llvm/SandboxIR/Operator.h
index 95c450807191b4..f19c54c75e424f 100644
--- a/llvm/include/llvm/SandboxIR/Operator.h
+++ b/llvm/include/llvm/SandboxIR/Operator.h
@@ -55,6 +55,45 @@ class OverflowingBinaryOperator : public Operator {
     return llvm::OverflowingBinaryOperator::classof(From->Val);
   }
 };
+
+class FPMathOperator : public Operator {
+public:
+  bool isFast() const { return cast<llvm::FPMathOperator>(Val)->isFast(); }
+  bool hasAllowReassoc() const {
+    return cast<llvm::FPMathOperator>(Val)->hasAllowReassoc();
+  }
+  bool hasNoNaNs() const {
+    return cast<llvm::FPMathOperator>(Val)->hasNoNaNs();
+  }
+  bool hasNoInfs() const {
+    return cast<llvm::FPMathOperator>(Val)->hasNoInfs();
+  }
+  bool hasNoSignedZeros() const {
+    return cast<llvm::FPMathOperator>(Val)->hasNoSignedZeros();
+  }
+  bool hasAllowReciprocal() const {
+    return cast<llvm::FPMathOperator>(Val)->hasAllowReciprocal();
+  }
+  bool hasAllowContract() const {
+    return cast<llvm::FPMathOperator>(Val)->hasAllowContract();
+  }
+  bool hasApproxFunc() const {
+    return cast<llvm::FPMathOperator>(Val)->hasApproxFunc();
+  }
+  FastMathFlags getFastMathFlags() const {
+    return cast<llvm::FPMathOperator>(Val)->getFastMathFlags();
+  }
+  float getFPAccuracy() const {
+    return cast<llvm::FPMathOperator>(Val)->getFPAccuracy();
+  }
+  static bool isSupportedFloatingPointType(Type *Ty) {
+    return llvm::FPMathOperator::isSupportedFloatingPointType(Ty->LLVMTy);
+  }
+  static bool classof(const Value *V) {
+    return llvm::FPMathOperator::classof(V->Val);
+  }
+};
+
 } // namespace llvm::sandboxir
 
 #endif // LLVM_SANDBOXIR_OPERATOR_H
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 8094f66567fb8c..9d1db11edb05ae 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -33,12 +33,13 @@ class ArrayType;
 class StructType;
 class TargetExtType;
 class Module;
+class FPMathOperator;
 #define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
 #define DEF_CONST(ID, CLASS) class CLASS;
 #include "llvm/SandboxIR/Values.def"
 
-/// Just like llvm::Type these are immutable, unique, never get freed and can
-/// only be created via static factory methods.
+/// Just like llvm::Type these are immutable, unique, never get freed and
+/// can only be created via static factory methods.
 class Type {
 protected:
   llvm::Type *LLVMTy;
@@ -61,6 +62,7 @@ class Type {
   friend class Utils;              // for LLVMTy
   friend class TargetExtType;      // For LLVMTy.
   friend class Module;             // For LLVMTy.
+  friend class FPMathOperator;     // For LLVMTy.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
diff --git a/llvm/include/llvm/SandboxIR/Value.h b/llvm/include/llvm/SandboxIR/Value.h
index 58088684bf18ec..243195f4c1c4bd 100644
--- a/llvm/include/llvm/SandboxIR/Value.h
+++ b/llvm/include/llvm/SandboxIR/Value.h
@@ -30,6 +30,7 @@ class CmpInst;
 class IntrinsicInst;
 class Operator;
 class OverflowingBinaryOperator;
+class FPMathOperator;
 
 /// Iterator for the `Use` edges of a Value's users.
 /// \Returns a `Use` when dereferenced.
@@ -162,6 +163,7 @@ class Value {
   friend class IntrinsicInst;         // For `Val`.
   friend class Operator;              // For `Val`.
   friend class OverflowingBinaryOperator; // For `Val`.
+  friend class FPMathOperator;            // For `Val`.
   // Region needs to manipulate metadata in the underlying LLVM Value, we don't
   // expose metadata in sandboxir.
   friend class Region;
diff --git a/llvm/unittests/SandboxIR/OperatorTest.cpp b/llvm/unittests/SandboxIR/OperatorTest.cpp
index 031e2adf406927..b1e324417da41d 100644
--- a/llvm/unittests/SandboxIR/OperatorTest.cpp
+++ b/llvm/unittests/SandboxIR/OperatorTest.cpp
@@ -86,3 +86,56 @@ define void @foo(i8 %v1) {
   EXPECT_EQ(AddNUW->getNoWrapKind(),
             llvm::OverflowingBinaryOperator::NoUnsignedWrap);
 }
+
+TEST_F(OperatorTest, FPMathOperator) {
+  parseIR(C, R"IR(
+define void @foo(float %v1, double %v2) {
+  %fadd = fadd float %v1, 42.0
+  %Fast = fadd fast float %v1, 42.0
+  %Reassoc = fmul reassoc float %v1, 42.0
+  %NNAN = fmul nnan float %v1, 42.0
+  %NINF = fmul ninf float %v1, 42.0
+  %NSZ = fmul nsz float %v1, 42.0
+  %ARCP = fmul arcp float %v1, 42.0
+  %CONTRACT = fmul contract float %v1, 42.0
+  %AFN = fmul afn double %v2, 42.0
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  auto *LLVMBB = &*LLVMF->begin();
+  auto LLVMIt = LLVMBB->begin();
+
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto TermIt = BB->getTerminator()->getIterator();
+  while (It != TermIt) {
+    auto *FPM = cast<sandboxir::FPMathOperator>(&*It++);
+    auto *LLVMFPM = cast<llvm::FPMathOperator>(&*LLVMIt++);
+    EXPECT_EQ(FPM->isFast(), LLVMFPM->isFast());
+    EXPECT_EQ(FPM->hasAllowReassoc(), LLVMFPM->hasAllowReassoc());
+    EXPECT_EQ(FPM->hasNoNaNs(), LLVMFPM->hasNoNaNs());
+    EXPECT_EQ(FPM->hasNoInfs(), LLVMFPM->hasNoInfs());
+    EXPECT_EQ(FPM->hasNoSignedZeros(), LLVMFPM->hasNoSignedZeros());
+    EXPECT_EQ(FPM->hasAllowReciprocal(), LLVMFPM->hasAllowReciprocal());
+    EXPECT_EQ(FPM->hasAllowContract(), LLVMFPM->hasAllowContract());
+    EXPECT_EQ(FPM->hasApproxFunc(), LLVMFPM->hasApproxFunc());
+
+    // There doesn't seem to be an operator== for FastMathFlags so let's do a
+    // string comparison instead.
+    std::string Str1;
+    raw_string_ostream SS1(Str1);
+    std::string Str2;
+    raw_string_ostream SS2(Str2);
+    FPM->getFastMathFlags().print(SS1);
+    LLVMFPM->getFastMathFlags().print(SS2);
+    EXPECT_EQ(Str1, Str2);
+
+    EXPECT_EQ(FPM->getFPAccuracy(), LLVMFPM->getFPAccuracy());
+    EXPECT_EQ(
+        sandboxir::FPMathOperator::isSupportedFloatingPointType(FPM->getType()),
+        llvm::FPMathOperator::isSupportedFloatingPointType(LLVMFPM->getType()));
+  }
+}



More information about the llvm-commits mailing list