[llvm] [RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. (PR #93578)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 01:53:16 PDT 2024


================
@@ -16128,23 +16128,26 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
   return true;
 }
 
+// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+// This would be benefit for the cases where X and Y are both the same value
+// type of low precision vectors. Since the truncate would be lowered into
+// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+// restriction, such pattern would be expanded into a series of "vsetvli"
+// and "vnsrl" instructions later to reach this point.
 static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
-  // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
-  // This would be benefit for the cases where X and Y are both the same value
-  // type of low precision vectors. Since the truncate would be lowered into
-  // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
-  // restriction, such pattern would be expanded into a series of "vsetvli"
-  // and "vnsrl" instructions later to reach this point.
-  auto IsTruncNode = [](SDValue V) {
-    if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
-      return false;
-    SDValue VL = V.getOperand(2);
-    auto *C = dyn_cast<ConstantSDNode>(VL);
-    // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
-    bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
-                           (isa<RegisterSDNode>(VL) &&
-                            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
-    return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  bool IsVLMAX = isAllOnesConstant(VL) ||
+                 (isa<RegisterSDNode>(VL) &&
+                  cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+  if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL ||
+      Mask.getOperand(0) != VL)
----------------
lukel97 wrote:

It looks like IsVLMAXForVMSET didn't actually check the VMSET's VL in the first place. But since we can interpret any VMSET_VL as all ones as the tail is undefined would it make sense to drop the VL check here?

https://github.com/llvm/llvm-project/pull/93578


More information about the llvm-commits mailing list