[llvm] [SandboxVec][Legality] Fix legality of SelectInst (PR #125005)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 16:29:52 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/125005
SelectInsts need special treatment because they are not always straightforward to vectorize. This patch disables vectorization unless they are trivially vectorizable.
>From 6a181808bb4b82f18093d52bceacdda63b64e7d3 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 16 Jan 2025 13:50:49 -0800
Subject: [PATCH] [SandboxVec][Legality] Fix legality of SelectInst
SelectInsts need special treatment because they are not always straightforward
to vectorize. This patch disables vectorization unless they are trivially
vectorizable.
---
.../Vectorize/SandboxVectorizer/Legality.cpp | 10 +-
.../SandboxVectorizer/special_opcodes.ll | 97 +++++++++++++++++++
.../SandboxVectorizer/LegalityTest.cpp | 15 ++-
3 files changed, 120 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 62be90aee4e0e0..c9329c24e1f4c7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -116,7 +116,15 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
return std::nullopt;
return ResultReason::DiffOpcodes;
}
- case Instruction::Opcode::Select:
+ case Instruction::Opcode::Select: {
+ auto *Sel0 = cast<SelectInst>(Bndl[0]);
+ auto *Cond0 = Sel0->getCondition();
+ if (VecUtils::getNumLanes(Cond0) != VecUtils::getNumLanes(Sel0))
+ // TODO: For now we don't vectorize if the lanes in the condition don't
+ // match those of the select instruction.
+ return ResultReason::Unimplemented;
+ return std::nullopt;
+ }
case Instruction::Opcode::FNeg:
case Instruction::Opcode::Add:
case Instruction::Opcode::FAdd:
diff --git a/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll b/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
new file mode 100644
index 00000000000000..fe3a2067d481d1
--- /dev/null
+++ b/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
@@ -0,0 +1,97 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
+
+; This file includes tests for opcodes that need special checks.
+
+; TODO: Selects with conditions of diff number of lanes than the instruction itself need special treatment.
+define void @selects_with_diff_cond_lanes(ptr %ptr, i1 %cond0, i1 %cond1, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_with_diff_cond_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND0:%.*]], i1 [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND0]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
+; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
+; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
+; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
+; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
+; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
+; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
+; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
+; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
+; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+ %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+ %ld0 = load <2 x i8>, ptr %ptr0
+ %ld1 = load <2 x i8>, ptr %ptr1
+ %sel0 = select i1 %cond0, <2 x i8> %ld0, <2 x i8> %ld0
+ %sel1 = select i1 %cond1, <2 x i8> %ld1, <2 x i8> %ld1
+ store <2 x i8> %sel0, ptr %ptr0
+ store <2 x i8> %sel1, ptr %ptr1
+ ret void
+}
+
+; TODO: Selects that share the same condition need special treatment.
+define void @selects_with_common_condition_but_diff_lanes(ptr %ptr, i1 %cond, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_with_common_condition_but_diff_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
+; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
+; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
+; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
+; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
+; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
+; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
+; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
+; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
+; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+ %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+ %ld0 = load <2 x i8>, ptr %ptr0
+ %ld1 = load <2 x i8>, ptr %ptr1
+ %sel0 = select i1 %cond, <2 x i8> %ld0, <2 x i8> %ld0
+ %sel1 = select i1 %cond, <2 x i8> %ld1, <2 x i8> %ld1
+ store <2 x i8> %sel0, ptr %ptr0
+ store <2 x i8> %sel1, ptr %ptr1
+ ret void
+}
+
+; Selects with conditions of the same number of lanes as the instruction itself be vectorized as usual.
+define void @selects_same_cond_lanes(ptr %ptr, <2 x i1> %cond0, <2 x i1> %cond1, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_same_cond_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], <2 x i1> [[COND0:%.*]], <2 x i1> [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i1> [[COND0]], i32 0
+; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i1> poison, i1 [[VPACK]], i32 0
+; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i1> [[COND0]], i32 1
+; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i1> [[VPACK1]], i1 [[VPACK2]], i32 1
+; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i1> [[COND1]], i32 0
+; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i1> [[VPACK3]], i1 [[VPACK4]], i32 2
+; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i1> [[COND1]], i32 1
+; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i1> [[VPACK5]], i1 [[VPACK6]], i32 3
+; CHECK-NEXT: [[VECL:%.*]] = load <4 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT: [[VEC:%.*]] = select <4 x i1> [[VPACK7]], <4 x i8> [[VECL]], <4 x i8> [[VECL]]
+; CHECK-NEXT: store <4 x i8> [[VEC]], ptr [[PTR0]], align 2
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+ %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+ %ld0 = load <2 x i8>, ptr %ptr0
+ %ld1 = load <2 x i8>, ptr %ptr1
+ %sel0 = select <2 x i1> %cond0, <2 x i8> %ld0, <2 x i8> %ld0
+ %sel1 = select <2 x i1> %cond1, <2 x i8> %ld1, <2 x i8> %ld1
+ store <2 x i8> %sel0, ptr %ptr0
+ store <2 x i8> %sel1, ptr %ptr1
+ ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 3c24214f0d87f2..15f8166b705fc2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -67,7 +67,7 @@ static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
TEST_F(LegalityTest, LegalitySkipSchedule) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2, i1 %c0, i1 %c1) {
entry:
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
@@ -93,6 +93,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
%trunc32to8 = trunc i32 %v2 to i8
%cmpSLT = icmp slt i64 %v0, %v1
%cmpSGT = icmp sgt i64 %v0, %v1
+ %sel0 = select i1 %c0, <2 x float> %vec2, <2 x float> %vec2
+ %sel1 = select i1 %c1, <2 x float> %vec2, <2 x float> %vec2
ret void
}
)IR");
@@ -128,6 +130,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
+ auto *Sel0 = cast<sandboxir::SelectInst>(&*It++);
+ auto *Sel1 = cast<sandboxir::SelectInst>(&*It++);
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
@@ -241,6 +245,15 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::RepeatedInstrs);
}
+ {
+ // For now don't vectorize Selects when the number of elements of conditions
+ // doesn't match the operands.
+ const auto &Result =
+ Legality.canVectorize({Sel0, Sel1}, /*SkipScheduling=*/true);
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::Unimplemented);
+ }
}
TEST_F(LegalityTest, LegalitySchedule) {
More information about the llvm-commits
mailing list