[llvm] [SandboxIR] Implement FPMathOperator (PR #112921)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 18 08:49:49 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/112921
>From 05c6eac0f88102cad534a76e22077e22bc9280fe Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 17 Oct 2024 16:03:29 -0700
Subject: [PATCH] [SandboxIR] Implement FPMathOperator
---
llvm/include/llvm/SandboxIR/Operator.h | 39 +++++++++++++++++
llvm/include/llvm/SandboxIR/Type.h | 6 ++-
llvm/include/llvm/SandboxIR/Value.h | 2 +
llvm/unittests/SandboxIR/OperatorTest.cpp | 53 +++++++++++++++++++++++
4 files changed, 98 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Operator.h b/llvm/include/llvm/SandboxIR/Operator.h
index 95c450807191b4..f19c54c75e424f 100644
--- a/llvm/include/llvm/SandboxIR/Operator.h
+++ b/llvm/include/llvm/SandboxIR/Operator.h
@@ -55,6 +55,45 @@ class OverflowingBinaryOperator : public Operator {
return llvm::OverflowingBinaryOperator::classof(From->Val);
}
};
+
+class FPMathOperator : public Operator {
+public:
+ bool isFast() const { return cast<llvm::FPMathOperator>(Val)->isFast(); }
+ bool hasAllowReassoc() const {
+ return cast<llvm::FPMathOperator>(Val)->hasAllowReassoc();
+ }
+ bool hasNoNaNs() const {
+ return cast<llvm::FPMathOperator>(Val)->hasNoNaNs();
+ }
+ bool hasNoInfs() const {
+ return cast<llvm::FPMathOperator>(Val)->hasNoInfs();
+ }
+ bool hasNoSignedZeros() const {
+ return cast<llvm::FPMathOperator>(Val)->hasNoSignedZeros();
+ }
+ bool hasAllowReciprocal() const {
+ return cast<llvm::FPMathOperator>(Val)->hasAllowReciprocal();
+ }
+ bool hasAllowContract() const {
+ return cast<llvm::FPMathOperator>(Val)->hasAllowContract();
+ }
+ bool hasApproxFunc() const {
+ return cast<llvm::FPMathOperator>(Val)->hasApproxFunc();
+ }
+ FastMathFlags getFastMathFlags() const {
+ return cast<llvm::FPMathOperator>(Val)->getFastMathFlags();
+ }
+ float getFPAccuracy() const {
+ return cast<llvm::FPMathOperator>(Val)->getFPAccuracy();
+ }
+ static bool isSupportedFloatingPointType(Type *Ty) {
+ return llvm::FPMathOperator::isSupportedFloatingPointType(Ty->LLVMTy);
+ }
+ static bool classof(const Value *V) {
+ return llvm::FPMathOperator::classof(V->Val);
+ }
+};
+
} // namespace llvm::sandboxir
#endif // LLVM_SANDBOXIR_OPERATOR_H
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 8094f66567fb8c..9d1db11edb05ae 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -33,12 +33,13 @@ class ArrayType;
class StructType;
class TargetExtType;
class Module;
+class FPMathOperator;
#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
#define DEF_CONST(ID, CLASS) class CLASS;
#include "llvm/SandboxIR/Values.def"
-/// Just like llvm::Type these are immutable, unique, never get freed and can
-/// only be created via static factory methods.
+/// Just like llvm::Type these are immutable, unique, never get freed and
+/// can only be created via static factory methods.
class Type {
protected:
llvm::Type *LLVMTy;
@@ -61,6 +62,7 @@ class Type {
friend class Utils; // for LLVMTy
friend class TargetExtType; // For LLVMTy.
friend class Module; // For LLVMTy.
+ friend class FPMathOperator; // For LLVMTy.
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
diff --git a/llvm/include/llvm/SandboxIR/Value.h b/llvm/include/llvm/SandboxIR/Value.h
index 58088684bf18ec..243195f4c1c4bd 100644
--- a/llvm/include/llvm/SandboxIR/Value.h
+++ b/llvm/include/llvm/SandboxIR/Value.h
@@ -30,6 +30,7 @@ class CmpInst;
class IntrinsicInst;
class Operator;
class OverflowingBinaryOperator;
+class FPMathOperator;
/// Iterator for the `Use` edges of a Value's users.
/// \Returns a `Use` when dereferenced.
@@ -162,6 +163,7 @@ class Value {
friend class IntrinsicInst; // For `Val`.
friend class Operator; // For `Val`.
friend class OverflowingBinaryOperator; // For `Val`.
+ friend class FPMathOperator; // For `Val`.
// Region needs to manipulate metadata in the underlying LLVM Value, we don't
// expose metadata in sandboxir.
friend class Region;
diff --git a/llvm/unittests/SandboxIR/OperatorTest.cpp b/llvm/unittests/SandboxIR/OperatorTest.cpp
index 031e2adf406927..b1e324417da41d 100644
--- a/llvm/unittests/SandboxIR/OperatorTest.cpp
+++ b/llvm/unittests/SandboxIR/OperatorTest.cpp
@@ -86,3 +86,56 @@ define void @foo(i8 %v1) {
EXPECT_EQ(AddNUW->getNoWrapKind(),
llvm::OverflowingBinaryOperator::NoUnsignedWrap);
}
+
+TEST_F(OperatorTest, FPMathOperator) {
+ parseIR(C, R"IR(
+define void @foo(float %v1, double %v2) {
+ %fadd = fadd float %v1, 42.0
+ %Fast = fadd fast float %v1, 42.0
+ %Reassoc = fmul reassoc float %v1, 42.0
+ %NNAN = fmul nnan float %v1, 42.0
+ %NINF = fmul ninf float %v1, 42.0
+ %NSZ = fmul nsz float %v1, 42.0
+ %ARCP = fmul arcp float %v1, 42.0
+ %CONTRACT = fmul contract float %v1, 42.0
+ %AFN = fmul afn double %v2, 42.0
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ auto *LLVMBB = &*LLVMF->begin();
+ auto LLVMIt = LLVMBB->begin();
+
+ sandboxir::Context Ctx(C);
+ sandboxir::Function *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto TermIt = BB->getTerminator()->getIterator();
+ while (It != TermIt) {
+ auto *FPM = cast<sandboxir::FPMathOperator>(&*It++);
+ auto *LLVMFPM = cast<llvm::FPMathOperator>(&*LLVMIt++);
+ EXPECT_EQ(FPM->isFast(), LLVMFPM->isFast());
+ EXPECT_EQ(FPM->hasAllowReassoc(), LLVMFPM->hasAllowReassoc());
+ EXPECT_EQ(FPM->hasNoNaNs(), LLVMFPM->hasNoNaNs());
+ EXPECT_EQ(FPM->hasNoInfs(), LLVMFPM->hasNoInfs());
+ EXPECT_EQ(FPM->hasNoSignedZeros(), LLVMFPM->hasNoSignedZeros());
+ EXPECT_EQ(FPM->hasAllowReciprocal(), LLVMFPM->hasAllowReciprocal());
+ EXPECT_EQ(FPM->hasAllowContract(), LLVMFPM->hasAllowContract());
+ EXPECT_EQ(FPM->hasApproxFunc(), LLVMFPM->hasApproxFunc());
+
+ // There doesn't seem to be an operator== for FastMathFlags so let's do a
+ // string comparison instead.
+ std::string Str1;
+ raw_string_ostream SS1(Str1);
+ std::string Str2;
+ raw_string_ostream SS2(Str2);
+ FPM->getFastMathFlags().print(SS1);
+ LLVMFPM->getFastMathFlags().print(SS2);
+ EXPECT_EQ(Str1, Str2);
+
+ EXPECT_EQ(FPM->getFPAccuracy(), LLVMFPM->getFPAccuracy());
+ EXPECT_EQ(
+ sandboxir::FPMathOperator::isSupportedFloatingPointType(FPM->getType()),
+ llvm::FPMathOperator::isSupportedFloatingPointType(LLVMFPM->getType()));
+ }
+}
More information about the llvm-commits
mailing list