[llvm] [SandboxVec][Legality] Check opcodes and types (PR #113741)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Mon Oct 28 09:20:37 PDT 2024
    
    
  
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/113741
>From aa96d30dda107c3df7dea35d501d26136650563a Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 17 Oct 2024 09:39:07 -0700
Subject: [PATCH] [SandboxVec][Legality] Check opcodes and types
---
 .../Vectorize/SandboxVectorizer/VecUtils.h    | 12 ++++--
 .../Vectorize/SandboxVectorizer/Legality.cpp  | 21 ++++++++++-
 .../SandboxVectorizer/CMakeLists.txt          |  1 +
 .../SandboxVectorizer/LegalityTest.cpp        | 27 +++++++++++++-
 .../SandboxVectorizer/VecUtilsTest.cpp        | 37 +++++++++++++++++++
 5 files changed, 93 insertions(+), 5 deletions(-)
 create mode 100644 llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 64f57edb38484e..9577e8ef7b37cb 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -12,7 +12,11 @@
 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 
-class Utils {
+#include "llvm/SandboxIR/Type.h"
+
+namespace llvm::sandboxir {
+
+class VecUtils {
 public:
   /// \Returns the number of elements in \p Ty. That is the number of lanes if a
   /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
@@ -25,6 +29,8 @@ class Utils {
   static Type *getElementType(Type *Ty) {
     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
   }
-}
+};
+
+} // namespace llvm::sandboxir
 
-#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index e4546c2f98113e..346a2eb2266f10 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -11,6 +11,7 @@
 #include "llvm/SandboxIR/Utils.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 
 namespace llvm::sandboxir {
 
@@ -26,7 +27,25 @@ void LegalityResult::dump() const {
 std::optional<ResultReason>
 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
     ArrayRef<Value *> Bndl) {
-  // TODO: Unimplemented.
+  auto *I0 = cast<Instruction>(Bndl[0]);
+  auto Opcode = I0->getOpcode();
+  // If they have different opcodes, then we cannot form a vector (for now).
+  if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
+        return cast<Instruction>(V)->getOpcode() != Opcode;
+      }))
+    return ResultReason::DiffOpcodes;
+
+  // If not the same scalar type, Pack. This will accept scalars and vectors as
+  // long as the element type is the same.
+  Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
+  for (auto *V : drop_begin(Bndl)) {
+    Type *ElmTy = VecUtils::getElementType(Utils::getExpectedType(V));
+    if (ElmTy != ElmTy0)
+      return ResultReason::DiffTypes;
+  }
+
+  // TODO: Missing checks
+
   return std::nullopt;
 }
 
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index 24512cb0225e8e..df689767b77245 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -13,4 +13,5 @@ add_llvm_unittest(SandboxVectorizerTests
   LegalityTest.cpp
   SchedulerTest.cpp
   SeedCollectorTest.cpp	
+  VecUtilsTest.cpp
 )
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 56c6bf5f1ef1f5..51f445c8d1d010 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -29,13 +29,17 @@ struct LegalityTest : public testing::Test {
 
 TEST_F(LegalityTest, Legality) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg) {
   %gep0 = getelementptr float, ptr %ptr, i32 0
   %gep1 = getelementptr float, ptr %ptr, i32 1
+  %gep3 = getelementptr float, ptr %ptr, i32 3
   %ld0 = load float, ptr %gep0
   %ld1 = load float, ptr %gep0
   store float %ld0, ptr %gep0
   store float %ld1, ptr %gep1
+  store <2 x float> %vec2, ptr %gep1
+  store <3 x float> %vec3, ptr %gep3
+  store i8 %arg, ptr %gep1
   ret void
 }
 )IR");
@@ -46,10 +50,14 @@ define void @foo(ptr %ptr) {
   auto It = BB->begin();
   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
   [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
 
   sandboxir::LegalityAnalysis Legality;
   const auto &Result = Legality.canVectorize({St0, St1});
@@ -62,6 +70,23 @@ define void @foo(ptr %ptr) {
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::NotInstructions);
   }
+  {
+    // Check DiffOpcodes
+    const auto &Result = Legality.canVectorize({St0, Ld0});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffOpcodes);
+  }
+  {
+    // Check DiffTypes
+    EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
+    EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
+
+    const auto &Result = Legality.canVectorize({St0, StI8});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffTypes);
+  }
 }
 
 #ifndef NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
new file mode 100644
index 00000000000000..e0b08284964392
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -0,0 +1,37 @@
+//===- VecUtilsTest.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Type.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct VecUtilsTest : public testing::Test {
+  LLVMContext C;
+};
+
+TEST_F(VecUtilsTest, GetNumElements) {
+  sandboxir::Context Ctx(C);
+  auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
+  EXPECT_EQ(sandboxir::VecUtils::getNumElements(ElemTy), 1);
+  auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
+  EXPECT_EQ(sandboxir::VecUtils::getNumElements(VTy), 2);
+  auto *VTy1 = sandboxir::FixedVectorType::get(ElemTy, 1);
+  EXPECT_EQ(sandboxir::VecUtils::getNumElements(VTy1), 1);
+}
+
+TEST_F(VecUtilsTest, GetElementType) {
+  sandboxir::Context Ctx(C);
+  auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
+  EXPECT_EQ(sandboxir::VecUtils::getElementType(ElemTy), ElemTy);
+  auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
+  EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy);
+}
    
    
More information about the llvm-commits
mailing list