[llvm] Add ShuffleVectorInst. (PR #104891)

Jorge Gorbe Moya via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 18:52:24 PDT 2024


https://github.com/slackito created https://github.com/llvm/llvm-project/pull/104891

None

>From 5f8b0b56a303ecf31d5fc857e11cafe26875516c Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Mon, 12 Aug 2024 15:35:17 -0700
Subject: [PATCH] Add ShuffleVectorInst.

---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 284 ++++++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |  75 ++--
 llvm/lib/SandboxIR/SandboxIR.cpp              |  74 ++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 355 ++++++++++++++++++
 4 files changed, 751 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a881bdf28f22c2..4b2bc9e3961b1a 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -114,6 +114,7 @@ class Instruction;
 class SelectInst;
 class ExtractElementInst;
 class InsertElementInst;
+class ShuffleVectorInst;
 class BranchInst;
 class UnaryInstruction;
 class LoadInst;
@@ -245,6 +246,7 @@ class Value {
   friend class SelectInst;         // For getting `Val`.
   friend class ExtractElementInst; // For getting `Val`.
   friend class InsertElementInst;  // For getting `Val`.
+  friend class ShuffleVectorInst;  // For getting `Val`.
   friend class BranchInst;         // For getting `Val`.
   friend class LoadInst;           // For getting `Val`.
   friend class StoreInst;          // For getting `Val`.
@@ -666,6 +668,7 @@ class Instruction : public sandboxir::User {
   friend class SelectInst;         // For getTopmostLLVMInstruction().
   friend class ExtractElementInst; // For getTopmostLLVMInstruction().
   friend class InsertElementInst;  // For getTopmostLLVMInstruction().
+  friend class ShuffleVectorInst;  // For getTopmostLLVMInstruction().
   friend class BranchInst;         // For getTopmostLLVMInstruction().
   friend class LoadInst;           // For getTopmostLLVMInstruction().
   friend class StoreInst;          // For getTopmostLLVMInstruction().
@@ -945,6 +948,285 @@ class ExtractElementInst final
   }
 };
 
+class ShuffleVectorInst final
+  : public SingleLLVMInstructionImpl<llvm::ShuffleVectorInst> {
+  /// Use Context::createShuffleVectorInst() instead.
+  ShuffleVectorInst(llvm::Instruction *I, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::ShuffleVector, Opcode::ShuffleVector,
+                                  I, Ctx) {}
+  friend class Context; // For accessing the constructor in create*()
+
+public:
+  static Value *create(Value *V1, Value *V2, Value *Mask,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Value *V1, Value *V2, Value *Mask,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Value *V1, Value *V2, ArrayRef<int> Mask,
+                       Instruction *InsertBefore, Context &Ctx,
+                       const Twine &Name = "");
+  static Value *create(Value *V1, Value *V2, ArrayRef<int> Mask,
+                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       const Twine &Name = "");
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ShuffleVector;
+  }
+
+  static bool isValidOperands(const Value *V1, const Value *V2,
+                              const Value *Mask) {
+    return llvm::ShuffleVectorInst::isValidOperands(V1->Val, V2->Val,
+                                                    Mask->Val);
+  }
+
+  static bool isValidOperands(const Value *V1, const Value *V2,
+                              ArrayRef<int> Mask) {
+    return llvm::ShuffleVectorInst::isValidOperands(V1->Val, V2->Val,
+                                                    Mask);
+  }
+
+  void commute() {
+    cast<llvm::ShuffleVectorInst>(Val)->commute();
+  }
+
+  VectorType *getType() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->getType();
+  }
+
+  int getMaskValue(unsigned Elt) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->getMaskValue(Elt);
+  }
+
+  static void getShuffleMask(const Constant *Mask,
+                             SmallVectorImpl<int> &Result) {
+    llvm::ShuffleVectorInst::getShuffleMask(cast<llvm::Constant>(Mask->Val),
+                                            Result);
+  }
+
+  void getShuffleMask(SmallVectorImpl<int> &Result) const {
+    cast<llvm::ShuffleVectorInst>(Val)->getShuffleMask(Result);
+  }
+
+  Constant *getShuffleMaskForBitcode() const;
+
+  static Constant *convertShuffleMaskForBitcode(ArrayRef<int> Mask,
+                                                Type *ResultTy, Context &Ctx);
+
+  void setShuffleMask(ArrayRef<int> Mask) {
+    cast<llvm::ShuffleVectorInst>(Val)->setShuffleMask(Mask);
+  }
+
+  ArrayRef<int> getShuffleMask() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->getShuffleMask();
+  }
+
+  bool changesLength() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->changesLength();
+  }
+
+  bool increasesLength() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->increasesLength();
+  }
+
+  static bool isSingleSourceMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isSingleSourceMask(Mask, NumSrcElts);
+  }
+
+  static bool isSingleSourceMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isSingleSourceMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isSingleSource() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isSingleSource();
+  }
+
+  static bool isIdentityMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts);
+  }
+
+  static bool isIdentityMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isIdentityMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isIdentity() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isIdentity();
+  }
+
+  bool isIdentityWithPadding() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isIdentityWithPadding();
+  }
+
+  bool isIdentityWithExtract() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isIdentityWithExtract();
+  }
+
+  bool isConcat() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isConcat();
+  }
+
+  static bool isSelectMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isSelectMask(Mask, NumSrcElts);
+  }
+
+  static bool isSelectMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isSelectMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isSelect() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isSelect();
+  }
+
+  static bool isReverseMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isReverseMask(Mask, NumSrcElts);
+  }
+
+  static bool isReverseMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isReverseMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isReverse() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isReverse();
+  }
+
+  static bool isZeroEltSplatMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isZeroEltSplatMask(Mask, NumSrcElts);
+  }
+
+  static bool isZeroEltSplatMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isZeroEltSplatMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isZeroEltSplat() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isZeroEltSplat();
+  }
+
+  static bool isTransposeMask(ArrayRef<int> Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isTransposeMask(Mask, NumSrcElts);
+  }
+
+  static bool isTransposeMask(const Constant *Mask, int NumSrcElts) {
+    return llvm::ShuffleVectorInst::isTransposeMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts);
+  }
+
+  bool isTranspose() const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isTranspose();
+  }
+
+  static bool isSpliceMask(ArrayRef<int> Mask, int NumSrcElts, int &Index) {
+    return llvm::ShuffleVectorInst::isSpliceMask(Mask, NumSrcElts, Index);
+  }
+
+  static bool isSpliceMask(const Constant *Mask, int NumSrcElts, int &Index) {
+    return llvm::ShuffleVectorInst::isSpliceMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts, Index);
+  }
+
+  bool isSplice(int &Index) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isSplice(Index);
+  }
+
+  static bool isExtractSubvectorMask(ArrayRef<int> Mask, int NumSrcElts,
+                                     int &Index) {
+    return llvm::ShuffleVectorInst::isExtractSubvectorMask(Mask, NumSrcElts,
+                                                           Index);
+  }
+
+  static bool isExtractSubvectorMask(const Constant *Mask, int NumSrcElts,
+                                     int &Index) {
+    return llvm::ShuffleVectorInst::isExtractSubvectorMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts, Index);
+  }
+
+  bool isExtractSubvectorMask(int &Index) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isExtractSubvectorMask(Index);
+  }
+
+  static bool isInsertSubvectorMask(ArrayRef<int> Mask, int NumSrcElts,
+                                    int &NumSubElts, int &Index) {
+    return llvm::ShuffleVectorInst::isInsertSubvectorMask(Mask, NumSrcElts,
+                                                          NumSubElts, Index);
+  }
+
+  static bool isInsertSubvectorMask(const Constant *Mask, int NumSrcElts,
+                                    int &NumSubElts, int &Index) {
+    return llvm::ShuffleVectorInst::isInsertSubvectorMask(
+        cast<llvm::Constant>(Mask->Val), NumSrcElts, NumSubElts, Index);
+  }
+
+  bool isInsertSubvectorMask(int &NumSubElts, int &Index) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isInsertSubvectorMask(NumSubElts,
+                                                                     Index);
+  }
+
+  static bool isReplicationMask(ArrayRef<int> Mask, int &ReplicationFactor,
+                                int &VF) {
+    return llvm::ShuffleVectorInst::isReplicationMask(Mask, ReplicationFactor,
+                                                      VF);
+  }
+  static bool isReplicationMask(const Constant *Mask, int &ReplicationFactor,
+                                int &VF) {
+    return llvm::ShuffleVectorInst::isReplicationMask(cast<llvm::Constant>(Mask->Val), ReplicationFactor, VF);
+  }
+
+  /// Return true if this shuffle mask is a replication mask.
+  bool isReplicationMask(int &ReplicationFactor, int &VF) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isReplicationMask(
+        ReplicationFactor, VF);
+  }
+
+  static bool isOneUseSingleSourceMask(ArrayRef<int> Mask, int VF) {
+    return llvm::ShuffleVectorInst::isOneUseSingleSourceMask(Mask, VF);
+  }
+
+  bool isOneUseSingleSourceMask(int VF) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isOneUseSingleSourceMask(VF);
+  }
+
+  static void commuteShuffleMask(MutableArrayRef<int> Mask,
+                                 unsigned InVecNumElts) {
+    llvm::ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts);
+  }
+
+  static bool isInterleaveMask(ArrayRef<int> Mask, unsigned Factor,
+                               unsigned NumInputElts,
+                               SmallVectorImpl<unsigned> &StartIndexes) {
+    return llvm::ShuffleVectorInst::isInterleaveMask(Mask, Factor, NumInputElts,
+                                                     StartIndexes);
+  }
+
+  static bool isInterleaveMask(ArrayRef<int> Mask, unsigned Factor,
+                               unsigned NumInputElts) {
+    return llvm::ShuffleVectorInst::isInterleaveMask(Mask, Factor,
+                                                     NumInputElts);
+  }
+
+  bool isInterleave(unsigned Factor) const {
+    return cast<llvm::ShuffleVectorInst>(Val)->isInterleave(Factor);
+  }
+
+  static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
+                                         unsigned &Index) {
+    return llvm::ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, Factor,
+                                                               Index);
+  }
+  static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor) {
+    return llvm::ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, Factor);
+  }
+
+  static bool isBitRotateMask(ArrayRef<int> Mask, unsigned EltSizeInBits,
+                              unsigned MinSubElts, unsigned MaxSubElts,
+                              unsigned &NumSubElts, unsigned &RotateAmt) {
+    return llvm::ShuffleVectorInst::isBitRotateMask(
+        Mask, EltSizeInBits, MinSubElts, MaxSubElts, NumSubElts, RotateAmt);
+  }
+};
+
 class BranchInst : public SingleLLVMInstructionImpl<llvm::BranchInst> {
   /// Use Context::createBranchInst(). Don't call the constructor directly.
   BranchInst(llvm::BranchInst *BI, Context &Ctx)
@@ -2185,6 +2467,8 @@ class Context {
   friend InsertElementInst; // For createInsertElementInst()
   ExtractElementInst *createExtractElementInst(llvm::ExtractElementInst *EEI);
   friend ExtractElementInst; // For createExtractElementInst()
+  ShuffleVectorInst *createShuffleVectorInst(llvm::ShuffleVectorInst *SVI);
+  friend ShuffleVectorInst; // For createShuffleVectorInst()
   BranchInst *createBranchInst(llvm::BranchInst *I);
   friend BranchInst; // For createBranchInst()
   LoadInst *createLoadInst(llvm::LoadInst *LI);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 2b9b44c529b30d..359f980de3cbd6 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -33,45 +33,46 @@ DEF_USER(ConstantInt, ConstantInt)
 #define OPCODES(...)
 #endif
 // clang-format off
-//       ClassID,        Opcode(s),         Class
-DEF_INSTR(Opaque,        OP(Opaque),        OpaqueInst)
+//        ClassID,        Opcode(s),         Class
+DEF_INSTR(Opaque,         OP(Opaque),        OpaqueInst)
 DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
-DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
-DEF_INSTR(Select,        OP(Select),        SelectInst)
-DEF_INSTR(Br,            OP(Br),            BranchInst)
-DEF_INSTR(Load,          OP(Load),          LoadInst)
-DEF_INSTR(Store,         OP(Store),         StoreInst)
-DEF_INSTR(Ret,           OP(Ret),           ReturnInst)
-DEF_INSTR(Call,          OP(Call),          CallInst)
-DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
-DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
-DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(Switch,        OP(Switch),        SwitchInst)
-DEF_INSTR(UnOp,          OPCODES( \
-                         OP(FNeg) \
-                         ),                 UnaryOperator)
+DEF_INSTR(InsertElement,  OP(InsertElement), InsertElementInst)
+DEF_INSTR(ShuffleVector,  OP(ShuffleVector),  ShuffleVectorInst)
+DEF_INSTR(Select,         OP(Select),        SelectInst)
+DEF_INSTR(Br,             OP(Br),            BranchInst)
+DEF_INSTR(Load,           OP(Load),          LoadInst)
+DEF_INSTR(Store,          OP(Store),         StoreInst)
+DEF_INSTR(Ret,            OP(Ret),           ReturnInst)
+DEF_INSTR(Call,           OP(Call),          CallInst)
+DEF_INSTR(Invoke,         OP(Invoke),        InvokeInst)
+DEF_INSTR(CallBr,         OP(CallBr),        CallBrInst)
+DEF_INSTR(GetElementPtr,  OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(Switch,         OP(Switch),        SwitchInst)
+DEF_INSTR(UnOp,           OPCODES( \
+                          OP(FNeg) \
+                          ),                 UnaryOperator)
 DEF_INSTR(BinaryOperator, OPCODES(\
-                         OP(Add)  \
-                         OP(FAdd) \
-                         OP(Sub)  \
-                         OP(FSub) \
-                         OP(Mul)  \
-                         OP(FMul) \
-                         OP(UDiv) \
-                         OP(SDiv) \
-                         OP(FDiv) \
-                         OP(URem) \
-                         OP(SRem) \
-                         OP(FRem) \
-                         OP(Shl)  \
-                         OP(LShr) \
-                         OP(AShr) \
-                         OP(And)  \
-                         OP(Or)   \
-                         OP(Xor)  \
-                         ),                 BinaryOperator)
-DEF_INSTR(AtomicRMW,     OP(AtomicRMW),     AtomicRMWInst)
-DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
+                          OP(Add)  \
+                          OP(FAdd) \
+                          OP(Sub)  \
+                          OP(FSub) \
+                          OP(Mul)  \
+                          OP(FMul) \
+                          OP(UDiv) \
+                          OP(SDiv) \
+                          OP(FDiv) \
+                          OP(URem) \
+                          OP(SRem) \
+                          OP(FRem) \
+                          OP(Shl)  \
+                          OP(LShr) \
+                          OP(AShr) \
+                          OP(And)  \
+                          OP(Or)   \
+                          OP(Xor)  \
+                          ),                 BinaryOperator)
+DEF_INSTR(AtomicRMW,      OP(AtomicRMW),     AtomicRMWInst)
+DEF_INSTR(AtomicCmpXchg,  OP(AtomicCmpXchg), AtomicCmpXchgInst)
 DEF_INSTR(Alloca,         OP(Alloca),         AllocaInst)
 DEF_INSTR(Cast,   OPCODES(\
                           OP(ZExt)          \
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index c243df7fc864ee..8b695e51d39faf 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1773,6 +1773,67 @@ Value *ExtractElementInst::create(Value *Vec, Value *Idx,
   return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
 }
 
+Value *ShuffleVectorInst::create(Value *V1, Value *V2, Value *Mask,
+                                 Instruction *InsertBefore, Context &Ctx,
+                                 const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  llvm::Value *NewV =
+      Builder.CreateShuffleVector(V1->Val, V2->Val, Mask->Val, Name);
+  if (auto *NewShuffle = dyn_cast<llvm::ShuffleVectorInst>(NewV))
+    return Ctx.createShuffleVectorInst(NewShuffle);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *ShuffleVectorInst::create(Value *V1, Value *V2, Value *Mask,
+                                 BasicBlock *InsertAtEnd, Context &Ctx,
+                                 const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  llvm::Value *NewV =
+      Builder.CreateShuffleVector(V1->Val, V2->Val, Mask->Val, Name);
+  if (auto *NewShuffle = dyn_cast<llvm::ShuffleVectorInst>(NewV))
+    return Ctx.createShuffleVectorInst(NewShuffle);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *ShuffleVectorInst::create(Value *V1, Value *V2, ArrayRef<int> Mask,
+                                 Instruction *InsertBefore, Context &Ctx,
+                                 const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  llvm::Value *NewV = Builder.CreateShuffleVector(V1->Val, V2->Val, Mask, Name);
+  if (auto *NewShuffle = dyn_cast<llvm::ShuffleVectorInst>(NewV))
+    return Ctx.createShuffleVectorInst(NewShuffle);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *ShuffleVectorInst::create(Value *V1, Value *V2, ArrayRef<int> Mask,
+                                 BasicBlock *InsertAtEnd, Context &Ctx,
+                                 const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  llvm::Value *NewV = Builder.CreateShuffleVector(V1->Val, V2->Val, Mask, Name);
+  if (auto *NewShuffle = dyn_cast<llvm::ShuffleVectorInst>(NewV))
+    return Ctx.createShuffleVectorInst(NewShuffle);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const {
+  return Ctx.getOrCreateConstant(
+      cast<llvm::ShuffleVectorInst>(Val)->getShuffleMaskForBitcode());
+}
+
+Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(
+    llvm::ArrayRef<int> Mask, llvm::Type *ResultTy, Context &Ctx) {
+  return Ctx.getOrCreateConstant(
+      llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy));
+}
+
 #ifndef NDEBUG
 void Constant::dumpOS(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -1912,6 +1973,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new InsertElementInst(LLVMIns, *this));
     return It->second.get();
   }
+  case llvm::Instruction::ShuffleVector: {
+    auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV);
+    It->second = std::unique_ptr<ShuffleVectorInst>(
+        new ShuffleVectorInst(LLVMIns, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Br: {
     auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
     It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
@@ -2070,6 +2137,13 @@ Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
   return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
 }
 
+ShuffleVectorInst *
+Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
+  auto NewPtr =
+      std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
+  return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
+}
+
 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
   auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
   return cast<BranchInst>(registerValue(std::move(NewPtr)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 1d6a26728c9c56..b64b8cf1fafb3c 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -15,6 +15,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock-matchers.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -739,6 +740,360 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
       llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero));
 }
 
+TEST_F(SandboxIRTest, ShuffleVectorInst) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
+  %ins0 = shufflevector <2 x i8> %v1, <2 x i8> %v2, <2 x i32> <i32 0, i32 2>
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *ArgV1 = F.getArg(0);
+  auto *ArgV2 = F.getArg(1);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *SI = cast<sandboxir::ShuffleVectorInst>(&*It++);
+  auto *Ret = &*It++;
+
+  EXPECT_EQ(SI->getOpcode(), sandboxir::Instruction::Opcode::ShuffleVector);
+  EXPECT_EQ(SI->getType(), ArgV1->getType());
+  EXPECT_EQ(SI->getOperand(0), ArgV1);
+  EXPECT_EQ(SI->getOperand(1), ArgV2);
+  EXPECT_EQ(SI->getMaskValue(0), 0);
+  EXPECT_EQ(SI->getMaskValue(1), 2);
+  SI->commute();
+  EXPECT_EQ(SI->getOperand(0), ArgV2);
+  EXPECT_EQ(SI->getOperand(1), ArgV1);
+  EXPECT_THAT(SI->getShuffleMask(),
+              testing::ContainerEq(ArrayRef<int>({2, 0})));
+
+  auto *NewI1 =
+      cast<sandboxir::ShuffleVectorInst>(sandboxir::ShuffleVectorInst::create(
+          ArgV1, ArgV2, ArrayRef<int>({0, 2, 1, 3}), Ret, Ctx,
+          "NewShuffleBeforeRet"));
+  EXPECT_EQ(NewI1->getOperand(0), ArgV1);
+  EXPECT_EQ(NewI1->getOperand(1), ArgV2);
+  EXPECT_EQ(NewI1->getNextNode(), Ret);
+  EXPECT_TRUE(NewI1->changesLength());
+  EXPECT_TRUE(NewI1->increasesLength());
+
+  auto *NewI2 =
+      cast<sandboxir::ShuffleVectorInst>(sandboxir::ShuffleVectorInst::create(
+          ArgV1, ArgV2, ArrayRef<int>({0, 1}), BB, Ctx, "NewShuffleAtEndOfBB"));
+  EXPECT_EQ(NewI2->getPrevNode(), Ret);
+
+  auto *LLVMArgV1 = LLVMF.getArg(0);
+  auto *LLVMArgV2 = LLVMF.getArg(1);
+  ArrayRef<int> Mask({1, 2});
+  EXPECT_EQ(
+      sandboxir::ShuffleVectorInst::isValidOperands(ArgV1, ArgV2, Mask),
+      llvm::ShuffleVectorInst::isValidOperands(LLVMArgV1, LLVMArgV2, Mask));
+  EXPECT_EQ(sandboxir::ShuffleVectorInst::isValidOperands(ArgV1, ArgV1, ArgV1),
+            llvm::ShuffleVectorInst::isValidOperands(LLVMArgV1, LLVMArgV1,
+                                                     LLVMArgV1));
+
+  // The following functions check different mask properties. Note that most
+  // of these come in three different flavors: a method that checks the mask
+  // in the current instructions and two static member functions that check
+  // a mask given as an ArrayRef<int> or Constant*, so there's quite a bit of
+  // repetition in order to check all of them.
+  //
+  // Because we need masks of different lengths we can't simply reuse one of the
+  // instructions we have created above, so we'll create a new instruction for
+  // each test using this helper.
+  auto shuffleWithMask = [&](auto &&...Indices) {
+    SmallVector<int, 4> Mask = {Indices...};
+    return cast<sandboxir::ShuffleVectorInst>(
+        sandboxir::ShuffleVectorInst::create(ArgV1, ArgV2, Mask, Ret, Ctx));
+  };
+
+  // changesLength / increasesLength
+  {
+    auto *I = shuffleWithMask(1);
+    EXPECT_TRUE(I->changesLength());
+    EXPECT_FALSE(I->increasesLength());
+  }
+  {
+    auto *I = shuffleWithMask(1, 1);
+    EXPECT_FALSE(I->changesLength());
+    EXPECT_FALSE(I->increasesLength());
+  }
+  {
+    auto *I = shuffleWithMask(1, 1, 1);
+    EXPECT_TRUE(I->changesLength());
+    EXPECT_TRUE(I->increasesLength());
+  }
+
+  // isSingleSourceMask
+  {
+    auto *I = shuffleWithMask(0, 1);
+    EXPECT_TRUE(I->isSingleSource());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(0, 2);
+    EXPECT_FALSE(I->isSingleSource());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // isIdentityMask
+  {
+    auto *I = shuffleWithMask(0, 1);
+    EXPECT_TRUE(I->isIdentity());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isIdentityMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isIdentityMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(1, 0);
+    EXPECT_FALSE(I->isIdentity());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isIdentityMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isIdentityMask(I->getShuffleMask(), 2));
+  }
+
+  // isIdentityWithPadding
+  EXPECT_TRUE(shuffleWithMask(0, 1, -1, -1)->isIdentityWithPadding());
+  EXPECT_FALSE(shuffleWithMask(0, 1)->isIdentityWithPadding());
+
+  // isIdentityWithExtract
+  EXPECT_TRUE(shuffleWithMask(0)->isIdentityWithExtract());
+  EXPECT_FALSE(shuffleWithMask(0, 1)->isIdentityWithExtract());
+  EXPECT_FALSE(shuffleWithMask(0, 1, 2)->isIdentityWithExtract());
+  EXPECT_FALSE(shuffleWithMask(1)->isIdentityWithExtract());
+
+  // isConcat
+  EXPECT_TRUE(shuffleWithMask(0, 1, 2, 3)->isConcat());
+  EXPECT_FALSE(shuffleWithMask(0, 3)->isConcat());
+
+  // isSelectMask
+  {
+    auto *I = shuffleWithMask(0, 3);
+    EXPECT_TRUE(I->isSelect());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSelectMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isSelectMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(0, 2);
+    EXPECT_FALSE(I->isSelect());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSelectMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isSelectMask(I->getShuffleMask(), 2));
+  }
+
+  // isReverseMask
+  {
+    auto *I = shuffleWithMask(1, 0);
+    EXPECT_TRUE(I->isReverse());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReverseMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isReverseMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(1, 2);
+    EXPECT_FALSE(I->isReverse());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReverseMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isReverseMask(I->getShuffleMask(), 2));
+  }
+
+  // isZeroEltSplatMask
+  {
+    auto *I = shuffleWithMask(0, 0);
+    EXPECT_TRUE(I->isZeroEltSplat());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(1, 1);
+    EXPECT_FALSE(I->isZeroEltSplat());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // isTransposeMask
+  {
+    auto *I = shuffleWithMask(0, 2);
+    EXPECT_TRUE(I->isTranspose());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isTransposeMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isTransposeMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(1, 1);
+    EXPECT_FALSE(I->isTranspose());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isTransposeMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isTransposeMask(I->getShuffleMask(), 2));
+  }
+
+  // isSpliceMask
+  {
+    auto *I = shuffleWithMask(1, 2);
+    int Index;
+    EXPECT_TRUE(I->isSplice(Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSpliceMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSpliceMask(I->getShuffleMask(),
+                                                           2, Index));
+  }
+  {
+    auto *I = shuffleWithMask(2, 1);
+    int Index;
+    EXPECT_FALSE(I->isSplice(Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSpliceMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSpliceMask(I->getShuffleMask(),
+                                                            2, Index));
+  }
+
+  // isExtractSubvectorMask
+  {
+    auto *I = shuffleWithMask(1);
+    int Index;
+    EXPECT_TRUE(I->isExtractSubvectorMask(Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMask(), 2, Index));
+  }
+  {
+    auto *I = shuffleWithMask(1, 2);
+    int Index;
+    EXPECT_FALSE(I->isExtractSubvectorMask(Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMask(), 2, Index));
+  }
+
+  // isInsertSubvectorMask
+  {
+    auto *I = shuffleWithMask(0, 2);
+    int NumSubElts, Index;
+    EXPECT_TRUE(I->isInsertSubvectorMask(NumSubElts, Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_EQ(NumSubElts, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, NumSubElts, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMask(), 2, NumSubElts, Index));
+  }
+  {
+    auto *I = shuffleWithMask(0, 1);
+    int NumSubElts, Index;
+    EXPECT_FALSE(I->isInsertSubvectorMask(NumSubElts, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, NumSubElts, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMask(), 2, NumSubElts, Index));
+  }
+
+  // isReplicationMask
+  {
+    auto *I = shuffleWithMask(0, 0, 0, 1, 1, 1);
+    int ReplicationFactor, VF;
+    EXPECT_TRUE(I->isReplicationMask(ReplicationFactor, VF));
+    EXPECT_EQ(ReplicationFactor, 3);
+    EXPECT_EQ(VF, 2);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMaskForBitcode(), ReplicationFactor, VF));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMask(), ReplicationFactor, VF));
+  }
+  {
+    auto *I = shuffleWithMask(1, 2);
+    int ReplicationFactor, VF;
+    EXPECT_FALSE(I->isReplicationMask(ReplicationFactor, VF));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMaskForBitcode(), ReplicationFactor, VF));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMask(), ReplicationFactor, VF));
+  }
+
+  // isOneUseSingleSourceMask
+  {
+    auto *I = shuffleWithMask(0, 1, 1, 0);
+    EXPECT_TRUE(I->isOneUseSingleSourceMask(2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isOneUseSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = shuffleWithMask(0, 1, 0, 0);
+    EXPECT_FALSE(I->isOneUseSingleSourceMask(2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isOneUseSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // commuteShuffleMask
+  {
+    SmallVector<int, 4> M = {0, 2, 1, 3};
+    ShuffleVectorInst::commuteShuffleMask(M, 2);
+    EXPECT_THAT(M, testing::ContainerEq(ArrayRef<int>({2, 0, 3, 1})));
+  }
+
+  // isInterleaveMask
+  {
+    auto *I = shuffleWithMask(0, 2, 1, 3);
+    EXPECT_TRUE(I->isInterleave(2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4));
+    SmallVector<unsigned, 4> StartIndexes;
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4, StartIndexes));
+    EXPECT_THAT(StartIndexes, testing::ContainerEq(ArrayRef<unsigned>({0, 2})));
+  }
+  {
+    auto *I = shuffleWithMask(0, 3, 1, 2);
+    EXPECT_FALSE(I->isInterleave(2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4));
+  }
+
+  // isDeInterleaveMaskOfFactor
+  EXPECT_TRUE(sandboxir::ShuffleVectorInst::isDeInterleaveMaskOfFactor(
+      ArrayRef<int>({0, 2}), 2));
+  EXPECT_FALSE(sandboxir::ShuffleVectorInst::isDeInterleaveMaskOfFactor(
+      ArrayRef<int>({0, 1}), 2));
+
+  // isBitRotateMask
+  {
+    unsigned NumSubElts, RotateAmt;
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isBitRotateMask(
+        ArrayRef<int>({1, 0, 3, 2, 5, 4, 7, 6}), 8, 2, 2, NumSubElts,
+        RotateAmt));
+    EXPECT_EQ(NumSubElts, 2u);
+    EXPECT_EQ(RotateAmt, 8u);
+
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isBitRotateMask(
+        ArrayRef<int>({0, 7, 1, 6, 2, 5, 3, 4}), 8, 2, 2, NumSubElts,
+        RotateAmt));
+  }
+}
+
 TEST_F(SandboxIRTest, BranchInst) {
   parseIR(C, R"IR(
 define void @foo(i1 %cond0, i1 %cond2) {



More information about the llvm-commits mailing list