[llvm] [RISCV][TTI] Implement getPartialReductionCost for the vqdotq cases (PR #140974)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu May 22 08:57:59 PDT 2025


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/140974

>From 507aa5e2ed170c470a6093ad49f1fc9d30bcd5a1 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 21 May 2025 13:16:37 -0700
Subject: [PATCH 1/2] [RISCV][TTI] Implement getPartialReductionCost for the
 vqdotq cases

Doing so tells the vectorizer that the partial.reduce intrinsic is
profitable to use over the plain extend/multiply/reduce.add sequence.
---
 .../Target/RISCV/RISCVTargetTransformInfo.cpp |  22 ++
 .../Target/RISCV/RISCVTargetTransformInfo.h   |   7 +
 .../RISCV/partial-reduce-dot-product.ll       | 223 ++++++++++++------
 3 files changed, 176 insertions(+), 76 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index db2f1141ee4b7..a0cfef70f5b0e 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -294,6 +294,28 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
              : TTI::PSK_Software;
 }
 
+InstructionCost RISCVTTIImpl::getPartialReductionCost(
+    unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+    ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+    TTI::PartialReductionExtendKind OpBExtend,
+    std::optional<unsigned> BinOp) const {
+
+  // FIXME: Guard zve32x properly here
+  if (!ST->hasStdExtZvqdotq() || Opcode != Instruction::Add || !BinOp ||
+      *BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+      !InputTypeA->isIntegerTy(8) || OpAExtend != OpBExtend ||
+      !AccumType->isIntegerTy(32) || !VF.isKnownMultipleOf(4) ||
+      !VF.isScalable())
+    return InstructionCost::getInvalid();
+
+  Type *Tp = VectorType::get(AccumType, VF);
+  std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
+  // Note: Asuming all vqdot* variants are equal cost
+  // TODO: Thread CostKind through this API
+  return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
+                                            TTI::TCK_RecipThroughput);
+}
+
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
   // Currently, the ExpandReductions pass can't expand scalable-vector
   // reductions, but we still request expansion as RVV doesn't support certain
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 53529d077fd54..f7a40e9bdedbf 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -107,6 +107,13 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   TargetTransformInfo::PopcntSupportKind
   getPopcntSupport(unsigned TyWidth) const override;
 
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+                          Type *AccumType, ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp) const override;
+
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {
     return ST->hasVInstructions();
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/RISCV/partial-reduce-dot-product.ll
index 61eec9332b857..23534143ed3a9 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/partial-reduce-dot-product.ll
@@ -5,42 +5,79 @@
 target triple = "riscv64-none-unknown-elf"
 
 define i32 @vqdot(ptr %a, ptr %b) #0 {
-; CHECK-LABEL: define i32 @vqdot(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
-; CHECK-NEXT:    [[TMP13]] = add <vscale x 4 x i32> [[TMP12]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP13]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
+; V-LABEL: define i32 @vqdot(
+; V-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; V-NEXT:  entry:
+; V-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; V-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; V-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; V:       vector.ph:
+; V-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; V-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; V-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; V-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; V-NEXT:    br label [[VECTOR_BODY:%.*]]
+; V:       vector.body:
+; V-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; V-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; V-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; V-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
+; V-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; V-NEXT:    [[TMP8:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; V-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; V-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
+; V-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; V-NEXT:    [[TMP11:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; V-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
+; V-NEXT:    [[TMP13]] = add <vscale x 4 x i32> [[TMP12]], [[VEC_PHI]]
+; V-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; V-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; V-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; V:       middle.block:
+; V-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP13]])
+; V-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; V-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; V:       scalar.ph:
+;
+; ZVQDOTQ-LABEL: define i32 @vqdot(
+; ZVQDOTQ-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; ZVQDOTQ-NEXT:  entry:
+; ZVQDOTQ-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; ZVQDOTQ-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; ZVQDOTQ-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; ZVQDOTQ:       vector.ph:
+; ZVQDOTQ-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; ZVQDOTQ-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; ZVQDOTQ-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; ZVQDOTQ-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; ZVQDOTQ-NEXT:    br label [[VECTOR_BODY:%.*]]
+; ZVQDOTQ:       vector.body:
+; ZVQDOTQ-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; ZVQDOTQ-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 1 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; ZVQDOTQ-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; ZVQDOTQ-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
+; ZVQDOTQ-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; ZVQDOTQ-NEXT:    [[TMP8:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; ZVQDOTQ-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; ZVQDOTQ-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
+; ZVQDOTQ-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; ZVQDOTQ-NEXT:    [[TMP11:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; ZVQDOTQ-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
+; ZVQDOTQ-NEXT:    [[PARTIAL_REDUCE]] = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add.nxv1i32.nxv4i32(<vscale x 1 x i32> [[VEC_PHI]], <vscale x 4 x i32> [[TMP12]])
+; ZVQDOTQ-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; ZVQDOTQ-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; ZVQDOTQ-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; ZVQDOTQ:       middle.block:
+; ZVQDOTQ-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv1i32(<vscale x 1 x i32> [[PARTIAL_REDUCE]])
+; ZVQDOTQ-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; ZVQDOTQ-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; ZVQDOTQ:       scalar.ph:
 ;
 entry:
   br label %for.body
@@ -66,42 +103,79 @@ for.exit:                        ; preds = %for.body
 
 
 define i32 @vqdotu(ptr %a, ptr %b) #0 {
-; CHECK-LABEL: define i32 @vqdotu(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
-; CHECK-NEXT:    [[TMP11:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
-; CHECK-NEXT:    [[TMP13]] = add <vscale x 4 x i32> [[TMP12]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP13]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
+; V-LABEL: define i32 @vqdotu(
+; V-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
+; V-NEXT:  entry:
+; V-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; V-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; V-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; V:       vector.ph:
+; V-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; V-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; V-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; V-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; V-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; V-NEXT:    br label [[VECTOR_BODY:%.*]]
+; V:       vector.body:
+; V-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; V-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; V-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; V-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
+; V-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; V-NEXT:    [[TMP8:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; V-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; V-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
+; V-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; V-NEXT:    [[TMP11:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; V-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
+; V-NEXT:    [[TMP13]] = add <vscale x 4 x i32> [[TMP12]], [[VEC_PHI]]
+; V-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; V-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; V-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; V:       middle.block:
+; V-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP13]])
+; V-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; V-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; V:       scalar.ph:
+;
+; ZVQDOTQ-LABEL: define i32 @vqdotu(
+; ZVQDOTQ-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
+; ZVQDOTQ-NEXT:  entry:
+; ZVQDOTQ-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; ZVQDOTQ-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; ZVQDOTQ-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; ZVQDOTQ:       vector.ph:
+; ZVQDOTQ-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; ZVQDOTQ-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; ZVQDOTQ-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; ZVQDOTQ-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; ZVQDOTQ-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; ZVQDOTQ-NEXT:    br label [[VECTOR_BODY:%.*]]
+; ZVQDOTQ:       vector.body:
+; ZVQDOTQ-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; ZVQDOTQ-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 1 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; ZVQDOTQ-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; ZVQDOTQ-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
+; ZVQDOTQ-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; ZVQDOTQ-NEXT:    [[TMP8:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; ZVQDOTQ-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; ZVQDOTQ-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP9]], i32 0
+; ZVQDOTQ-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; ZVQDOTQ-NEXT:    [[TMP11:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; ZVQDOTQ-NEXT:    [[TMP12:%.*]] = mul <vscale x 4 x i32> [[TMP11]], [[TMP8]]
+; ZVQDOTQ-NEXT:    [[PARTIAL_REDUCE]] = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add.nxv1i32.nxv4i32(<vscale x 1 x i32> [[VEC_PHI]], <vscale x 4 x i32> [[TMP12]])
+; ZVQDOTQ-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; ZVQDOTQ-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; ZVQDOTQ-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; ZVQDOTQ:       middle.block:
+; ZVQDOTQ-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv1i32(<vscale x 1 x i32> [[PARTIAL_REDUCE]])
+; ZVQDOTQ-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; ZVQDOTQ-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; ZVQDOTQ:       scalar.ph:
 ;
 entry:
   br label %for.body
@@ -128,7 +202,7 @@ for.exit:                        ; preds = %for.body
 
 define i32 @vqdotsu(ptr %a, ptr %b) #0 {
 ; CHECK-LABEL: define i32 @vqdotsu(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
@@ -245,6 +319,3 @@ for.body:                                         ; preds = %for.body, %entry
 for.exit:                        ; preds = %for.body
   ret i32 %add
 }
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; V: {{.*}}
-; ZVQDOTQ: {{.*}}

>From 7d4d4a65f1c0402a7270e0df220c2b6c07c23606 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 22 May 2025 08:35:29 -0700
Subject: [PATCH 2/2] Guard zve32x

---
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index a0cfef70f5b0e..d54ad63404578 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -300,12 +300,13 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
     TTI::PartialReductionExtendKind OpBExtend,
     std::optional<unsigned> BinOp) const {
 
-  // FIXME: Guard zve32x properly here
-  if (!ST->hasStdExtZvqdotq() || Opcode != Instruction::Add || !BinOp ||
-      *BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
-      !InputTypeA->isIntegerTy(8) || OpAExtend != OpBExtend ||
-      !AccumType->isIntegerTy(32) || !VF.isKnownMultipleOf(4) ||
-      !VF.isScalable())
+  // zve32x is broken for partial_reduce_umla, but let's make sure we
+  // don't generate them.
+  if (!ST->hasStdExtZvqdotq() || ST->getELen() < 64 ||
+      Opcode != Instruction::Add || !BinOp || *BinOp != Instruction::Mul ||
+      InputTypeA != InputTypeB || !InputTypeA->isIntegerTy(8) ||
+      OpAExtend != OpBExtend || !AccumType->isIntegerTy(32) ||
+      !VF.isKnownMultipleOf(4) || !VF.isScalable())
     return InstructionCost::getInvalid();
 
   Type *Tp = VectorType::get(AccumType, VF);



More information about the llvm-commits mailing list