[llvm] [SandboxVec][Legality] Fix legality of SelectInst (PR #125005)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 16:29:52 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/125005

SelectInsts need special treatment because they are not always straightforward to vectorize. This patch disables vectorization unless they are trivially vectorizable.

>From 6a181808bb4b82f18093d52bceacdda63b64e7d3 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 16 Jan 2025 13:50:49 -0800
Subject: [PATCH] [SandboxVec][Legality] Fix legality of SelectInst

SelectInsts need special treatment because they are not always straightforward
to vectorize. This patch disables vectorization unless they are trivially
vectorizable.
---
 .../Vectorize/SandboxVectorizer/Legality.cpp  | 10 +-
 .../SandboxVectorizer/special_opcodes.ll      | 97 +++++++++++++++++++
 .../SandboxVectorizer/LegalityTest.cpp        | 15 ++-
 3 files changed, 120 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll

diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 62be90aee4e0e0..c9329c24e1f4c7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -116,7 +116,15 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
       return std::nullopt;
     return ResultReason::DiffOpcodes;
   }
-  case Instruction::Opcode::Select:
+  case Instruction::Opcode::Select: {
+    auto *Sel0 = cast<SelectInst>(Bndl[0]);
+    auto *Cond0 = Sel0->getCondition();
+    if (VecUtils::getNumLanes(Cond0) != VecUtils::getNumLanes(Sel0))
+      // TODO: For now we don't vectorize if the lanes in the condition don't
+      // match those of the select instruction.
+      return ResultReason::Unimplemented;
+    return std::nullopt;
+  }
   case Instruction::Opcode::FNeg:
   case Instruction::Opcode::Add:
   case Instruction::Opcode::FAdd:
diff --git a/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll b/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
new file mode 100644
index 00000000000000..fe3a2067d481d1
--- /dev/null
+++ b/llvm/test/Transforms/SandboxVectorizer/special_opcodes.ll
@@ -0,0 +1,97 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
+
+; This file includes tests for opcodes that need special checks.
+
+; TODO: Selects with conditions of diff number of lanes than the instruction itself need special treatment.
+define void @selects_with_diff_cond_lanes(ptr %ptr, i1 %cond0, i1 %cond1, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_with_diff_cond_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND0:%.*]], i1 [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT:    [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[COND0]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[COND1]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
+; CHECK-NEXT:    [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
+; CHECK-NEXT:    [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
+; CHECK-NEXT:    [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
+; CHECK-NEXT:    [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
+; CHECK-NEXT:    [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
+; CHECK-NEXT:    [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
+; CHECK-NEXT:    [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
+; CHECK-NEXT:    [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+  %ld0 = load <2 x i8>, ptr %ptr0
+  %ld1 = load <2 x i8>, ptr %ptr1
+  %sel0 = select i1 %cond0, <2 x i8> %ld0, <2 x i8> %ld0
+  %sel1 = select i1 %cond1, <2 x i8> %ld1, <2 x i8> %ld1
+  store <2 x i8> %sel0, ptr %ptr0
+  store <2 x i8> %sel1, ptr %ptr1
+  ret void
+}
+
+; TODO: Selects that share the same condition need special treatment.
+define void @selects_with_common_condition_but_diff_lanes(ptr %ptr, i1 %cond, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_with_common_condition_but_diff_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT:    [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[COND]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[COND]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
+; CHECK-NEXT:    [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
+; CHECK-NEXT:    [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
+; CHECK-NEXT:    [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
+; CHECK-NEXT:    [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
+; CHECK-NEXT:    [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
+; CHECK-NEXT:    [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
+; CHECK-NEXT:    [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
+; CHECK-NEXT:    [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+  %ld0 = load <2 x i8>, ptr %ptr0
+  %ld1 = load <2 x i8>, ptr %ptr1
+  %sel0 = select i1 %cond, <2 x i8> %ld0, <2 x i8> %ld0
+  %sel1 = select i1 %cond, <2 x i8> %ld1, <2 x i8> %ld1
+  store <2 x i8> %sel0, ptr %ptr0
+  store <2 x i8> %sel1, ptr %ptr1
+  ret void
+}
+
+; Selects with conditions of the same number of lanes as the instruction itself be vectorized as usual.
+define void @selects_same_cond_lanes(ptr %ptr, <2 x i1> %cond0, <2 x i1> %cond1, <2 x i8> %op0, <2 x i8> %op1) {
+; CHECK-LABEL: define void @selects_same_cond_lanes(
+; CHECK-SAME: ptr [[PTR:%.*]], <2 x i1> [[COND0:%.*]], <2 x i1> [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VPACK:%.*]] = extractelement <2 x i1> [[COND0]], i32 0
+; CHECK-NEXT:    [[VPACK1:%.*]] = insertelement <4 x i1> poison, i1 [[VPACK]], i32 0
+; CHECK-NEXT:    [[VPACK2:%.*]] = extractelement <2 x i1> [[COND0]], i32 1
+; CHECK-NEXT:    [[VPACK3:%.*]] = insertelement <4 x i1> [[VPACK1]], i1 [[VPACK2]], i32 1
+; CHECK-NEXT:    [[VPACK4:%.*]] = extractelement <2 x i1> [[COND1]], i32 0
+; CHECK-NEXT:    [[VPACK5:%.*]] = insertelement <4 x i1> [[VPACK3]], i1 [[VPACK4]], i32 2
+; CHECK-NEXT:    [[VPACK6:%.*]] = extractelement <2 x i1> [[COND1]], i32 1
+; CHECK-NEXT:    [[VPACK7:%.*]] = insertelement <4 x i1> [[VPACK5]], i1 [[VPACK6]], i32 3
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x i8>, ptr [[PTR0]], align 2
+; CHECK-NEXT:    [[VEC:%.*]] = select <4 x i1> [[VPACK7]], <4 x i8> [[VECL]], <4 x i8> [[VECL]]
+; CHECK-NEXT:    store <4 x i8> [[VEC]], ptr [[PTR0]], align 2
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
+  %ld0 = load <2 x i8>, ptr %ptr0
+  %ld1 = load <2 x i8>, ptr %ptr1
+  %sel0 = select <2 x i1> %cond0, <2 x i8> %ld0, <2 x i8> %ld0
+  %sel1 = select <2 x i1> %cond1, <2 x i8> %ld1, <2 x i8> %ld1
+  store <2 x i8> %sel0, ptr %ptr0
+  store <2 x i8> %sel1, ptr %ptr1
+  ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 3c24214f0d87f2..15f8166b705fc2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -67,7 +67,7 @@ static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
 
 TEST_F(LegalityTest, LegalitySkipSchedule) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2, i1 %c0, i1 %c1) {
 entry:
   %gep0 = getelementptr float, ptr %ptr, i32 0
   %gep1 = getelementptr float, ptr %ptr, i32 1
@@ -93,6 +93,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   %trunc32to8 = trunc i32 %v2 to i8
   %cmpSLT = icmp slt i64 %v0, %v1
   %cmpSGT = icmp sgt i64 %v0, %v1
+  %sel0 = select i1 %c0, <2 x float> %vec2, <2 x float> %vec2
+  %sel1 = select i1 %c1, <2 x float> %vec2, <2 x float> %vec2
   ret void
 }
 )IR");
@@ -128,6 +130,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
+  auto *Sel0 = cast<sandboxir::SelectInst>(&*It++);
+  auto *Sel1 = cast<sandboxir::SelectInst>(&*It++);
 
   llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
@@ -241,6 +245,15 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::RepeatedInstrs);
   }
+  {
+    // For now don't vectorize Selects when the number of elements of conditions
+    // doesn't match the operands.
+    const auto &Result =
+        Legality.canVectorize({Sel0, Sel1}, /*SkipScheduling=*/true);
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::Unimplemented);
+  }
 }
 
 TEST_F(LegalityTest, LegalitySchedule) {



More information about the llvm-commits mailing list