[clang] [clang][bytecode] Implement comparsion operators for vector type (PR #107258)

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 5 06:08:06 PDT 2024


https://github.com/yronglin updated https://github.com/llvm/llvm-project/pull/107258

>From ae9fd2653ef3f5d6323c8c8e857fbf45f394f0ce Mon Sep 17 00:00:00 2001
From: yronglin <yronglin777 at gmail.com>
Date: Wed, 4 Sep 2024 23:31:16 +0800
Subject: [PATCH 1/3] [clang][bytecode] Implement comparsion operators for
 vector type

Signed-off-by: yronglin <yronglin777 at gmail.com>
---
 clang/lib/AST/ByteCode/Compiler.cpp           | 113 ++++++++++
 clang/lib/AST/ByteCode/Compiler.h             |   3 +-
 clang/test/AST/ByteCode/constexpr-vectors.cpp | 212 +++++++++++++++++-
 3 files changed, 323 insertions(+), 5 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index a831f196abdcb5..c3fb812971ac52 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -687,6 +687,8 @@ bool Compiler<Emitter>::VisitParenExpr(const ParenExpr *E) {
 template <class Emitter>
 bool Compiler<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
   // Need short-circuiting for these.
+  if (BO->getType()->isVectorType())
+    return this->VisitVectorBinOp(BO);
   if (BO->isLogicalOp())
     return this->VisitLogicalBinOp(BO);
 
@@ -1222,6 +1224,117 @@ bool Compiler<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   return true;
 }
 
+template <class Emitter>
+bool Compiler<Emitter>::VisitVectorBinOp(const BinaryOperator *E) {
+  assert(E->getType()->isVectorType());
+
+  // FIXME: Current only support comparison binary operator, add support for
+  // other binary operator.
+  if (!E->isComparisonOp())
+    return this->emitInvalid(E);
+  // Prepare storage for result.
+  if (!Initializing) {
+    unsigned LocalIndex = allocateTemporary(E);
+    if (!this->emitGetPtrLocal(LocalIndex, E))
+      return false;
+  }
+
+  const Expr *LHS = E->getLHS();
+  const Expr *RHS = E->getRHS();
+  const auto *VecTy = E->getType()->getAs<VectorType>();
+
+  // The LHS and RHS of a comparison operator must have the same type. So we
+  // just use LHS vector element type here.
+  PrimType ElemT = this->classifyVectorElementType(LHS->getType());
+  PrimType ResultElemT = this->classifyVectorElementType(E->getType());
+
+  // Evaluate LHS and save value to LHSOffset.
+  unsigned LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
+  if (!this->visit(LHS))
+    return false;
+  if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
+    return false;
+
+  // Evaluate RHS and save value to RHSOffset.
+  unsigned RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
+  if (!this->visit(RHS))
+    return false;
+  if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
+    return false;
+
+  auto getElem = [=](unsigned Offset, unsigned Index, PrimType ElemT) -> bool {
+    if (!this->emitGetLocal(PT_Ptr, Offset, E))
+      return false;
+    return this->emitArrayElemPop(ElemT, Index, E);
+  };
+
+  for (unsigned I = 0; I != VecTy->getNumElements(); ++I) {
+    if (!getElem(LHSOffset, I, ElemT))
+      return false;
+    if (!getElem(RHSOffset, I, ElemT))
+      return false;
+    if (E->isComparisonOp()) {
+      if (!this->emitVectorComparison(E))
+        return false;
+    } else {
+      llvm_unreachable("Unsupported binary operator");
+    }
+    if (!this->emitInitElem(ResultElemT, I, E))
+      return false;
+  }
+  return true;
+}
+
+template <class Emitter>
+bool Compiler<Emitter>::emitVectorComparison(const BinaryOperator *E) {
+  assert(E->isComparisonOp());
+  assert(!DiscardResult);
+  const auto *VecTy = E->getType()->getAs<VectorType>();
+
+  // The LHS and RHS of a comparison operator must have the same type. So we
+  // just use LHS vector element type here.
+  PrimType LHSElemT = this->classifyVectorElementType(E->getLHS()->getType());
+  PrimType ResultElemT = this->classifyVectorElementType(E->getType());
+  switch (E->getOpcode()) {
+  case BO_EQ:
+    if (!this->emitEQ(LHSElemT, E))
+      return false;
+    break;
+  case BO_NE:
+    if (!this->emitNE(LHSElemT, E))
+      return false;
+    break;
+  case BO_LE:
+    if (!this->emitLE(LHSElemT, E))
+      return false;
+    break;
+  case BO_LT:
+    if (!this->emitLT(LHSElemT, E))
+      return false;
+    break;
+  case BO_GE:
+    if (!this->emitGE(LHSElemT, E))
+      return false;
+    break;
+  case BO_GT:
+    if (!this->emitGT(LHSElemT, E))
+      return false;
+    break;
+  default:
+    llvm_unreachable("Unsupported binary operator");
+  }
+
+  // The result of the comparison is a vector of the same width and number of
+  // elements as the comparison operands with a signed integral element type.
+  //
+  // https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
+  if (!this->emitPrimCast(PT_Bool, ResultElemT, VecTy->getElementType(), E))
+    return false;
+  if (!this->emitNeg(ResultElemT, E))
+    return false;
+  return true;
+}
+
 template <class Emitter>
 bool Compiler<Emitter>::VisitImplicitValueInitExpr(
     const ImplicitValueInitExpr *E) {
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index b18afacdb2e491..c9dacbc76d7c4f 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -130,6 +130,7 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
   bool VisitLogicalBinOp(const BinaryOperator *E);
   bool VisitPointerArithBinOp(const BinaryOperator *E);
   bool VisitComplexBinOp(const BinaryOperator *E);
+  bool VisitVectorBinOp(const BinaryOperator *E);
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E);
   bool VisitCallExpr(const CallExpr *E);
   bool VisitBuiltinCallExpr(const CallExpr *E);
@@ -363,7 +364,7 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
   bool emitComplexBoolCast(const Expr *E);
   bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
                              const BinaryOperator *E);
-
+  bool emitVectorComparison(const BinaryOperator *E);
   bool emitRecordDestruction(const Record *R);
   bool emitDestruction(const Descriptor *Desc);
   unsigned collectBaseOffset(const QualType BaseType,
diff --git a/clang/test/AST/ByteCode/constexpr-vectors.cpp b/clang/test/AST/ByteCode/constexpr-vectors.cpp
index a738cfe617a0e0..684c5810702cc3 100644
--- a/clang/test/AST/ByteCode/constexpr-vectors.cpp
+++ b/clang/test/AST/ByteCode/constexpr-vectors.cpp
@@ -15,8 +15,50 @@ using FourI128ExtVec __attribute__((ext_vector_type(4))) = __int128;
 // Only int vs float makes a difference here, so we only need to test 1 of each.
 // Test Char to make sure the mixed-nature of shifts around char is evident.
 void CharUsage() {
-  constexpr auto H = FourCharsVecSize{-1, -1, 0, -1};
-  constexpr auto InvH = -H;
+  constexpr auto w = FourCharsVecSize{1, 2, 3, 4} <
+                     FourCharsVecSize{4, 3, 2, 1};
+  static_assert(w[0] == -1 && w[1] == -1 && w[2] == 0 && w[3] == 0, "");
+
+  constexpr auto x = FourCharsVecSize{1, 2, 3, 4} >
+                     FourCharsVecSize{4, 3, 2, 1};
+  static_assert(x[0] == 0 && x[1] == 0 && x[2] == -1 && x[3] == -1, "");
+
+  constexpr auto y = FourCharsVecSize{1, 2, 3, 4} <=
+                     FourCharsVecSize{4, 3, 3, 1};
+  static_assert(y[0] == -1 && y[1] == -1 && y[2] == -1 && y[3] == 0, "");
+
+  constexpr auto z = FourCharsVecSize{1, 2, 3, 4} >=
+                     FourCharsVecSize{4, 3, 3, 1};
+  static_assert(z[0] == 0 && z[1] == 0 && z[2] == -1 && z[3] == -1, "");
+
+  constexpr auto A = FourCharsVecSize{1, 2, 3, 4} ==
+                     FourCharsVecSize{4, 3, 3, 1};
+  static_assert(A[0] == 0 && A[1] == 0 && A[2] == -1 && A[3] == 0, "");
+
+  constexpr auto B = FourCharsVecSize{1, 2, 3, 4} !=
+                     FourCharsVecSize{4, 3, 3, 1};
+  static_assert(B[0] == -1 && B[1] == -1 && B[2] == 0 && B[3] == -1, "");
+
+  constexpr auto C = FourCharsVecSize{1, 2, 3, 4} < 3;
+  static_assert(C[0] == -1 && C[1] == -1 && C[2] == 0 && C[3] == 0, "");
+
+  constexpr auto D = FourCharsVecSize{1, 2, 3, 4} > 3;
+  static_assert(D[0] == 0 && D[1] == 0 && D[2] == 0 && D[3] == -1, "");
+
+  constexpr auto E = FourCharsVecSize{1, 2, 3, 4} <= 3;
+  static_assert(E[0] == -1 && E[1] == -1 && E[2] == -1 && E[3] == 0, "");
+
+  constexpr auto F = FourCharsVecSize{1, 2, 3, 4} >= 3;
+  static_assert(F[0] == 0 && F[1] == 0 && F[2] == -1 && F[3] == -1, "");
+
+  constexpr auto G = FourCharsVecSize{1, 2, 3, 4} == 3;
+  static_assert(G[0] == 0 && G[1] == 0 && G[2] == -1 && G[3] == 0, "");
+
+  constexpr auto H = FourCharsVecSize{1, 2, 3, 4} != 3;
+  static_assert(H[0] == -1 && H[1] == -1 && H[2] == 0 && H[3] == -1, "");
+
+  constexpr auto H1 = FourCharsVecSize{-1, -1, 0, -1};
+  constexpr auto InvH = -H1;
   static_assert(InvH[0] == 1 && InvH[1] == 1 && InvH[2] == 0 && InvH[3] == 1, "");
 
   constexpr auto ae = ~FourCharsVecSize{1, 2, 10, 20};
@@ -27,8 +69,50 @@ void CharUsage() {
 }
 
 void CharExtVecUsage() {
-  constexpr auto H = FourCharsExtVec{-1, -1, 0, -1};
-  constexpr auto InvH = -H;
+  constexpr auto w = FourCharsExtVec{1, 2, 3, 4} <
+                     FourCharsExtVec{4, 3, 2, 1};
+  static_assert(w[0] == -1 && w[1] == -1 && w[2] == 0 && w[3] == 0, "");
+
+  constexpr auto x = FourCharsExtVec{1, 2, 3, 4} >
+                     FourCharsExtVec{4, 3, 2, 1};
+  static_assert(x[0] == 0 && x[1] == 0 && x[2] == -1 && x[3] == -1, "");
+
+  constexpr auto y = FourCharsExtVec{1, 2, 3, 4} <=
+                     FourCharsExtVec{4, 3, 3, 1};
+  static_assert(y[0] == -1 && y[1] == -1 && y[2] == -1 && y[3] == 0, "");
+
+  constexpr auto z = FourCharsExtVec{1, 2, 3, 4} >=
+                     FourCharsExtVec{4, 3, 3, 1};
+  static_assert(z[0] == 0 && z[1] == 0 && z[2] == -1 && z[3] == -1, "");
+
+  constexpr auto A = FourCharsExtVec{1, 2, 3, 4} ==
+                     FourCharsExtVec{4, 3, 3, 1};
+  static_assert(A[0] == 0 && A[1] == 0 && A[2] == -1 && A[3] == 0, "");
+
+  constexpr auto B = FourCharsExtVec{1, 2, 3, 4} !=
+                     FourCharsExtVec{4, 3, 3, 1};
+  static_assert(B[0] == -1 && B[1] == -1 && B[2] == 0 && B[3] == -1, "");
+
+  constexpr auto C = FourCharsExtVec{1, 2, 3, 4} < 3;
+  static_assert(C[0] == -1 && C[1] == -1 && C[2] == 0 && C[3] == 0, "");
+
+  constexpr auto D = FourCharsExtVec{1, 2, 3, 4} > 3;
+  static_assert(D[0] == 0 && D[1] == 0 && D[2] == 0 && D[3] == -1, "");
+
+  constexpr auto E = FourCharsExtVec{1, 2, 3, 4} <= 3;
+  static_assert(E[0] == -1 && E[1] == -1 && E[2] == -1 && E[3] == 0, "");
+
+  constexpr auto F = FourCharsExtVec{1, 2, 3, 4} >= 3;
+  static_assert(F[0] == 0 && F[1] == 0 && F[2] == -1 && F[3] == -1, "");
+
+  constexpr auto G = FourCharsExtVec{1, 2, 3, 4} == 3;
+  static_assert(G[0] == 0 && G[1] == 0 && G[2] == -1 && G[3] == 0, "");
+
+  constexpr auto H = FourCharsExtVec{1, 2, 3, 4} != 3;
+  static_assert(H[0] == -1 && H[1] == -1 && H[2] == 0 && H[3] == -1, "");
+
+  constexpr auto H1 = FourCharsExtVec{-1, -1, 0, -1};
+  constexpr auto InvH = -H1;
   static_assert(InvH[0] == 1 && InvH[1] == 1 && InvH[2] == 0 && InvH[3] == 1, "");
 
   constexpr auto ae = ~FourCharsExtVec{1, 2, 10, 20};
@@ -39,6 +123,48 @@ void CharExtVecUsage() {
 }
 
 void FloatUsage() {
+  constexpr auto w = FourFloatsVecSize{1, 2, 3, 4} <
+                     FourFloatsVecSize{4, 3, 2, 1};
+  static_assert(w[0] == -1 && w[1] == -1 && w[2] == 0 && w[3] == 0, "");
+
+  constexpr auto x = FourFloatsVecSize{1, 2, 3, 4} >
+                     FourFloatsVecSize{4, 3, 2, 1};
+  static_assert(x[0] == 0 && x[1] == 0 && x[2] == -1 && x[3] == -1, "");
+
+  constexpr auto y = FourFloatsVecSize{1, 2, 3, 4} <=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(y[0] == -1 && y[1] == -1 && y[2] == -1 && y[3] == 0, "");
+
+  constexpr auto z = FourFloatsVecSize{1, 2, 3, 4} >=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(z[0] == 0 && z[1] == 0 && z[2] == -1 && z[3] == -1, "");
+
+  constexpr auto A = FourFloatsVecSize{1, 2, 3, 4} ==
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(A[0] == 0 && A[1] == 0 && A[2] == -1 && A[3] == 0, "");
+
+  constexpr auto B = FourFloatsVecSize{1, 2, 3, 4} !=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(B[0] == -1 && B[1] == -1 && B[2] == 0 && B[3] == -1, "");
+
+  constexpr auto C = FourFloatsVecSize{1, 2, 3, 4} < 3;
+  static_assert(C[0] == -1 && C[1] == -1 && C[2] == 0 && C[3] == 0, "");
+
+  constexpr auto D = FourFloatsVecSize{1, 2, 3, 4} > 3;
+  static_assert(D[0] == 0 && D[1] == 0 && D[2] == 0 && D[3] == -1, "");
+
+  constexpr auto E = FourFloatsVecSize{1, 2, 3, 4} <= 3;
+  static_assert(E[0] == -1 && E[1] == -1 && E[2] == -1 && E[3] == 0, "");
+
+  constexpr auto F = FourFloatsVecSize{1, 2, 3, 4} >= 3;
+  static_assert(F[0] == 0 && F[1] == 0 && F[2] == -1 && F[3] == -1, "");
+
+  constexpr auto G = FourFloatsVecSize{1, 2, 3, 4} == 3;
+  static_assert(G[0] == 0 && G[1] == 0 && G[2] == -1 && G[3] == 0, "");
+
+  constexpr auto H = FourFloatsVecSize{1, 2, 3, 4} != 3;
+  static_assert(H[0] == -1 && H[1] == -1 && H[2] == 0 && H[3] == -1, "");
+
   constexpr auto Y = FourFloatsVecSize{1.200000e+01, 1.700000e+01, -1.000000e+00, -1.000000e+00};
   constexpr auto Z = -Y;
   static_assert(Z[0] == -1.200000e+01 && Z[1] == -1.700000e+01 && Z[2] == 1.000000e+00 && Z[3] == 1.000000e+00, "");
@@ -51,6 +177,48 @@ void FloatUsage() {
 }
 
 void FloatVecUsage() {
+  constexpr auto w = FourFloatsVecSize{1, 2, 3, 4} <
+                     FourFloatsVecSize{4, 3, 2, 1};
+  static_assert(w[0] == -1 && w[1] == -1 && w[2] == 0 && w[3] == 0, "");
+
+  constexpr auto x = FourFloatsVecSize{1, 2, 3, 4} >
+                     FourFloatsVecSize{4, 3, 2, 1};
+  static_assert(x[0] == 0 && x[1] == 0 && x[2] == -1 && x[2] == -1, "");
+
+  constexpr auto y = FourFloatsVecSize{1, 2, 3, 4} <=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(y[0] == -1 && y[1] == -1 && y[2] == -1 && y[3] == 0, "");
+
+  constexpr auto z = FourFloatsVecSize{1, 2, 3, 4} >=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(z[0] == 0 && z[1] == 0 && z[2] == -1 && z[3] == -1, "");
+
+  constexpr auto A = FourFloatsVecSize{1, 2, 3, 4} ==
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(A[0] == 0 && A[1] == 0 && A[2] == -1 && A[3] == 0, "");
+
+  constexpr auto B = FourFloatsVecSize{1, 2, 3, 4} !=
+                     FourFloatsVecSize{4, 3, 3, 1};
+  static_assert(B[0] == -1 && B[1] == -1 && B[2] == 0 && B[3] == -1, "");
+
+  constexpr auto C = FourFloatsVecSize{1, 2, 3, 4} < 3;
+  static_assert(C[0] == -1 && C[1] == -1 && C[2] == 0 && C[3] == 0, "");
+
+  constexpr auto D = FourFloatsVecSize{1, 2, 3, 4} > 3;
+  static_assert(D[0] == 0 && D[1] == 0 && D[2] == 0 && D[3] == -1, "");
+
+  constexpr auto E = FourFloatsVecSize{1, 2, 3, 4} <= 3;
+  static_assert(E[0] == -1 && E[1] == -1 && E[2] == -1 && E[3] == 0, "");
+
+  constexpr auto F = FourFloatsVecSize{1, 2, 3, 4} >= 3;
+  static_assert(F[0] == 0 && F[1] == 0 && F[2] == -1 && F[3] == -1, "");
+
+  constexpr auto G = FourFloatsVecSize{1, 2, 3, 4} == 3;
+  static_assert(G[0] == 0 && G[1] == 0 && G[2] == -1 && G[3] == 0, "");
+
+  constexpr auto H = FourFloatsVecSize{1, 2, 3, 4} != 3;
+  static_assert(H[0] == -1 && H[1] == -1 && H[2] == 0 && H[3] == -1, "");
+
   constexpr auto Y = FourFloatsVecSize{1.200000e+01, 1.700000e+01, -1.000000e+00, -1.000000e+00};
   constexpr auto Z = -Y;
   static_assert(Z[0] == -1.200000e+01 && Z[1] == -1.700000e+01 && Z[2] == 1.000000e+00 && Z[3] == 1.000000e+00, "");
@@ -63,6 +231,12 @@ void FloatVecUsage() {
 }
 
 void I128Usage() {
+  constexpr auto a = FourI128VecSize{1, 2, 3, 4};
+  static_assert(a[0] == 1 && a[1] == 2 && a[2] == 3 && a[3] == 4, "");
+
+  constexpr auto b = a < 3;
+  static_assert(b[0] == -1 && b[1] == -1 && b[2] == 0 && b[3] == 0, "");
+
   // Operator ~ is illegal on floats, so no test for that.
   constexpr auto c = ~FourI128VecSize{1, 2, 10, 20};
    static_assert(c[0] == -2 && c[1] == -3 && c[2] == -11 && c[3] == -21, "");
@@ -72,6 +246,12 @@ void I128Usage() {
 }
 
 void I128VecUsage() {
+  constexpr auto a = FourI128ExtVec{1, 2, 3, 4};
+  static_assert(a[0] == 1 && a[1] == 2 && a[2] == 3 && a[3] == 4, "");
+
+  constexpr auto b = a < 3;
+  static_assert(b[0] == -1 && b[1] == -1 && b[2] == 0 && b[3] == 0, "");
+
   // Operator ~ is illegal on floats, so no test for that.
   constexpr auto c = ~FourI128ExtVec{1, 2, 10, 20};
   static_assert(c[0] == -2 && c[1] == -3 && c[2] == -11 && c[3] == -21, "");
@@ -82,6 +262,30 @@ void I128VecUsage() {
 
 using FourBoolsExtVec __attribute__((ext_vector_type(4))) = bool;
 void BoolVecUsage() {
+  constexpr auto a = FourBoolsExtVec{true, false, true, false} <
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(a[0] == false && a[1] == false && a[2] == false && a[3] == true, "");
+
+  constexpr auto b = FourBoolsExtVec{true, false, true, false} <=
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(b[0] == false && b[1] == true && b[2] == true && b[3] == true, "");
+
+  constexpr auto c = FourBoolsExtVec{true, false, true, false} ==
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(c[0] == false && c[1] == true && c[2] == true && c[3] == false, "");
+
+  constexpr auto d = FourBoolsExtVec{true, false, true, false} !=
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(d[0] == true && d[1] == false && d[2] == false && d[3] == true, "");
+
+  constexpr auto e = FourBoolsExtVec{true, false, true, false} >=
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(e[0] == true && e[1] == true && e[2] == true && e[3] == false, "");
+
+  constexpr auto f = FourBoolsExtVec{true, false, true, false} >
+                     FourBoolsExtVec{false, false, true, true};
+  static_assert(f[0] == true && f[1] == false && f[2] == false && f[3] == false, "");
+
   constexpr auto j = !FourBoolsExtVec{true, false, true, false};
   static_assert(j[0] == false && j[1] == true && j[2] == false && j[3] == true, "");
 

>From e0148f41fa34fc6f1f7b31af07060e1e7df6c1cc Mon Sep 17 00:00:00 2001
From: yronglin <yronglin777 at gmail.com>
Date: Thu, 5 Sep 2024 21:03:38 +0800
Subject: [PATCH 2/3] Remove unnecessary emitVectorComparsion

Signed-off-by: yronglin <yronglin777 at gmail.com>
---
 clang/lib/AST/ByteCode/Compiler.cpp | 90 +++++++++++++----------------
 clang/lib/AST/ByteCode/Compiler.h   |  1 -
 2 files changed, 39 insertions(+), 52 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index c3fb812971ac52..22134725c4dfad 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -1227,6 +1227,8 @@ bool Compiler<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
 template <class Emitter>
 bool Compiler<Emitter>::VisitVectorBinOp(const BinaryOperator *E) {
   assert(E->getType()->isVectorType());
+  assert(E->getLHS()->getType()->isVectorType());
+  assert(E->getRHS()->getType()->isVectorType());
 
   // FIXME: Current only support comparison binary operator, add support for
   // other binary operator.
@@ -1274,7 +1276,43 @@ bool Compiler<Emitter>::VisitVectorBinOp(const BinaryOperator *E) {
     if (!getElem(RHSOffset, I, ElemT))
       return false;
     if (E->isComparisonOp()) {
-      if (!this->emitVectorComparison(E))
+      switch (E->getOpcode()) {
+      case BO_EQ:
+        if (!this->emitEQ(ElemT, E))
+          return false;
+        break;
+      case BO_NE:
+        if (!this->emitNE(ElemT, E))
+          return false;
+        break;
+      case BO_LE:
+        if (!this->emitLE(ElemT, E))
+          return false;
+        break;
+      case BO_LT:
+        if (!this->emitLT(ElemT, E))
+          return false;
+        break;
+      case BO_GE:
+        if (!this->emitGE(ElemT, E))
+          return false;
+        break;
+      case BO_GT:
+        if (!this->emitGT(ElemT, E))
+          return false;
+        break;
+      default:
+        llvm_unreachable("Unsupported binary operator");
+      }
+
+      // The result of the comparison is a vector of the same width and number
+      // of elements as the comparison operands with a signed integral element
+      // type.
+      //
+      // https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
+      if (!this->emitPrimCast(PT_Bool, ResultElemT, VecTy->getElementType(), E))
+        return false;
+      if (!this->emitNeg(ResultElemT, E))
         return false;
     } else {
       llvm_unreachable("Unsupported binary operator");
@@ -1285,56 +1323,6 @@ bool Compiler<Emitter>::VisitVectorBinOp(const BinaryOperator *E) {
   return true;
 }
 
-template <class Emitter>
-bool Compiler<Emitter>::emitVectorComparison(const BinaryOperator *E) {
-  assert(E->isComparisonOp());
-  assert(!DiscardResult);
-  const auto *VecTy = E->getType()->getAs<VectorType>();
-
-  // The LHS and RHS of a comparison operator must have the same type. So we
-  // just use LHS vector element type here.
-  PrimType LHSElemT = this->classifyVectorElementType(E->getLHS()->getType());
-  PrimType ResultElemT = this->classifyVectorElementType(E->getType());
-  switch (E->getOpcode()) {
-  case BO_EQ:
-    if (!this->emitEQ(LHSElemT, E))
-      return false;
-    break;
-  case BO_NE:
-    if (!this->emitNE(LHSElemT, E))
-      return false;
-    break;
-  case BO_LE:
-    if (!this->emitLE(LHSElemT, E))
-      return false;
-    break;
-  case BO_LT:
-    if (!this->emitLT(LHSElemT, E))
-      return false;
-    break;
-  case BO_GE:
-    if (!this->emitGE(LHSElemT, E))
-      return false;
-    break;
-  case BO_GT:
-    if (!this->emitGT(LHSElemT, E))
-      return false;
-    break;
-  default:
-    llvm_unreachable("Unsupported binary operator");
-  }
-
-  // The result of the comparison is a vector of the same width and number of
-  // elements as the comparison operands with a signed integral element type.
-  //
-  // https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
-  if (!this->emitPrimCast(PT_Bool, ResultElemT, VecTy->getElementType(), E))
-    return false;
-  if (!this->emitNeg(ResultElemT, E))
-    return false;
-  return true;
-}
-
 template <class Emitter>
 bool Compiler<Emitter>::VisitImplicitValueInitExpr(
     const ImplicitValueInitExpr *E) {
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index c9dacbc76d7c4f..e6f54fe05427b7 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -364,7 +364,6 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
   bool emitComplexBoolCast(const Expr *E);
   bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
                              const BinaryOperator *E);
-  bool emitVectorComparison(const BinaryOperator *E);
   bool emitRecordDestruction(const Record *R);
   bool emitDestruction(const Descriptor *Desc);
   unsigned collectBaseOffset(const QualType BaseType,

>From 69ac7b1e83f5964a8326350852a856bec7fc795a Mon Sep 17 00:00:00 2001
From: yronglin <yronglin777 at gmail.com>
Date: Thu, 5 Sep 2024 21:07:30 +0800
Subject: [PATCH 3/3] Refactor

Signed-off-by: yronglin <yronglin777 at gmail.com>
---
 clang/lib/AST/ByteCode/Compiler.cpp | 72 ++++++++++++++---------------
 1 file changed, 36 insertions(+), 36 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 22134725c4dfad..8218f2a11d48c7 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -1275,48 +1275,48 @@ bool Compiler<Emitter>::VisitVectorBinOp(const BinaryOperator *E) {
       return false;
     if (!getElem(RHSOffset, I, ElemT))
       return false;
-    if (E->isComparisonOp()) {
-      switch (E->getOpcode()) {
-      case BO_EQ:
-        if (!this->emitEQ(ElemT, E))
-          return false;
-        break;
-      case BO_NE:
-        if (!this->emitNE(ElemT, E))
-          return false;
-        break;
-      case BO_LE:
-        if (!this->emitLE(ElemT, E))
-          return false;
-        break;
-      case BO_LT:
-        if (!this->emitLT(ElemT, E))
-          return false;
-        break;
-      case BO_GE:
-        if (!this->emitGE(ElemT, E))
-          return false;
-        break;
-      case BO_GT:
-        if (!this->emitGT(ElemT, E))
-          return false;
-        break;
-      default:
-        llvm_unreachable("Unsupported binary operator");
-      }
+    switch (E->getOpcode()) {
+    case BO_EQ:
+      if (!this->emitEQ(ElemT, E))
+        return false;
+      break;
+    case BO_NE:
+      if (!this->emitNE(ElemT, E))
+        return false;
+      break;
+    case BO_LE:
+      if (!this->emitLE(ElemT, E))
+        return false;
+      break;
+    case BO_LT:
+      if (!this->emitLT(ElemT, E))
+        return false;
+      break;
+    case BO_GE:
+      if (!this->emitGE(ElemT, E))
+        return false;
+      break;
+    case BO_GT:
+      if (!this->emitGT(ElemT, E))
+        return false;
+      break;
+    default:
+      llvm_unreachable("Unsupported binary operator");
+    }
 
-      // The result of the comparison is a vector of the same width and number
-      // of elements as the comparison operands with a signed integral element
-      // type.
-      //
-      // https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
+    // The result of the comparison is a vector of the same width and number
+    // of elements as the comparison operands with a signed integral element
+    // type.
+    //
+    // https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
+    if (E->isComparisonOp()) {
       if (!this->emitPrimCast(PT_Bool, ResultElemT, VecTy->getElementType(), E))
         return false;
       if (!this->emitNeg(ResultElemT, E))
         return false;
-    } else {
-      llvm_unreachable("Unsupported binary operator");
     }
+
+    // Initialize array element with the value we just computed.
     if (!this->emitInitElem(ResultElemT, I, E))
       return false;
   }



More information about the cfe-commits mailing list