[llvm] [LV][VPlan] Add initial support for CSA vectorization (PR #106560)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 18:51:51 PDT 2024


================
@@ -668,6 +672,94 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
     }
     return NewPhi;
   }
+  case VPInstruction::CSAInitMask: {
+    if (Part == 0) {
+      Value *InitMask = ConstantAggregateZero::get(VectorType::get(
+          Type::getInt1Ty(State.Builder.getContext()), State.VF));
+      State.set(this, InitMask, Part);
+      return InitMask;
+    }
+    Value *V = State.get(this, Part - 1);
+    return V;
+  }
+  case VPInstruction::CSAInitData: {
+    if (Part == 0) {
+      Type *ElemTyp = getOperand(0)->getUnderlyingValue()->getType();
+      Value *InitData = PoisonValue::get(VectorType::get(ElemTyp, State.VF));
+      State.set(this, InitData, Part);
+      return InitData;
+    }
+    Value *V = State.get(this, Part - 1);
+    return V;
+  }
+  case VPInstruction::CSAMaskPhi: {
+    if (Part == 0) {
+      IRBuilder<>::InsertPointGuard Guard(State.Builder);
+      State.Builder.SetInsertPoint(State.CFG.PrevBB->getFirstNonPHI());
+      BasicBlock *PreheaderBB = State.CFG.getPreheaderBBFor(this);
+      Value *InitMask = State.get(getOperand(0), Part);
+      PHINode *MaskPhi =
+          State.Builder.CreatePHI(InitMask->getType(), 2, "csa.mask.phi");
+      MaskPhi->addIncoming(InitMask, PreheaderBB);
+      State.set(this, MaskPhi, Part);
+      return MaskPhi;
+    }
+    Value *V = State.get(this, Part - 1);
+    return V;
+  }
+  case VPInstruction::CSAMaskSel: {
+    Value *WidenedCond = State.get(getOperand(0), Part);
+    Value *MaskPhi = State.get(getOperand(1), Part);
+    Value *AnyActive = State.get(getOperand(2), Part, /*NeedsScalar=*/true);
+    // If not the first Part, use the mask from the previous unrolled Part
+    Value *OldMask = Part == 0 ? MaskPhi : State.get(this, Part - 1);
+    Value *MaskSel = State.Builder.CreateSelect(AnyActive, WidenedCond, OldMask,
+                                                "csa.mask.sel");
+    // MaskPhi wants to use the most recently updated mask. That's the one
+    // that corresponds to the last Part.
+    if (Part == State.UF - 1)
+      cast<PHINode>(MaskPhi)->addIncoming(MaskSel, State.CFG.PrevBB);
+    return MaskSel;
+  }
+  case VPInstruction::CSAAnyActive: {
+    Value *WidenedCond = State.get(getOperand(0), Part);
+    return Builder.CreateOrReduce(WidenedCond);
+  }
+  case VPInstruction::CSAAnyActiveEVL: {
+    Value *WidenedCond = State.get(getOperand(0), Part);
+    Value *AllOnesMask = Constant::getAllOnesValue(
+        VectorType::get(Type::getInt1Ty(State.Builder.getContext()), State.VF));
+    Value *EVL = State.get(getOperand(1), Part, /*NeedsScalar=*/true);
+
+    Value *StartValue =
+        ConstantInt::get(WidenedCond->getType()->getScalarType(), 0);
+    Value *AnyActive = State.Builder.CreateIntrinsic(
+        WidenedCond->getType()->getScalarType(), Intrinsic::vp_reduce_or,
+        {StartValue, WidenedCond, AllOnesMask, EVL}, nullptr,
+        "csa.cond.anyactive");
+    return AnyActive;
+  }
+  case VPInstruction::CSAVLPhi: {
+    IRBuilder<>::InsertPointGuard Guard(State.Builder);
+    State.Builder.SetInsertPoint(State.CFG.PrevBB->getFirstNonPHI());
+    BasicBlock *PreheaderBB = State.CFG.getPreheaderBBFor(this);
+
+    // InitVL can be anything since it won't be used if no mask was active
+    Value *InitVL = ConstantInt::get(State.Builder.getInt32Ty(), 0);
----------------
michaelmaitland wrote:

According to the [LangRef](https://llvm.org/docs/LangRef.html#vector-predication-intrinsics):

> The explicit vector length parameter always has the type i32 and is an unsigned integer value. The explicit vector length parameter (%evl) is in the range:

https://github.com/llvm/llvm-project/pull/106560


More information about the llvm-commits mailing list