[llvm] [SandboxIR] Add ShuffleVectorInst (PR #104891)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 17:22:39 PDT 2024


================
@@ -739,6 +740,401 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
       llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero));
 }
 
+TEST_F(SandboxIRTest, ShuffleVectorInst) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
+  %ins0 = shufflevector <2 x i8> %v1, <2 x i8> %v2, <2 x i32> <i32 0, i32 2>
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *ArgV1 = F.getArg(0);
+  auto *ArgV2 = F.getArg(1);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *SI = cast<sandboxir::ShuffleVectorInst>(&*It++);
+  auto *Ret = &*It++;
+
+  EXPECT_EQ(SI->getOpcode(), sandboxir::Instruction::Opcode::ShuffleVector);
+  EXPECT_EQ(SI->getOperand(0), ArgV1);
+  EXPECT_EQ(SI->getOperand(1), ArgV2);
+
+  // In order to test all the methods we need masks of different lengths, so we
+  // can't simply reuse one of the instructions created above. This helper
+  // creates a new `shufflevector %v1, %2, <mask>` with the given mask indices.
+  auto CreateShuffleWithMask = [&](auto &&...Indices) {
+    SmallVector<int, 4> Mask = {Indices...};
+    return cast<sandboxir::ShuffleVectorInst>(
+        sandboxir::ShuffleVectorInst::create(ArgV1, ArgV2, Mask, Ret, Ctx));
+  };
+
+  // create (InsertBefore)
+  auto *NewI1 =
+      cast<sandboxir::ShuffleVectorInst>(sandboxir::ShuffleVectorInst::create(
+          ArgV1, ArgV2, ArrayRef<int>({0, 2, 1, 3}), Ret, Ctx,
+          "NewShuffleBeforeRet"));
+  EXPECT_EQ(NewI1->getOperand(0), ArgV1);
+  EXPECT_EQ(NewI1->getOperand(1), ArgV2);
+  EXPECT_EQ(NewI1->getNextNode(), Ret);
+#ifndef NDEBUG
+  EXPECT_EQ(NewI1->getName(), "NewShuffleBeforeRet");
+#endif
+
+  // create (InsertAtEnd)
+  auto *NewI2 =
+      cast<sandboxir::ShuffleVectorInst>(sandboxir::ShuffleVectorInst::create(
+          ArgV1, ArgV2, ArrayRef<int>({0, 1}), BB, Ctx, "NewShuffleAtEndOfBB"));
+  EXPECT_EQ(NewI2->getPrevNode(), Ret);
+
+  // isValidOperands
+  auto *LLVMArgV1 = LLVMF.getArg(0);
+  auto *LLVMArgV2 = LLVMF.getArg(1);
+  ArrayRef<int> Mask({1, 2});
+  EXPECT_EQ(
+      sandboxir::ShuffleVectorInst::isValidOperands(ArgV1, ArgV2, Mask),
+      llvm::ShuffleVectorInst::isValidOperands(LLVMArgV1, LLVMArgV2, Mask));
+  EXPECT_EQ(sandboxir::ShuffleVectorInst::isValidOperands(ArgV1, ArgV1, ArgV1),
+            llvm::ShuffleVectorInst::isValidOperands(LLVMArgV1, LLVMArgV1,
+                                                     LLVMArgV1));
+
+  // commute
+  {
+    auto *I = CreateShuffleWithMask(0, 2);
+    I->commute();
+    EXPECT_EQ(I->getOperand(0), ArgV2);
+    EXPECT_EQ(I->getOperand(1), ArgV1);
+    EXPECT_THAT(I->getShuffleMask(),
+                testing::ContainerEq(ArrayRef<int>({2, 0})));
+  }
+
+  // getType
+  EXPECT_EQ(SI->getType(), ArgV1->getType());
+
+  // getMaskValue
+  EXPECT_EQ(SI->getMaskValue(0), 0);
+  EXPECT_EQ(SI->getMaskValue(1), 2);
+
+  // getShuffleMask/getShuffleMaskForBitcode
+  {
+    EXPECT_THAT(SI->getShuffleMask(),
+                testing::ContainerEq(ArrayRef<int>({0, 2})));
+
+    SmallVector<int, 2> Result;
+    sandboxir::ShuffleVectorInst::getShuffleMask(SI->getShuffleMaskForBitcode(),
+                                                 Result);
+    EXPECT_THAT(Result, testing::ContainerEq(ArrayRef<int>({0, 2})));
+  }
+
+  // convertShuffleMaskForBitcode
+  {
+    auto *C = sandboxir::ShuffleVectorInst::convertShuffleMaskForBitcode(
+        ArrayRef<int>({2, 3}), ArgV1->getType(), Ctx);
+    SmallVector<int, 2> Result;
+    sandboxir::ShuffleVectorInst::getShuffleMask(C, Result);
+    EXPECT_THAT(Result, testing::ContainerEq(ArrayRef<int>({2, 3})));
+  }
+
+  // setShuffleMask
+  {
+    auto *I = CreateShuffleWithMask(0, 1);
+    I->setShuffleMask(ArrayRef<int>({2, 3}));
+    EXPECT_THAT(I->getShuffleMask(),
+                testing::ContainerEq(ArrayRef<int>({2, 3})));
+  }
+
+  // The following functions check different mask properties. Note that most
+  // of these come in three different flavors: a method that checks the mask
+  // in the current instructions and two static member functions that check
+  // a mask given as an ArrayRef<int> or Constant*, so there's quite a bit of
+  // repetition in order to check all of them.
+
+  // changesLength / increasesLength
+  {
+    auto *I = CreateShuffleWithMask(1);
+    EXPECT_TRUE(I->changesLength());
+    EXPECT_FALSE(I->increasesLength());
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 1);
+    EXPECT_FALSE(I->changesLength());
+    EXPECT_FALSE(I->increasesLength());
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 1, 1);
+    EXPECT_TRUE(I->changesLength());
+    EXPECT_TRUE(I->increasesLength());
+  }
+
+  // isSingleSourceMask
+  {
+    auto *I = CreateShuffleWithMask(0, 1);
+    EXPECT_TRUE(I->isSingleSource());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(0, 2);
+    EXPECT_FALSE(I->isSingleSource());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // isIdentityMask
+  {
+    auto *I = CreateShuffleWithMask(0, 1);
+    EXPECT_TRUE(I->isIdentity());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isIdentityMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isIdentityMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 0);
+    EXPECT_FALSE(I->isIdentity());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isIdentityMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isIdentityMask(I->getShuffleMask(), 2));
+  }
+
+  // isIdentityWithPadding
+  EXPECT_TRUE(CreateShuffleWithMask(0, 1, -1, -1)->isIdentityWithPadding());
+  EXPECT_FALSE(CreateShuffleWithMask(0, 1)->isIdentityWithPadding());
+
+  // isIdentityWithExtract
+  EXPECT_TRUE(CreateShuffleWithMask(0)->isIdentityWithExtract());
+  EXPECT_FALSE(CreateShuffleWithMask(0, 1)->isIdentityWithExtract());
+  EXPECT_FALSE(CreateShuffleWithMask(0, 1, 2)->isIdentityWithExtract());
+  EXPECT_FALSE(CreateShuffleWithMask(1)->isIdentityWithExtract());
+
+  // isConcat
+  EXPECT_TRUE(CreateShuffleWithMask(0, 1, 2, 3)->isConcat());
+  EXPECT_FALSE(CreateShuffleWithMask(0, 3)->isConcat());
+
+  // isSelectMask
+  {
+    auto *I = CreateShuffleWithMask(0, 3);
+    EXPECT_TRUE(I->isSelect());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSelectMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isSelectMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(0, 2);
+    EXPECT_FALSE(I->isSelect());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSelectMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isSelectMask(I->getShuffleMask(), 2));
+  }
+
+  // isReverseMask
+  {
+    auto *I = CreateShuffleWithMask(1, 0);
+    EXPECT_TRUE(I->isReverse());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReverseMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isReverseMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 2);
+    EXPECT_FALSE(I->isReverse());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReverseMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isReverseMask(I->getShuffleMask(), 2));
+  }
+
+  // isZeroEltSplatMask
+  {
+    auto *I = CreateShuffleWithMask(0, 0);
+    EXPECT_TRUE(I->isZeroEltSplat());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 1);
+    EXPECT_FALSE(I->isZeroEltSplat());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isZeroEltSplatMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // isTransposeMask
+  {
+    auto *I = CreateShuffleWithMask(0, 2);
+    EXPECT_TRUE(I->isTranspose());
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isTransposeMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_TRUE(
+        sandboxir::ShuffleVectorInst::isTransposeMask(I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 1);
+    EXPECT_FALSE(I->isTranspose());
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isTransposeMask(
+        I->getShuffleMaskForBitcode(), 2));
+    EXPECT_FALSE(
+        sandboxir::ShuffleVectorInst::isTransposeMask(I->getShuffleMask(), 2));
+  }
+
+  // isSpliceMask
+  {
+    auto *I = CreateShuffleWithMask(1, 2);
+    int Index;
+    EXPECT_TRUE(I->isSplice(Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSpliceMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isSpliceMask(I->getShuffleMask(),
+                                                           2, Index));
+  }
+  {
+    auto *I = CreateShuffleWithMask(2, 1);
+    int Index;
+    EXPECT_FALSE(I->isSplice(Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSpliceMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isSpliceMask(I->getShuffleMask(),
+                                                            2, Index));
+  }
+
+  // isExtractSubvectorMask
+  {
+    auto *I = CreateShuffleWithMask(1);
+    int Index;
+    EXPECT_TRUE(I->isExtractSubvectorMask(Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMask(), 2, Index));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 2);
+    int Index;
+    EXPECT_FALSE(I->isExtractSubvectorMask(Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isExtractSubvectorMask(
+        I->getShuffleMask(), 2, Index));
+  }
+
+  // isInsertSubvectorMask
+  {
+    auto *I = CreateShuffleWithMask(0, 2);
+    int NumSubElts, Index;
+    EXPECT_TRUE(I->isInsertSubvectorMask(NumSubElts, Index));
+    EXPECT_EQ(Index, 1);
+    EXPECT_EQ(NumSubElts, 1);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, NumSubElts, Index));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMask(), 2, NumSubElts, Index));
+  }
+  {
+    auto *I = CreateShuffleWithMask(0, 1);
+    int NumSubElts, Index;
+    EXPECT_FALSE(I->isInsertSubvectorMask(NumSubElts, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMaskForBitcode(), 2, NumSubElts, Index));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInsertSubvectorMask(
+        I->getShuffleMask(), 2, NumSubElts, Index));
+  }
+
+  // isReplicationMask
+  {
+    auto *I = CreateShuffleWithMask(0, 0, 0, 1, 1, 1);
+    int ReplicationFactor, VF;
+    EXPECT_TRUE(I->isReplicationMask(ReplicationFactor, VF));
+    EXPECT_EQ(ReplicationFactor, 3);
+    EXPECT_EQ(VF, 2);
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMaskForBitcode(), ReplicationFactor, VF));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMask(), ReplicationFactor, VF));
+  }
+  {
+    auto *I = CreateShuffleWithMask(1, 2);
+    int ReplicationFactor, VF;
+    EXPECT_FALSE(I->isReplicationMask(ReplicationFactor, VF));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMaskForBitcode(), ReplicationFactor, VF));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isReplicationMask(
+        I->getShuffleMask(), ReplicationFactor, VF));
+  }
+
+  // isOneUseSingleSourceMask
+  {
+    auto *I = CreateShuffleWithMask(0, 1, 1, 0);
+    EXPECT_TRUE(I->isOneUseSingleSourceMask(2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isOneUseSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+  {
+    auto *I = CreateShuffleWithMask(0, 1, 0, 0);
+    EXPECT_FALSE(I->isOneUseSingleSourceMask(2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isOneUseSingleSourceMask(
+        I->getShuffleMask(), 2));
+  }
+
+  // commuteShuffleMask
+  {
+    SmallVector<int, 4> M = {0, 2, 1, 3};
+    ShuffleVectorInst::commuteShuffleMask(M, 2);
+    EXPECT_THAT(M, testing::ContainerEq(ArrayRef<int>({2, 0, 3, 1})));
+  }
+
+  // isInterleaveMask
+  {
+    auto *I = CreateShuffleWithMask(0, 2, 1, 3);
+    EXPECT_TRUE(I->isInterleave(2));
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4));
+    SmallVector<unsigned, 4> StartIndexes;
+    EXPECT_TRUE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4, StartIndexes));
+    EXPECT_THAT(StartIndexes, testing::ContainerEq(ArrayRef<unsigned>({0, 2})));
+  }
+  {
+    auto *I = CreateShuffleWithMask(0, 3, 1, 2);
+    EXPECT_FALSE(I->isInterleave(2));
+    EXPECT_FALSE(sandboxir::ShuffleVectorInst::isInterleaveMask(
+        I->getShuffleMask(), 2, 4));
+  }
+
+  // isDeInterleaveMaskOfFactor
+  EXPECT_TRUE(sandboxir::ShuffleVectorInst::isDeInterleaveMaskOfFactor(
+      ArrayRef<int>({0, 2}), 2));
+  EXPECT_FALSE(sandboxir::ShuffleVectorInst::isDeInterleaveMaskOfFactor(
+      ArrayRef<int>({0, 1}), 2));
+
----------------
vporpo wrote:

Could you also add a check for the three-argument `isDeinterlaveMaskOfFactor(ArrayRef, unsigned, unsigned &Index)`.

https://github.com/llvm/llvm-project/pull/104891


More information about the llvm-commits mailing list