[llvm] 1a78c64 - Scalarizer: explicitly exclude scalable vectors

Nicolai Hähnle via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 8 11:48:25 PST 2022


Author: Nicolai Hähnle
Date: 2022-12-08T20:48:14+01:00
New Revision: 1a78c64654cde1960be70fe532e999aff2aaa5bc

URL: https://github.com/llvm/llvm-project/commit/1a78c64654cde1960be70fe532e999aff2aaa5bc
DIFF: https://github.com/llvm/llvm-project/commit/1a78c64654cde1960be70fe532e999aff2aaa5bc.diff

LOG: Scalarizer: explicitly exclude scalable vectors

They are unsupported and would previously crash, now we just skip them.

Hypothetically, one could consider "scalarizing" a <vscale x n x T> into
n copies of <vscale x 1 x T>. But (1) it's unclear how to do that
because insertelement etc. don't work with scalable vectors in the
required way, and (2) there is no user of such functionality.

Differential Revision: https://reviews.llvm.org/D139358

Added: 
    llvm/test/Transforms/Scalarizer/ignore-scalable-vectors.ll

Modified: 
    llvm/lib/Transforms/Scalar/Scalarizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 3cab25649aca..bbce942965ea 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -174,7 +174,7 @@ struct VectorLayout {
   }
 
   // The type of the vector.
-  VectorType *VecTy = nullptr;
+  FixedVectorType *VecTy = nullptr;
 
   // The type of each element.
   Type *ElemTy = nullptr;
@@ -488,7 +488,7 @@ ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
                                    const DataLayout &DL) {
   VectorLayout Layout;
   // Make sure we're dealing with a vector.
-  Layout.VecTy = dyn_cast<VectorType>(Ty);
+  Layout.VecTy = dyn_cast<FixedVectorType>(Ty);
   if (!Layout.VecTy)
     return std::nullopt;
   // Check that we're dealing with full-byte elements.
@@ -504,11 +504,11 @@ ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
 // to create an instruction like I with operand X and name Name.
 template<typename Splitter>
 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
-  VectorType *VT = dyn_cast<VectorType>(I.getType());
+  auto *VT = dyn_cast<FixedVectorType>(I.getType());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   IRBuilder<> Builder(&I);
   Scatterer Op = scatter(&I, I.getOperand(0));
   assert(Op.size() == NumElems && "Mismatched unary operation");
@@ -524,11 +524,11 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
 // to create an instruction like I with operands X and Y and name Name.
 template<typename Splitter>
 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
-  VectorType *VT = dyn_cast<VectorType>(I.getType());
+  auto *VT = dyn_cast<FixedVectorType>(I.getType());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   IRBuilder<> Builder(&I);
   Scatterer VOp0 = scatter(&I, I.getOperand(0));
   Scatterer VOp1 = scatter(&I, I.getOperand(1));
@@ -559,7 +559,7 @@ static Function *getScalarIntrinsicDeclaration(Module *M,
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
-  VectorType *VT = dyn_cast<VectorType>(CI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(CI.getType());
   if (!VT)
     return false;
 
@@ -571,7 +571,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   unsigned NumArgs = CI.arg_size();
 
   ValueVector ScalarOperands(NumArgs);
@@ -624,11 +624,11 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 }
 
 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
-  VectorType *VT = dyn_cast<VectorType>(SI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(SI.getType());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   IRBuilder<> Builder(&SI);
   Scatterer VOp1 = scatter(&SI, SI.getOperand(1));
   Scatterer VOp2 = scatter(&SI, SI.getOperand(2));
@@ -677,12 +677,12 @@ bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
 }
 
 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-  VectorType *VT = dyn_cast<VectorType>(GEPI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(GEPI.getType());
   if (!VT)
     return false;
 
   IRBuilder<> Builder(&GEPI);
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   unsigned NumIndices = GEPI.getNumIndices();
 
   // The base pointer might be scalar even if it's a vector GEP. In those cases,
@@ -723,11 +723,11 @@ bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
 }
 
 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
-  VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
+  auto *VT = dyn_cast<FixedVectorType>(CI.getDestTy());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   IRBuilder<> Builder(&CI);
   Scatterer Op0 = scatter(&CI, CI.getOperand(0));
   assert(Op0.size() == NumElems && "Mismatched cast");
@@ -741,13 +741,13 @@ bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
 }
 
 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
-  VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
-  VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
+  auto *DstVT = dyn_cast<FixedVectorType>(BCI.getDestTy());
+  auto *SrcVT = dyn_cast<FixedVectorType>(BCI.getSrcTy());
   if (!DstVT || !SrcVT)
     return false;
 
-  unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements();
-  unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();
+  unsigned DstNumElems = DstVT->getNumElements();
+  unsigned SrcNumElems = SrcVT->getNumElements();
   IRBuilder<> Builder(&BCI);
   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
   ValueVector Res;
@@ -796,11 +796,11 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
 }
 
 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
-  VectorType *VT = dyn_cast<VectorType>(IEI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(IEI.getType());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   IRBuilder<> Builder(&IEI);
   Scatterer Op0 = scatter(&IEI, IEI.getOperand(0));
   Value *NewElt = IEI.getOperand(1);
@@ -831,11 +831,11 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
 }
 
 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
-  VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType());
+  auto *VT = dyn_cast<FixedVectorType>(EEI.getOperand(0)->getType());
   if (!VT)
     return false;
 
-  unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumSrcElems = VT->getNumElements();
   IRBuilder<> Builder(&EEI);
   Scatterer Op0 = scatter(&EEI, EEI.getOperand(0));
   Value *ExtIdx = EEI.getOperand(1);
@@ -863,11 +863,11 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
 }
 
 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
-  VectorType *VT = dyn_cast<VectorType>(SVI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(SVI.getType());
   if (!VT)
     return false;
 
-  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+  unsigned NumElems = VT->getNumElements();
   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
   ValueVector Res;
@@ -887,7 +887,7 @@ bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
 }
 
 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
-  VectorType *VT = dyn_cast<VectorType>(PHI.getType());
+  auto *VT = dyn_cast<FixedVectorType>(PHI.getType());
   if (!VT)
     return false;
 
@@ -982,9 +982,9 @@ bool ScalarizerVisitor::finish() {
       // The value is still needed, so recreate it using a series of
       // InsertElements.
       Value *Res = PoisonValue::get(Op->getType());
-      if (auto *Ty = dyn_cast<VectorType>(Op->getType())) {
+      if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
         BasicBlock *BB = Op->getParent();
-        unsigned Count = cast<FixedVectorType>(Ty)->getNumElements();
+        unsigned Count = Ty->getNumElements();
         IRBuilder<> Builder(Op);
         if (isa<PHINode>(Op))
           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());

diff  --git a/llvm/test/Transforms/Scalarizer/ignore-scalable-vectors.ll b/llvm/test/Transforms/Scalarizer/ignore-scalable-vectors.ll
new file mode 100644
index 000000000000..cb29457d6f3d
--- /dev/null
+++ b/llvm/test/Transforms/Scalarizer/ignore-scalable-vectors.ll
@@ -0,0 +1,11 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt %s -passes=scalarizer -S -o - | FileCheck %s
+
+define <vscale x 1 x i32> @test1(<vscale x 1 x i32> %a, <vscale x 1 x i32> %b) {
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:    [[R:%.*]] = add <vscale x 1 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret <vscale x 1 x i32> [[R]]
+;
+  %r = add <vscale x 1 x i32> %a, %b
+  ret <vscale x 1 x i32> %r
+}


        


More information about the llvm-commits mailing list