[llvm] a03e16a - [Hexagon] Improve idioms for fixed-point vector multiplication

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 12 09:11:49 PST 2022


Author: Krzysztof Parzyszek
Date: 2022-11-12T08:46:27-08:00
New Revision: a03e16a16e4db9d297d86102ca49e4a7506cb803

URL: https://github.com/llvm/llvm-project/commit/a03e16a16e4db9d297d86102ca49e4a7506cb803
DIFF: https://github.com/llvm/llvm-project/commit/a03e16a16e4db9d297d86102ca49e4a7506cb803.diff

LOG: [Hexagon] Improve idioms for fixed-point vector multiplication

Make the code more general, restore the ability to generate V6_vmpyhvsrs.

Added: 
    

Modified: 
    llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
    llvm/test/CodeGen/Hexagon/autohvx/qmul-chop.ll
    llvm/test/CodeGen/Hexagon/autohvx/qmul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
index 279ce3483ee68..59076ca004f9d 100644
--- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp
@@ -78,7 +78,7 @@ class HexagonVectorCombine {
   // element count.
   Type *getBoolTy(int ElemCount = 0) const;
   // Create a ConstantInt of type returned by getIntTy with the value Val.
-  ConstantInt *getConstInt(int Val) const;
+  ConstantInt *getConstInt(int Val, unsigned Width = 32) const;
   // Get the integer value of V, if it exists.
   std::optional<APInt> getIntValue(const Value *Val) const;
   // Is V a constant 0, or a vector of 0s?
@@ -103,6 +103,8 @@ class HexagonVectorCombine {
   Constant *getFullValue(Type *Ty) const;
   Constant *getConstSplat(Type *Ty, int Val) const;
 
+  Value *simplify(Value *Val) const;
+
   Value *insertb(IRBuilderBase &Builder, Value *Dest, Value *Src, int Start,
                  int Length, int Where) const;
   Value *vlalignb(IRBuilderBase &Builder, Value *Lo, Value *Hi,
@@ -341,32 +343,49 @@ class HvxIdioms {
   bool run();
 
 private:
-  struct FxpOp {
-    unsigned Opcode;
-    unsigned Frac; // Number of fraction bits
-    Value *X, *Y;
-    // If present, add 1 << RoundAt before shift:
-    std::optional<unsigned> RoundAt;
-  };
+  enum Signedness { Positive, Signed, Unsigned };
 
   // Value + sign
-  // This is to distinguish multiplications: s*s, s*u, u*s, u*u.
+  // This is to keep track of whether the value should be treated as signed
+  // or unsigned, or is known to be positive.
   struct SValue {
     Value *Val;
-    bool Signed;
+    Signedness Sgn;
   };
 
-  std::optional<FxpOp> matchFxpMul(Instruction &In) const;
-  Value *processFxpMul(Instruction &In, const FxpOp &Op) const;
+  struct FxpOp {
+    unsigned Opcode;
+    unsigned Frac; // Number of fraction bits
+    SValue X, Y;
+    // If present, add 1 << RoundAt before shift:
+    std::optional<unsigned> RoundAt;
+  };
 
-  Value *processFxpMulChopped(IRBuilderBase &Builder, Instruction &In,
-                              const FxpOp &Op) const;
-  Value *createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y,
-                      bool Rounding) const;
-  Value *createMulQ31(IRBuilderBase &Builder, Value *X, Value *Y,
-                      bool Rounding) const;
-  std::pair<Value *, Value *> createMul32(IRBuilderBase &Builder, SValue X,
-                                          SValue Y) const;
+  auto getNumSignificantBits(Value *V, Instruction *In) const
+      -> std::pair<unsigned, Signedness>;
+  auto canonSgn(SValue X, SValue Y) const -> std::pair<SValue, SValue>;
+
+  auto matchFxpMul(Instruction &In) const -> std::optional<FxpOp>;
+  auto processFxpMul(Instruction &In, const FxpOp &Op) const -> Value *;
+
+  auto processFxpMulChopped(IRBuilderBase &Builder, Instruction &In,
+                            const FxpOp &Op) const -> Value *;
+  auto createMulQ15(IRBuilderBase &Builder, SValue X, SValue Y,
+                    bool Rounding) const -> Value *;
+  auto createMulQ31(IRBuilderBase &Builder, SValue X, SValue Y,
+                    bool Rounding) const -> Value *;
+  // Return {Result, Carry}, where Carry is a vector predicate.
+  auto createAddCarry(IRBuilderBase &Builder, Value *X, Value *Y,
+                      Value *CarryIn = nullptr) const
+      -> std::pair<Value *, Value *>;
+  auto createMul16(IRBuilderBase &Builder, SValue X, SValue Y) const -> Value *;
+  auto createMul32(IRBuilderBase &Builder, SValue X, SValue Y) const
+      -> std::pair<Value *, Value *>;
+  auto createAddLong(IRBuilderBase &Builder, ArrayRef<Value *> WordX,
+                     ArrayRef<Value *> WordY) const -> SmallVector<Value *>;
+  auto createMulLong(IRBuilderBase &Builder, ArrayRef<Value *> WordX,
+                     Signedness SgnX, ArrayRef<Value *> WordY,
+                     Signedness SgnY) const -> SmallVector<Value *>;
 
   VectorType *HvxI32Ty;
   VectorType *HvxP32Ty;
@@ -377,6 +396,7 @@ class HvxIdioms {
 
 [[maybe_unused]] raw_ostream &operator<<(raw_ostream &OS,
                                          const HvxIdioms::FxpOp &Op) {
+  static const char *SgnNames[] = {"Positive", "Signed", "Unsigned"};
   OS << Instruction::getOpcodeName(Op.Opcode) << '.' << Op.Frac;
   if (Op.RoundAt.has_value()) {
     if (Op.Frac != 0 && Op.RoundAt.value() == Op.Frac - 1) {
@@ -385,7 +405,8 @@ class HvxIdioms {
       OS << " + 1<<" << Op.RoundAt.value();
     }
   }
-  OS << "\n  X:" << *Op.X << "\n  Y:" << *Op.Y;
+  OS << "\n  X:(" << SgnNames[Op.X.Sgn] << ") " << *Op.X.Val << "\n"
+     << "  Y:(" << SgnNames[Op.Y.Sgn] << ") " << *Op.Y.Val;
   return OS;
 }
 
@@ -1181,6 +1202,47 @@ auto AlignVectors::run() -> bool {
 
 // --- Begin HvxIdioms
 
+auto HvxIdioms::getNumSignificantBits(Value *V, Instruction *In) const
+    -> std::pair<unsigned, Signedness> {
+  unsigned Bits = HVC.getNumSignificantBits(V, In);
+  // The significant bits are calculated including the sign bit. This may
+  // add an extra bit for zero-extended values, e.g. (zext i32 to i64) may
+  // result in 33 significant bits. To avoid extra words, skip the extra
+  // sign bit, but keep information that the value is to be treated as
+  // unsigned.
+  KnownBits Known = HVC.getKnownBits(V, In);
+  Signedness Sign = Signed;
+  unsigned NumToTest = 0; // Number of bits used in test for unsignedness.
+  if (isPowerOf2_32(Bits))
+    NumToTest = Bits;
+  else if (Bits > 1 && isPowerOf2_32(Bits - 1))
+    NumToTest = Bits - 1;
+
+  if (NumToTest != 0 && Known.Zero.ashr(NumToTest).isAllOnes()) {
+    Sign = Unsigned;
+    Bits = NumToTest;
+  }
+
+  // If the top bit of the nearest power-of-2 is zero, this value is
+  // positive. It could be treated as either signed or unsigned.
+  if (unsigned Pow2 = PowerOf2Ceil(Bits); Pow2 != Bits) {
+    if (Known.Zero.ashr(Pow2 - 1).isAllOnes())
+      Sign = Positive;
+  }
+  return {Bits, Sign};
+}
+
+auto HvxIdioms::canonSgn(SValue X, SValue Y) const
+    -> std::pair<SValue, SValue> {
+  // Canonicalize the signedness of X and Y, so that the result is one of:
+  //   S, S
+  //   U/P, S
+  //   U/P, U/P
+  if (X.Sgn == Signed && Y.Sgn != Signed)
+    std::swap(X, Y);
+  return {X, Y};
+}
+
 // Match
 //   (X * Y) [>> N], or
 //   ((X * Y) + (1 << N-1)) >> N
@@ -1225,8 +1287,11 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
   }
 
   // Check if the rest is a multiplication.
-  if (match(Exp, m_Mul(m_Value(Op.X), m_Value(Op.Y)))) {
+  if (match(Exp, m_Mul(m_Value(Op.X.Val), m_Value(Op.Y.Val)))) {
     Op.Opcode = Instruction::Mul;
+    // FIXME: The information below is recomputed.
+    Op.X.Sgn = getNumSignificantBits(Op.X.Val, &In).second;
+    Op.Y.Sgn = getNumSignificantBits(Op.Y.Val, &In).second;
     return Op;
   }
 
@@ -1235,218 +1300,160 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
 
 auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const
     -> Value * {
-  assert(Op.X->getType() == Op.Y->getType());
+  assert(Op.X.Val->getType() == Op.Y.Val->getType());
 
-  auto *VecTy = cast<VectorType>(Op.X->getType());
+  auto *VecTy = dyn_cast<VectorType>(Op.X.Val->getType());
+  if (VecTy == nullptr)
+    return nullptr;
   auto *ElemTy = cast<IntegerType>(VecTy->getElementType());
   unsigned ElemWidth = ElemTy->getBitWidth();
-  if (ElemWidth < 8 || !isPowerOf2_32(ElemWidth))
-    return nullptr;
 
-  unsigned VecLen = HVC.length(VecTy);
-  unsigned HvxLen = (8 * HVC.HST.getVectorLength()) / std::min(ElemWidth, 32u);
-  if (VecLen % HvxLen != 0)
+  // TODO: This can be relaxed after legalization is done pre-isel.
+  if ((HVC.length(VecTy) * ElemWidth) % (8 * HVC.HST.getVectorLength()) != 0)
     return nullptr;
 
-  // FIXME: handle 8-bit multiplications
-  if (ElemWidth < 16)
+  // There are no special intrinsics that should be used for multiplying
+  // signed 8-bit values, so just skip them. Normal codegen should handle
+  // this just fine.
+  if (ElemWidth <= 8)
+    return nullptr;
+  // Similarly, if this is just a multiplication that can be handled without
+  // intervention, then leave it alone.
+  if (ElemWidth <= 32 && Op.Frac == 0)
     return nullptr;
 
-  SmallVector<Value *> Results;
-  FxpOp ChopOp;
-  ChopOp.Opcode = Op.Opcode;
-  ChopOp.Frac = Op.Frac;
-  ChopOp.RoundAt = Op.RoundAt;
+  auto [BitsX, SignX] = getNumSignificantBits(Op.X.Val, &In);
+  auto [BitsY, SignY] = getNumSignificantBits(Op.Y.Val, &In);
+
+  // TODO: Add multiplication of vectors by scalar registers (up to 4 bytes).
+
+  Value *X = Op.X.Val, *Y = Op.Y.Val;
+  IRBuilder Builder(In.getParent(), In.getIterator(),
+                    InstSimplifyFolder(HVC.DL));
 
-  IRBuilder<InstSimplifyFolder> Builder(In.getParent(), In.getIterator(),
-                                        InstSimplifyFolder(HVC.DL));
+  auto roundUpWidth = [](unsigned Width) -> unsigned {
+    if (Width <= 32 && !isPowerOf2_32(Width)) {
+      // If the element width is not a power of 2, round it up
+      // to the next one. Do this for widths not exceeding 32.
+      return PowerOf2Ceil(Width);
+    }
+    if (Width > 32 && Width % 32 != 0) {
+      // For wider elements, round it up to the multiple of 32.
+      return alignTo(Width, 32u);
+    }
+    return Width;
+  };
+
+  BitsX = roundUpWidth(BitsX);
+  BitsY = roundUpWidth(BitsY);
+
+  // For elementwise multiplication vectors must have the same lengths, so
+  // resize the elements of both inputs to the same width, the max of the
+  // calculated significant bits.
+  unsigned Width = std::max(BitsX, BitsY);
+
+  auto *ResizeTy = VectorType::get(HVC.getIntTy(Width), VecTy);
+  if (Width < ElemWidth) {
+    X = Builder.CreateTrunc(X, ResizeTy);
+    Y = Builder.CreateTrunc(Y, ResizeTy);
+  } else if (Width > ElemWidth) {
+    X = SignX == Signed ? Builder.CreateSExt(X, ResizeTy)
+                        : Builder.CreateZExt(X, ResizeTy);
+    Y = SignY == Signed ? Builder.CreateSExt(Y, ResizeTy)
+                        : Builder.CreateZExt(Y, ResizeTy);
+  };
 
-  for (unsigned V = 0; V != VecLen / HvxLen; ++V) {
-    ChopOp.X = HVC.subvector(Builder, Op.X, V * HvxLen, HvxLen);
-    ChopOp.Y = HVC.subvector(Builder, Op.Y, V * HvxLen, HvxLen);
+  assert(X->getType() == Y->getType() && X->getType() == ResizeTy);
+
+  unsigned VecLen = HVC.length(ResizeTy);
+  unsigned ChopLen = (8 * HVC.HST.getVectorLength()) / std::min(Width, 32u);
+
+  SmallVector<Value *> Results;
+  FxpOp ChopOp = Op;
+
+  for (unsigned V = 0; V != VecLen / ChopLen; ++V) {
+    ChopOp.X.Val = HVC.subvector(Builder, X, V * ChopLen, ChopLen);
+    ChopOp.Y.Val = HVC.subvector(Builder, Y, V * ChopLen, ChopLen);
     Results.push_back(processFxpMulChopped(Builder, In, ChopOp));
     if (Results.back() == nullptr)
       break;
   }
 
-  if (Results.back() == nullptr) {
-    // FIXME: clean up leftover instructions
+  if (Results.back() == nullptr)
     return nullptr;
-  }
 
-  return HVC.concat(Builder, Results);
+  Value *Cat = HVC.concat(Builder, Results);
+  Value *Ext = SignX == Signed || SignY == Signed
+                   ? Builder.CreateSExt(Cat, VecTy)
+                   : Builder.CreateZExt(Cat, VecTy);
+  return Ext;
 }
 
 auto HvxIdioms::processFxpMulChopped(IRBuilderBase &Builder, Instruction &In,
                                      const FxpOp &Op) const -> Value * {
-  // FIXME: make this more elegant
-  struct TempValues {
-    void insert(Value *V) { //
-      Values.push_back(V);
-    }
-    void insert(ArrayRef<Value *> Vs) {
-      Values.insert(Values.end(), Vs.begin(), Vs.end());
-    }
-    void clear() { //
-      Values.clear();
-    }
-    ~TempValues() {
-      for (Value *V : llvm::reverse(Values)) {
-        if (auto *In = dyn_cast<Instruction>(V))
-          In->eraseFromParent();
-      }
-    }
-    SmallVector<Value *> Values;
-  };
-  TempValues DeleteOnFailure;
-
-  // TODO: Make it general.
-  // if (Op.Frac != 15 && Op.Frac != 31)
-  //  return nullptr;
-
-  enum Signedness { Positive, Signed, Unsigned };
-  auto getNumSignificantBits =
-      [this, &In](Value *V) -> std::pair<unsigned, Signedness> {
-    unsigned Bits = HVC.getNumSignificantBits(V, &In);
-    // The significant bits are calculated including the sign bit. This may
-    // add an extra bit for zero-extended values, e.g. (zext i32 to i64) may
-    // result in 33 significant bits. To avoid extra words, skip the extra
-    // sign bit, but keep information that the value is to be treated as
-    // unsigned.
-    KnownBits Known = HVC.getKnownBits(V, &In);
-    Signedness Sign = Signed;
-    if (Bits > 1 && isPowerOf2_32(Bits - 1)) {
-      if (Known.Zero.ashr(Bits - 1).isAllOnes()) {
-        Sign = Unsigned;
-        Bits--;
-      }
-    }
-    // If the top bit of the nearest power-of-2 is zero, this value is
-    // positive. It could be treated as either signed or unsigned.
-    if (unsigned Pow2 = PowerOf2Ceil(Bits); Pow2 != Bits) {
-      if (Known.Zero.ashr(Pow2 - 1).isAllOnes())
-        Sign = Positive;
-    }
-    return {Bits, Sign};
-  };
-
-  auto *OrigTy = dyn_cast<VectorType>(Op.X->getType());
-  if (OrigTy == nullptr)
-    return nullptr;
-
-  auto [BitsX, SignX] = getNumSignificantBits(Op.X);
-  auto [BitsY, SignY] = getNumSignificantBits(Op.Y);
-  unsigned Width = PowerOf2Ceil(std::max(BitsX, BitsY));
+  assert(Op.X.Val->getType() == Op.Y.Val->getType());
+  auto *InpTy = cast<VectorType>(Op.X.Val->getType());
+  unsigned Width = InpTy->getScalarSizeInBits();
+  bool Rounding = Op.RoundAt.has_value();
 
   if (!Op.RoundAt || *Op.RoundAt == Op.Frac - 1) {
-    bool Rounding = Op.RoundAt.has_value();
     // The fixed-point intrinsics do signed multiplication.
-    if (Width == Op.Frac + 1 && SignX != Unsigned && SignY != Unsigned) {
-      auto *TruncTy = VectorType::get(HVC.getIntTy(Width), OrigTy);
-      Value *TruncX = Builder.CreateTrunc(Op.X, TruncTy);
-      Value *TruncY = Builder.CreateTrunc(Op.Y, TruncTy);
+    if (Width == Op.Frac + 1 && Op.X.Sgn != Unsigned && Op.Y.Sgn != Unsigned) {
       Value *QMul = nullptr;
       if (Width == 16) {
-        QMul = createMulQ15(Builder, TruncX, TruncY, Rounding);
+        QMul = createMulQ15(Builder, Op.X, Op.Y, Rounding);
       } else if (Width == 32) {
-        QMul = createMulQ31(Builder, TruncX, TruncY, Rounding);
+        QMul = createMulQ31(Builder, Op.X, Op.Y, Rounding);
       }
       if (QMul != nullptr)
-        return Builder.CreateSExt(QMul, OrigTy);
-
-      if (TruncX != Op.X && isa<Instruction>(TruncX))
-        cast<Instruction>(TruncX)->eraseFromParent();
-      if (TruncY != Op.Y && isa<Instruction>(TruncY))
-        cast<Instruction>(TruncY)->eraseFromParent();
+        return QMul;
     }
   }
 
-  // FIXME: make it general, _64, addcarry
-  if (!HVC.HST.useHVXV62Ops())
-    return nullptr;
-
-  // FIXME: make it general
-  if (OrigTy->getScalarSizeInBits() < 32)
-    return nullptr;
+  assert(Width >= 32 || isPowerOf2_32(Width)); // Width <= 32 => Width is 2^n
+  assert(Width < 32 || Width % 32 == 0);       // Width > 32 => Width is 32*k
+
+  // If Width < 32, then it should really be 16.
+  if (Width < 32) {
+    if (Width < 16)
+      return nullptr;
+    // Getting here with Op.Frac == 0 isn't wrong, but suboptimal: here we
+    // generate a full precision products, which is unnecessary if there is
+    // no shift.
+    assert(Op.Frac != 0 && "Unshifted mul should have been skipped");
+    Value *Prod32 = createMul16(Builder, Op.X, Op.Y);
+    if (Rounding) {
+      Value *RoundVal = HVC.getConstSplat(Prod32->getType(), 1 << *Op.RoundAt);
+      Prod32 = Builder.CreateAdd(Prod32, RoundVal);
+    }
 
-  if (Width > 64)
-    return nullptr;
+    Value *ShiftAmt = HVC.getConstSplat(Prod32->getType(), Op.Frac);
+    Value *Shifted = Op.X.Sgn == Signed || Op.Y.Sgn == Signed
+               ? Builder.CreateAShr(Prod32, ShiftAmt)
+               : Builder.CreateLShr(Prod32, ShiftAmt);
+    return Builder.CreateTrunc(Shifted, InpTy);
+  }
 
-  // At this point, NewX and NewY may be truncated to 
diff erent element
-  // widths to save on the number of multiplications to perform.
-  unsigned WidthX =
-      PowerOf2Ceil(std::max(BitsX, 32u)); // FIXME: handle shorter ones
-  unsigned WidthY = PowerOf2Ceil(std::max(BitsY, 32u));
-  Value *NewX = Builder.CreateTrunc(
-      Op.X, VectorType::get(HVC.getIntTy(WidthX), HVC.length(Op.X), false));
-  Value *NewY = Builder.CreateTrunc(
-      Op.Y, VectorType::get(HVC.getIntTy(WidthY), HVC.length(Op.Y), false));
-  if (NewX != Op.X)
-    DeleteOnFailure.insert(NewX);
-  if (NewY != Op.Y)
-    DeleteOnFailure.insert(NewY);
-
-  // Break up the arguments NewX and NewY into vectors of smaller widths
-  // in preparation of doing the multiplication via HVX intrinsics.
-  // TODO:
-  // Make sure that the number of elements in NewX/NewY is 32. In the future
-  // add generic code that will break up a (presumable long) vector into
-  // shorter pieces, pad the last one, then concatenate all the pieces back.
-  if (HVC.length(NewX) != 32)
-    return nullptr;
-  auto WordX = HVC.splitVectorElements(Builder, NewX, /*ToWidth=*/32);
-  auto WordY = HVC.splitVectorElements(Builder, NewY, /*ToWidth=*/32);
-  auto HvxWordTy = WordX[0]->getType();
+  // Width >= 32
 
-  SmallVector<SmallVector<Value *>> Products(WordX.size() + WordY.size());
+  // Break up the arguments Op.X and Op.Y into vectors of smaller widths
+  // in preparation of doing the multiplication by 32-bit parts.
+  auto WordX = HVC.splitVectorElements(Builder, Op.X.Val, /*ToWidth=*/32);
+  auto WordY = HVC.splitVectorElements(Builder, Op.Y.Val, /*ToWidth=*/32);
+  auto WordP = createMulLong(Builder, WordX, Op.X.Sgn, WordY, Op.Y.Sgn);
 
-  // WordX[i] * WordY[j] produces words i+j and i+j+1 of the results,
-  // that is halves 2(i+j), 2(i+j)+1, 2(i+j)+2, 2(i+j)+3.
-  for (int i = 0, e = WordX.size(); i != e; ++i) {
-    for (int j = 0, f = WordY.size(); j != f; ++j) {
-      bool SgnX = (i + 1 == e) && SignX != Unsigned;
-      bool SgnY = (j + 1 == f) && SignY != Unsigned;
-      auto [Lo, Hi] = createMul32(Builder, {WordX[i], SgnX}, {WordY[j], SgnY});
-      Products[i + j + 0].push_back(Lo);
-      Products[i + j + 1].push_back(Hi);
-    }
-  }
+  auto *HvxWordTy = cast<VectorType>(WordP.front()->getType());
 
   // Add the optional rounding to the proper word.
   if (Op.RoundAt.has_value()) {
-    Products[*Op.RoundAt / 32].push_back(
-        HVC.getConstSplat(HvxWordTy, 1 << (*Op.RoundAt % 32)));
+    Value *Zero = HVC.getNullValue(WordX[0]->getType());
+    SmallVector<Value *> RoundV(WordP.size(), Zero);
+    RoundV[*Op.RoundAt / 32] =
+        HVC.getConstSplat(HvxWordTy, 1 << (*Op.RoundAt % 32));
+    WordP = createAddLong(Builder, WordP, RoundV);
   }
 
-  auto V6_vaddcarry = HVC.HST.getIntrinsicId(Hexagon::V6_vaddcarry);
-  Value *NoCarry = HVC.getNullValue(HVC.getBoolTy(HVC.length(HvxWordTy)));
-  auto pop_back_or_zero = [this, HvxWordTy](auto &Vector) -> Value * {
-    if (Vector.empty())
-      return HVC.getNullValue(HvxWordTy);
-    auto Last = Vector.back();
-    Vector.pop_back();
-    return Last;
-  };
-
-  for (int i = 0, e = Products.size(); i != e; ++i) {
-    while (Products[i].size() > 1) {
-      Value *Carry = NoCarry;
-      for (int j = i; j != e; ++j) {
-        auto &ProdJ = Products[j];
-        Value *Ret = HVC.createHvxIntrinsic(
-            Builder, V6_vaddcarry, nullptr,
-            {pop_back_or_zero(ProdJ), pop_back_or_zero(ProdJ), Carry});
-        ProdJ.insert(ProdJ.begin(), Builder.CreateExtractValue(Ret, {0}));
-        Carry = Builder.CreateExtractValue(Ret, {1});
-      }
-    }
-  }
-
-  SmallVector<Value *> WordP;
-  for (auto &P : Products) {
-    assert(P.size() == 1 && "Should have been added together");
-    WordP.push_back(P.front());
-  }
+  // createRightShiftLong?
 
   // Shift all products right by Op.Frac.
   unsigned SkipWords = Op.Frac / 32;
@@ -1467,79 +1474,125 @@ auto HvxIdioms::processFxpMulChopped(IRBuilderBase &Builder, Instruction &In,
   if (SkipWords != 0)
     WordP.resize(WordP.size() - SkipWords);
 
-  DeleteOnFailure.clear();
-  Value *Ret = HVC.joinVectorElements(Builder, WordP, OrigTy);
-  return Ret;
+  return HVC.joinVectorElements(Builder, WordP, InpTy);
 }
 
-auto HvxIdioms::createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y,
+auto HvxIdioms::createMulQ15(IRBuilderBase &Builder, SValue X, SValue Y,
                              bool Rounding) const -> Value * {
-  assert(X->getType() == Y->getType());
-  assert(X->getType()->getScalarType() == HVC.getIntTy(16));
-  if (!HVC.HST.isHVXVectorType(EVT::getEVT(X->getType(), false)))
+  assert(X.Val->getType() == Y.Val->getType());
+  assert(X.Val->getType()->getScalarType() == HVC.getIntTy(16));
+  assert(HVC.HST.isHVXVectorType(EVT::getEVT(X.Val->getType(), false)));
+
+  // There is no non-rounding intrinsic for i16.
+  if (!Rounding || X.Sgn == Unsigned || Y.Sgn == Unsigned)
     return nullptr;
 
-  unsigned HwLen = HVC.HST.getVectorLength();
+  auto V6_vmpyhvsrs = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhvsrs);
+  return HVC.createHvxIntrinsic(Builder, V6_vmpyhvsrs, X.Val->getType(),
+                                {X.Val, Y.Val});
+}
 
-  if (Rounding) {
-    auto V6_vmpyhvsrs = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhvsrs);
-    return HVC.createHvxIntrinsic(Builder, V6_vmpyhvsrs, X->getType(), {X, Y});
-  }
-  // No rounding, do i16*i16 -> i32, << 1, take upper half.
-  auto V6_vmpyhv = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhv);
+auto HvxIdioms::createMulQ31(IRBuilderBase &Builder, SValue X, SValue Y,
+                             bool Rounding) const -> Value * {
+  Type *InpTy = X.Val->getType();
+  assert(InpTy == Y.Val->getType());
+  assert(InpTy->getScalarType() == HVC.getIntTy(32));
+  assert(HVC.HST.isHVXVectorType(EVT::getEVT(InpTy, false)));
 
-  // i16*i16 -> i32 / interleaved
-  Value *V1 = HVC.createHvxIntrinsic(Builder, V6_vmpyhv, HvxP32Ty, {X, Y});
-  // <<1
-  Value *V2 = Builder.CreateAdd(V1, V1);
-  // i32 -> i32 deinterleave
-  SmallVector<int, 64> DeintMask;
-  for (int i = 0; i != static_cast<int>(HwLen) / 4; ++i) {
-    DeintMask.push_back(i);
-    DeintMask.push_back(i + HwLen / 4);
+  if (X.Sgn == Unsigned || Y.Sgn == Unsigned)
+    return nullptr;
+
+  auto V6_vmpyewuh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyewuh);
+  auto V6_vmpyo_acc = Rounding
+                          ? HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_rnd_sacc)
+                          : HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_sacc);
+  Value *V1 =
+      HVC.createHvxIntrinsic(Builder, V6_vmpyewuh, InpTy, {X.Val, Y.Val});
+  return HVC.createHvxIntrinsic(Builder, V6_vmpyo_acc, InpTy,
+                                {V1, X.Val, Y.Val});
+}
+
+auto HvxIdioms::createAddCarry(IRBuilderBase &Builder, Value *X, Value *Y,
+                               Value *CarryIn) const
+    -> std::pair<Value *, Value *> {
+  assert(X->getType() == Y->getType());
+  auto VecTy = cast<VectorType>(X->getType());
+  if (VecTy == HvxI32Ty && HVC.HST.useHVXV62Ops()) {
+    SmallVector<Value *> Args = {X, Y};
+    Intrinsic::ID AddCarry;
+    if (CarryIn == nullptr && HVC.HST.useHVXV66Ops()) {
+      AddCarry = HVC.HST.getIntrinsicId(Hexagon::V6_vaddcarryo);
+    } else {
+      AddCarry = HVC.HST.getIntrinsicId(Hexagon::V6_vaddcarry);
+      if (CarryIn == nullptr)
+        CarryIn = HVC.getNullValue(HVC.getBoolTy(HVC.length(VecTy)));
+      Args.push_back(CarryIn);
+    }
+    Value *Ret = HVC.createHvxIntrinsic(Builder, AddCarry,
+                                        /*RetTy=*/nullptr, Args);
+    Value *Result = Builder.CreateExtractValue(Ret, {0});
+    Value *CarryOut = Builder.CreateExtractValue(Ret, {1});
+    return {Result, CarryOut};
   }
 
-  Value *V3 =
-      HVC.vdeal(Builder, HVC.sublo(Builder, V2), HVC.subhi(Builder, V2));
-  // High halves: i32 -> i16
-  SmallVector<int, 64> HighMask;
-  for (int i = 0; i != static_cast<int>(HwLen) / 2; ++i) {
-    HighMask.push_back(2 * i + 1);
+  // In other cases, do a regular add, and unsigned compare-less-than.
+  // The carry-out can originate in two places: adding the carry-in or adding
+  // the two input values.
+  Value *Result1 = X; // Result1 = X + CarryIn
+  if (CarryIn != nullptr) {
+    unsigned Width = VecTy->getScalarSizeInBits();
+    uint32_t Mask = 1;
+    if (Width < 32) {
+      for (unsigned i = 0, e = 32 / Width; i != e; ++i)
+        Mask = (Mask << Width) | 1;
+    }
+    auto V6_vandqrt = HVC.HST.getIntrinsicId(Hexagon::V6_vandqrt);
+    Value *ValueIn =
+        HVC.createHvxIntrinsic(Builder, V6_vandqrt, /*RetTy=*/nullptr,
+                               {CarryIn, HVC.getConstInt(Mask)});
+    Result1 = Builder.CreateAdd(X, ValueIn);
   }
-  auto *HvxP16Ty = HVC.getHvxTy(HVC.getIntTy(16), /*Pair=*/true);
-  Value *V4 = Builder.CreateBitCast(V3, HvxP16Ty);
-  return Builder.CreateShuffleVector(V4, HighMask);
+
+  Value *CarryOut1 = Builder.CreateCmp(CmpInst::ICMP_ULT, Result1, X);
+  Value *Result2 = Builder.CreateAdd(Result1, Y);
+  Value *CarryOut2 = Builder.CreateCmp(CmpInst::ICMP_ULT, Result2, Y);
+  return {Result2, Builder.CreateOr(CarryOut1, CarryOut2)};
 }
 
-auto HvxIdioms::createMulQ31(IRBuilderBase &Builder, Value *X, Value *Y,
-                             bool Rounding) const -> Value * {
-  assert(X->getType() == Y->getType());
-  assert(X->getType()->getScalarType() == HVC.getIntTy(32));
-  if (!HVC.HST.isHVXVectorType(EVT::getEVT(X->getType(), false)))
-    return nullptr;
+auto HvxIdioms::createMul16(IRBuilderBase &Builder, SValue X, SValue Y) const
+    -> Value * {
+  Intrinsic::ID V6_vmpyh = 0;
+  std::tie(X, Y) = canonSgn(X, Y);
 
-  auto V6_vmpyewuh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyewuh);
-  auto MpyOddAcc = Rounding
-                       ? HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_rnd_sacc)
-                       : HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_sacc);
-  Value *V1 =
-      HVC.createHvxIntrinsic(Builder, V6_vmpyewuh, X->getType(), {X, Y});
-  return HVC.createHvxIntrinsic(Builder, MpyOddAcc, X->getType(), {V1, X, Y});
+  if (X.Sgn == Signed) {
+    V6_vmpyh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhv);
+  } else if (Y.Sgn == Signed) {
+    V6_vmpyh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyhus);
+  } else {
+    V6_vmpyh = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyuhv);
+  }
+
+  // i16*i16 -> i32 / interleaved
+  Value *P =
+      HVC.createHvxIntrinsic(Builder, V6_vmpyh, HvxP32Ty, {X.Val, Y.Val});
+  // Deinterleave
+  return HVC.vdeal(Builder, HVC.sublo(Builder, P), HVC.subhi(Builder, P));
 }
 
 auto HvxIdioms::createMul32(IRBuilderBase &Builder, SValue X, SValue Y) const
     -> std::pair<Value *, Value *> {
   assert(X.Val->getType() == Y.Val->getType());
-  assert(X.Val->getType() == HVC.getHvxTy(HVC.getIntTy(32), /*Pair=*/false));
+  assert(X.Val->getType() == HvxI32Ty);
 
   Intrinsic::ID V6_vmpy_parts;
-  if (X.Signed == Y.Signed) {
-    V6_vmpy_parts = X.Signed ? Intrinsic::hexagon_V6_vmpyss_parts
-                             : Intrinsic::hexagon_V6_vmpyuu_parts;
-  } else {
-    if (X.Signed)
-      std::swap(X, Y);
+  std::tie(X, Y) = canonSgn(X, Y);
+
+  if (X.Sgn == Signed) {
+    V6_vmpy_parts = Intrinsic::hexagon_V6_vmpyss_parts;
+  } else if (Y.Sgn == Signed) {
     V6_vmpy_parts = Intrinsic::hexagon_V6_vmpyus_parts;
+  } else {
+    V6_vmpy_parts = Intrinsic::hexagon_V6_vmpyuu_parts;
   }
 
   Value *Parts = HVC.createHvxIntrinsic(Builder, V6_vmpy_parts, nullptr,
@@ -1549,6 +1602,83 @@ auto HvxIdioms::createMul32(IRBuilderBase &Builder, SValue X, SValue Y) const
   return {Lo, Hi};
 }
 
+auto HvxIdioms::createAddLong(IRBuilderBase &Builder, ArrayRef<Value *> WordX,
+                              ArrayRef<Value *> WordY) const
+    -> SmallVector<Value *> {
+  assert(WordX.size() == WordY.size());
+  unsigned Idx = 0, Length = WordX.size();
+  SmallVector<Value *> Sum(Length);
+
+  while (Idx != Length) {
+    if (HVC.isZero(WordX[Idx]))
+      Sum[Idx] = WordY[Idx];
+    else if (HVC.isZero(WordY[Idx]))
+      Sum[Idx] = WordX[Idx];
+    else
+      break;
+    ++Idx;
+  }
+
+  Value *Carry = nullptr;
+  for (; Idx != Length; ++Idx) {
+    std::tie(Sum[Idx], Carry) =
+        createAddCarry(Builder, WordX[Idx], WordY[Idx], Carry);
+  }
+
+  // This drops the final carry beyond the highest word.
+  return Sum;
+}
+
+auto HvxIdioms::createMulLong(IRBuilderBase &Builder, ArrayRef<Value *> WordX,
+                              Signedness SgnX, ArrayRef<Value *> WordY,
+                              Signedness SgnY) const -> SmallVector<Value *> {
+  SmallVector<SmallVector<Value *>> Products(WordX.size() + WordY.size());
+
+  // WordX[i] * WordY[j] produces words i+j and i+j+1 of the results,
+  // that is halves 2(i+j), 2(i+j)+1, 2(i+j)+2, 2(i+j)+3.
+  for (int i = 0, e = WordX.size(); i != e; ++i) {
+    for (int j = 0, f = WordY.size(); j != f; ++j) {
+      // Check the 4 halves that this multiplication can generate.
+      Signedness SX = (i + 1 == e) ? SgnX : Unsigned;
+      Signedness SY = (j + 1 == f) ? SgnY : Unsigned;
+      auto [Lo, Hi] = createMul32(Builder, {WordX[i], SX}, {WordY[j], SY});
+      Products[i + j + 0].push_back(Lo);
+      Products[i + j + 1].push_back(Hi);
+    }
+  }
+
+  Value *Zero = HVC.getNullValue(WordX[0]->getType());
+
+  auto pop_back_or_zero = [Zero](auto &Vector) -> Value * {
+    if (Vector.empty())
+      return Zero;
+    auto Last = Vector.back();
+    Vector.pop_back();
+    return Last;
+  };
+
+  for (int i = 0, e = Products.size(); i != e; ++i) {
+    while (Products[i].size() > 1) {
+      Value *Carry = nullptr; // no carry-in
+      for (int j = i; j != e; ++j) {
+        auto &ProdJ = Products[j];
+        auto [Sum, CarryOut] = createAddCarry(Builder, pop_back_or_zero(ProdJ),
+                                              pop_back_or_zero(ProdJ), Carry);
+        ProdJ.insert(ProdJ.begin(), Sum);
+        Carry = CarryOut;
+      }
+    }
+  }
+
+  SmallVector<Value *> WordP;
+  for (auto &P : Products) {
+    assert(P.size() == 1 && "Should have been added together");
+    WordP.push_back(P.front());
+  }
+
+  return WordP;
+}
+
 auto HvxIdioms::run() -> bool {
   bool Changed = false;
 
@@ -1606,8 +1736,9 @@ auto HexagonVectorCombine::getBoolTy(int ElemCount) const -> Type * {
   return VectorType::get(BoolTy, ElemCount, /*Scalable=*/false);
 }
 
-auto HexagonVectorCombine::getConstInt(int Val) const -> ConstantInt * {
-  return ConstantInt::getSigned(getIntTy(), Val);
+auto HexagonVectorCombine::getConstInt(int Val, unsigned Width) const
+    -> ConstantInt * {
+  return ConstantInt::getSigned(getIntTy(Width), Val);
 }
 
 auto HexagonVectorCombine::isZero(const Value *Val) const -> bool {
@@ -1702,6 +1833,14 @@ auto HexagonVectorCombine::getConstSplat(Type *Ty, int Val) const
   return Splat;
 }
 
+auto HexagonVectorCombine::simplify(Value *V) const -> Value * {
+  if (auto *In = dyn_cast<Instruction>(V)) {
+    SimplifyQuery Q(DL, &TLI, &DT, &AC, In);
+    return simplifyInstruction(In, Q);
+  }
+  return nullptr;
+}
+
 // Insert bytes [Start..Start+Length) of Src into Dst at byte Where.
 auto HexagonVectorCombine::insertb(IRBuilderBase &Builder, Value *Dst,
                                    Value *Src, int Start, int Length,
@@ -2128,12 +2267,9 @@ auto HexagonVectorCombine::calculatePointerDifference(Value *Ptr0,
     return V;                                                                  \
   }(B)
 
-  auto Simplify = [&](Value *V) {
-    if (auto *I = dyn_cast<Instruction>(V)) {
-      SimplifyQuery Q(DL, &TLI, &DT, &AC, I);
-      if (Value *S = simplifyInstruction(I, Q))
-        return S;
-    }
+  auto Simplify = [this](Value *V) {
+    if (Value *S = simplify(V))
+      return S;
     return V;
   };
 

diff  --git a/llvm/test/CodeGen/Hexagon/autohvx/qmul-chop.ll b/llvm/test/CodeGen/Hexagon/autohvx/qmul-chop.ll
index 25318a11bef11..4ee22ae907936 100644
--- a/llvm/test/CodeGen/Hexagon/autohvx/qmul-chop.ll
+++ b/llvm/test/CodeGen/Hexagon/autohvx/qmul-chop.ll
@@ -39,5 +39,5 @@ declare <128 x i32> @llvm.smin.v128i32(<128 x i32>, <128 x i32>) #1
 ; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
 declare <128 x i32> @llvm.smax.v128i32(<128 x i32>, <128 x i32>) #1
 
-attributes #0 = { "target-features"="+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp" }
+attributes #0 = { "target-features"="+v68,+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp" }
 attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }

diff  --git a/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll b/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll
index 98f0f0eae622b..d3a4e88ab9de3 100644
--- a/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll
+++ b/llvm/test/CodeGen/Hexagon/autohvx/qmul.ll
@@ -72,37 +72,28 @@ define void @f2(ptr %a0, ptr %a1, ptr %a2) #0 {
 ; CHECK-LABEL: f2:
 ; CHECK:       // %bb.0: // %b0
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v0 = vmem(r1+#0)
+; CHECK-NEXT:     v0 = vmem(r0+#0)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v1:0.w = vunpack(v0.h)
+; CHECK-NEXT:     r7 = #-4
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
 ; CHECK-NEXT:     r3 = #15
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v2 = vmem(r0+#0)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v3:2.w = vunpack(v2.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v4.w = vmpyieo(v2.h,v0.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v5.w = vmpyieo(v3.h,v1.h)
+; CHECK-NEXT:     v1 = vmem(r1+#0)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v4.w += vmpyie(v2.w,v0.uh)
+; CHECK-NEXT:     v1:0.w = vmpy(v0.h,v1.h)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v5.w += vmpyie(v3.w,v1.uh)
+; CHECK-NEXT:     v1:0 = vshuff(v1,v0,r7)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v0.uw = vlsr(v4.uw,r3)
+; CHECK-NEXT:     v0.uw = vlsr(v0.uw,r3)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v1.uw = vlsr(v5.uw,r3)
+; CHECK-NEXT:     v1.uw = vlsr(v1.uw,r3)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
 ; CHECK-NEXT:     v0.h = vpacke(v1.w,v0.w)
@@ -129,58 +120,13 @@ define void @f3(ptr %a0, ptr %a1, ptr %a2) #0 {
 ; CHECK-LABEL: f3:
 ; CHECK:       // %bb.0: // %b0
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v0 = vmem(r1+#0)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v1:0.w = vunpack(v0.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     r4 = #16384
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     r3 = #15
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v2 = vmem(r0+#0)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v3:2.w = vunpack(v2.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     q0 = vcmp.gt(v0.w,v0.w)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     q1 = and(q0,q0)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v4 = vsplat(r4)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v5.w = vmpyieo(v2.h,v0.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v6.w = vmpyieo(v3.h,v1.h)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v5.w += vmpyie(v2.w,v0.uh)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v6.w += vmpyie(v3.w,v1.uh)
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v0.w = vadd(v4.w,v5.w,q1):carry
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v1.w = vadd(v4.w,v6.w,q0):carry
-; CHECK-NEXT:    }
-; CHECK-NEXT:    {
-; CHECK-NEXT:     v0.uw = vlsr(v0.uw,r3)
+; CHECK-NEXT:     v0 = vmem(r0+#0)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v1.uw = vlsr(v1.uw,r3)
+; CHECK-NEXT:     v1 = vmem(r1+#0)
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
-; CHECK-NEXT:     v0.h = vpacke(v1.w,v0.w)
+; CHECK-NEXT:     v0.h = vmpy(v0.h,v1.h):<<1:rnd:sat
 ; CHECK-NEXT:    }
 ; CHECK-NEXT:    {
 ; CHECK-NEXT:     vmem(r2+#0) = v0


        


More information about the llvm-commits mailing list