[llvm] [SandboxVec][Legality] Check Fastmath flags (PR #113967)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 14:55:55 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/113967

None

>From e8526d2c1875e63bcf7900a71d03a68fa824b243 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 17 Oct 2024 09:58:00 -0700
Subject: [PATCH] [SandboxVec][Legality] Check Fastmath flags

---
 .../Vectorize/SandboxVectorizer/Legality.h       |  3 +++
 .../Vectorize/SandboxVectorizer/Legality.cpp     | 12 ++++++++++++
 .../Vectorize/SandboxVectorizer/LegalityTest.cpp | 16 +++++++++++++++-
 3 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index d4b0b54375b026..49dcec26dbc559 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -31,6 +31,7 @@ enum class ResultReason {
   NotInstructions,
   DiffOpcodes,
   DiffTypes,
+  DiffMathFlags,
 };
 
 #ifndef NDEBUG
@@ -53,6 +54,8 @@ struct ToStr {
       return "DiffOpcodes";
     case ResultReason::DiffTypes:
       return "DiffTypes";
+    case ResultReason::DiffMathFlags:
+      return "DiffMathFlags";
     }
     llvm_unreachable("Unknown ResultReason enum");
   }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index fcfb11c669fa10..346d8a90589f55 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
 #include "llvm/SandboxIR/Instruction.h"
+#include "llvm/SandboxIR/Operator.h"
 #include "llvm/SandboxIR/Utils.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Debug.h"
@@ -43,6 +44,17 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
       }))
     return ResultReason::DiffTypes;
 
+  // TODO: Allow vectorization of instrs with different flags as long as we
+  // change them to the least common one.
+  // For now pack if differnt FastMathFlags.
+  if (isa<FPMathOperator>(I0)) {
+    FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
+    if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
+          return cast<Instruction>(V)->getFastMathFlags() != FMF0;
+        }))
+      return ResultReason::DiffMathFlags;
+  }
+
   // TODO: Missing checks
 
   return std::nullopt;
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 51f445c8d1d010..aaa8e96de6d171 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -29,7 +29,7 @@ struct LegalityTest : public testing::Test {
 
 TEST_F(LegalityTest, Legality) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1) {
   %gep0 = getelementptr float, ptr %ptr, i32 0
   %gep1 = getelementptr float, ptr %ptr, i32 1
   %gep3 = getelementptr float, ptr %ptr, i32 3
@@ -40,6 +40,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
   store <2 x float> %vec2, ptr %gep1
   store <3 x float> %vec3, ptr %gep3
   store i8 %arg, ptr %gep1
+  %fadd0 = fadd float %farg0, %farg0
+  %fadd1 = fadd fast float %farg1, %farg1
   ret void
 }
 )IR");
@@ -58,6 +60,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
   auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
   auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
   auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
+  auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
 
   sandboxir::LegalityAnalysis Legality;
   const auto &Result = Legality.canVectorize({St0, St1});
@@ -87,6 +91,13 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffTypes);
   }
+  {
+    // Check DiffMathFlags
+    const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffMathFlags);
+  }
 }
 
 #ifndef NDEBUG
@@ -110,5 +121,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
                           sandboxir::ResultReason::DiffTypes),
                       "Pack Reason: DiffTypes"));
+  EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
+                          sandboxir::ResultReason::DiffMathFlags),
+                      "Pack Reason: DiffMathFlags"));
 }
 #endif // NDEBUG



More information about the llvm-commits mailing list