[llvm] [SandboxVec][Legality] Check wrap flags (PR #113975)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 28 15:36:52 PDT 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/113975
None
>From 5892bf5b94eee9fef8c562449a216a6898578844 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 17 Oct 2024 11:09:00 -0700
Subject: [PATCH] [SandboxVec][Legality] Check wrap flags
---
.../Vectorize/SandboxVectorizer/Legality.h | 3 +++
.../Vectorize/SandboxVectorizer/Legality.cpp | 15 +++++++++++++++
.../Vectorize/SandboxVectorizer/LegalityTest.cpp | 16 +++++++++++++++-
3 files changed, 33 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 49dcec26dbc559..77ba5cd7f002e9 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -32,6 +32,7 @@ enum class ResultReason {
DiffOpcodes,
DiffTypes,
DiffMathFlags,
+ DiffWrapFlags,
};
#ifndef NDEBUG
@@ -56,6 +57,8 @@ struct ToStr {
return "DiffTypes";
case ResultReason::DiffMathFlags:
return "DiffMathFlags";
+ case ResultReason::DiffWrapFlags:
+ return "DiffWrapFlags";
}
llvm_unreachable("Unknown ResultReason enum");
}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 346d8a90589f55..1cc6356300e492 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -55,6 +55,21 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
return ResultReason::DiffMathFlags;
}
+ // TODO: Allow vectorization by using common flags.
+ // For now Pack if they don't have the same wrap flags.
+ bool CanHaveWrapFlags =
+ isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0);
+ if (CanHaveWrapFlags) {
+ bool NUW0 = I0->hasNoUnsignedWrap();
+ bool NSW0 = I0->hasNoSignedWrap();
+ if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) {
+ return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
+ cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
+ })) {
+ return ResultReason::DiffWrapFlags;
+ }
+ }
+
// TODO: Missing checks
return std::nullopt;
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index aaa8e96de6d171..50b78f6f48afdf 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -29,7 +29,7 @@ struct LegalityTest : public testing::Test {
TEST_F(LegalityTest, Legality) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) {
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
%gep3 = getelementptr float, ptr %ptr, i32 3
@@ -42,6 +42,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
store i8 %arg, ptr %gep1
%fadd0 = fadd float %farg0, %farg0
%fadd1 = fadd fast float %farg1, %farg1
+ %trunc0 = trunc nuw nsw i64 %v0 to i8
+ %trunc1 = trunc nsw i64 %v1 to i8
ret void
}
)IR");
@@ -62,6 +64,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
+ auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
sandboxir::LegalityAnalysis Legality;
const auto &Result = Legality.canVectorize({St0, St1});
@@ -98,6 +102,13 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffMathFlags);
}
+ {
+ // Check DiffWrapFlags
+ const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::DiffWrapFlags);
+ }
}
#ifndef NDEBUG
@@ -124,5 +135,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
sandboxir::ResultReason::DiffMathFlags),
"Pack Reason: DiffMathFlags"));
+ EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
+ sandboxir::ResultReason::DiffWrapFlags),
+ "Pack Reason: DiffWrapFlags"));
}
#endif // NDEBUG
More information about the llvm-commits
mailing list