[llvm] [NFC][RISCV] Unify all zvfbfa vl patterns and sd node patterns (PR #171072)
Brandon Wu via llvm-commits
llvm-commits at lists.llvm.org
Sun Dec 7 19:19:37 PST 2025
https://github.com/4vtomat created https://github.com/llvm/llvm-project/pull/171072
This patch try to move all vl patterns and sd node patterns to
RISCVInstrInfoVVLPatterns.td and RISCVInstrInfoVSDPatterns.td
respectively. It removes redefinition of pattern classes for zvfbfa and
make it easier to maintain and change.
Note: this does not include intrinsic patterns, if we want to also unify
intrinsic patterns we need to also move pseudo instruction definitions
of zvfbfa to RISCVInstrInfoVPseudos.td.
>From 35fdda9ffe0cce788616c47ec44c5899f7e175ce Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Sun, 7 Dec 2025 19:10:14 -0800
Subject: [PATCH] [NFC][RISCV] Unify all zvfbfa vl patterns and sd node
patterns
This patch try to move all vl patterns and sd node patterns to
RISCVInstrInfoVVLPatterns.td and RISCVInstrInfoVSDPatterns.td
respectively. It removes redefinition of pattern classes for zvfbfa and
make it easier to maintain and change.
Note: this does not include intrinsic patterns, if we want to also unify
intrinsic patterns we need to also move pseudo instruction definitions
of zvfbfa to RISCVInstrInfoVPseudos.td.
---
llvm/lib/Target/RISCV/RISCVFeatures.td | 1 +
llvm/lib/Target/RISCV/RISCVInstrInfoV.td | 3 +
.../Target/RISCV/RISCVInstrInfoVPseudos.td | 44 ++--
.../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 98 +++++---
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 138 ++++++++---
llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td | 223 ------------------
6 files changed, 195 insertions(+), 312 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td
index 0c75312847c87..05f50cba6e9be 100644
--- a/llvm/lib/Target/RISCV/RISCVFeatures.td
+++ b/llvm/lib/Target/RISCV/RISCVFeatures.td
@@ -908,6 +908,7 @@ def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minima
def HasVInstructionsBF16Minimal : Predicate<"Subtarget->hasVInstructionsBF16Minimal()">;
def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">;
+def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">;
def HasVInstructionsF64 : Predicate<"Subtarget->hasVInstructionsF64()">;
def HasVInstructionsFullMultiply : Predicate<"Subtarget->hasVInstructionsFullMultiply()">;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td
index 594a75a4746d4..9354b63bced53 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td
@@ -1840,3 +1840,6 @@ let Predicates = [HasVInstructionsI64, IsRV64] in {
include "RISCVInstrInfoVPseudos.td"
include "RISCVInstrInfoZvfbf.td"
+// Include the non-intrinsic ISel patterns
+include "RISCVInstrInfoVVLPatterns.td"
+include "RISCVInstrInfoVSDPatterns.td"
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index e36204c536c0d..cdbeb0c1046d2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -473,17 +473,27 @@ defset list<VTypeInfoToWide> AllWidenableIntVectors = {
def : VTypeInfoToWide<VI32M4, VI64M8>;
}
-defset list<VTypeInfoToWide> AllWidenableFloatVectors = {
- def : VTypeInfoToWide<VF16MF4, VF32MF2>;
- def : VTypeInfoToWide<VF16MF2, VF32M1>;
- def : VTypeInfoToWide<VF16M1, VF32M2>;
- def : VTypeInfoToWide<VF16M2, VF32M4>;
- def : VTypeInfoToWide<VF16M4, VF32M8>;
+defset list<VTypeInfoToWide> AllWidenableFloatAndBF16Vectors = {
+ defset list<VTypeInfoToWide> AllWidenableFloatVectors = {
+ def : VTypeInfoToWide<VF16MF4, VF32MF2>;
+ def : VTypeInfoToWide<VF16MF2, VF32M1>;
+ def : VTypeInfoToWide<VF16M1, VF32M2>;
+ def : VTypeInfoToWide<VF16M2, VF32M4>;
+ def : VTypeInfoToWide<VF16M4, VF32M8>;
- def : VTypeInfoToWide<VF32MF2, VF64M1>;
- def : VTypeInfoToWide<VF32M1, VF64M2>;
- def : VTypeInfoToWide<VF32M2, VF64M4>;
- def : VTypeInfoToWide<VF32M4, VF64M8>;
+ def : VTypeInfoToWide<VF32MF2, VF64M1>;
+ def : VTypeInfoToWide<VF32M1, VF64M2>;
+ def : VTypeInfoToWide<VF32M2, VF64M4>;
+ def : VTypeInfoToWide<VF32M4, VF64M8>;
+ }
+
+ defset list<VTypeInfoToWide> AllWidenableBF16ToFloatVectors = {
+ def : VTypeInfoToWide<VBF16MF4, VF32MF2>;
+ def : VTypeInfoToWide<VBF16MF2, VF32M1>;
+ def : VTypeInfoToWide<VBF16M1, VF32M2>;
+ def : VTypeInfoToWide<VBF16M2, VF32M4>;
+ def : VTypeInfoToWide<VBF16M4, VF32M8>;
+ }
}
defset list<VTypeInfoToFraction> AllFractionableVF2IntVectors = {
@@ -543,14 +553,6 @@ defset list<VTypeInfoToWide> AllWidenableIntToFloatVectors = {
def : VTypeInfoToWide<VI32M4, VF64M8>;
}
-defset list<VTypeInfoToWide> AllWidenableBF16ToFloatVectors = {
- def : VTypeInfoToWide<VBF16MF4, VF32MF2>;
- def : VTypeInfoToWide<VBF16MF2, VF32M1>;
- def : VTypeInfoToWide<VBF16M1, VF32M2>;
- def : VTypeInfoToWide<VBF16M2, VF32M4>;
- def : VTypeInfoToWide<VBF16M4, VF32M8>;
-}
-
// This class holds the record of the RISCVVPseudoTable below.
// This represents the information we need in codegen for each pseudo.
// The definition should be consistent with `struct PseudoInfo` in
@@ -780,7 +782,7 @@ class GetVRegNoV0<VReg VRegClass> {
class GetVTypePredicates<VTypeInfo vti> {
list<Predicate> Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16],
- !eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal],
+ !eq(vti.Scalar, bf16) : [HasVInstructionsBF16],
!eq(vti.Scalar, f32) : [HasVInstructionsAnyF],
!eq(vti.Scalar, f64) : [HasVInstructionsF64],
!eq(vti.SEW, 64) : [HasVInstructionsI64],
@@ -7326,7 +7328,3 @@ defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
// 16.5. Vector Compress Instruction
//===----------------------------------------------------------------------===//
defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllVectors>;
-
-// Include the non-intrinsic ISel patterns
-include "RISCVInstrInfoVVLPatterns.td"
-include "RISCVInstrInfoVSDPatterns.td"
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index a67112b9981b8..6b72a584acb00 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -215,13 +215,16 @@ multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_na
}
multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name,
- bit isSEWAware = 0, bit isBF16 = 0> {
- foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
+ bit isSEWAware = 0, bit supportBF16 = 0> {
+ foreach vti = !if(supportBF16, AllFloatAndBF16Vectors, AllFloatVectors) in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : VPatBinarySDNode_VV_RM<vop, instruction_name,
+ def : VPatBinarySDNode_VV_RM<vop, instruction_name #
+ !if(!eq(vti.Scalar, bf16), "_ALT", ""),
vti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, vti.AVL, vti.RegClass, isSEWAware>;
- def : VPatBinarySDNode_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
+ def : VPatBinarySDNode_VF_RM<vop, instruction_name#
+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
+ "_V"#vti.ScalarSuffix,
vti.Vector, vti.Vector, vti.Scalar,
vti.Log2SEW, vti.LMul, vti.AVL, vti.RegClass,
vti.ScalarRegClass, isSEWAware>;
@@ -246,14 +249,16 @@ multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_nam
}
multiclass VPatBinaryFPSDNode_R_VF_RM<SDPatternOperator vop, string instruction_name,
- bit isSEWAware = 0, bit isBF16 = 0> {
- foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in
+ bit isSEWAware = 0, bit supportBF16 = 0> {
+ foreach fvti = !if(supportBF16, AllFloatAndBF16Vectors, AllFloatVectors) in
let Predicates = GetVTypePredicates<fvti>.Predicates in
def : Pat<(fvti.Vector (vop (fvti.Vector (SplatFPOp fvti.Scalar:$rs2)),
(fvti.Vector fvti.RegClass:$rs1))),
(!cast<Instruction>(
!if(isSEWAware,
- instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
+ instruction_name#
+ !if(!eq(fvti.Scalar, bf16), "_ALT", "")#
+ "_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX))
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs1,
@@ -664,11 +669,10 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates,
+ let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
!if(!eq(vti.Scalar, bf16),
[HasStdExtZvfbfwma],
- [])) in {
+ GetVTypePredicates<vti>.Predicates)) in {
def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
(vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask), (XLenVT srcvalue))),
@@ -676,7 +680,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
- (!cast<Instruction>(instruction_name#"_VV_"#suffix)
+ (!cast<Instruction>(instruction_name#
+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
+ "_VV_"#suffix)
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
@@ -688,7 +694,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
- (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix)
+ (!cast<Instruction>(instruction_name#
+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
+ "_V"#vti.ScalarSuffix#"_"#suffix)
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
@@ -1201,16 +1209,20 @@ foreach mti = AllMasks in {
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
-defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", isSEWAware=1>;
-defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", isSEWAware=1>;
-defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", isSEWAware=1>;
+defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", isSEWAware=1,
+ supportBF16=1>;
+defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", isSEWAware=1,
+ supportBF16=1>;
+defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", isSEWAware=1,
+ supportBF16=1>;
// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fadd, "PseudoVFWADD">;
defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fsub, "PseudoVFWSUB">;
// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
-defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", isSEWAware=1>;
+defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", isSEWAware=1,
+ supportBF16=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fdiv, "PseudoVFDIV", isSEWAware=1>;
defm : VPatBinaryFPSDNode_R_VF_RM<any_fdiv, "PseudoVFRDIV", isSEWAware=1>;
@@ -1314,14 +1326,15 @@ foreach fvti = AllFloatVectors in {
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC",
- AllWidenableFloatVectors>;
+ AllWidenableFloatAndBF16Vectors>;
defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;
-foreach vti = AllFloatVectors in {
+foreach vti = AllFloatAndBF16Vectors in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
// 13.8. Vector Floating-Point Square-Root Instruction
+ if !ne(vti.Scalar, bf16) then
def : Pat<(any_fsqrt (vti.Vector vti.RegClass:$rs2)),
(!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
@@ -1333,34 +1346,46 @@ foreach vti = AllFloatVectors in {
// 13.12. Vector Floating-Point Sign-Injection Instructions
def : Pat<(fabs (vti.Vector vti.RegClass:$rs)),
- (!cast<Instruction>("PseudoVFSGNJX_VV_"# vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJX_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
// Handle fneg with VFSGNJN using the same input for both operands.
def : Pat<(fneg (vti.Vector vti.RegClass:$rs)),
- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJN_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector vti.RegClass:$rs2))),
- (!cast<Instruction>("PseudoVFSGNJ_VV_"# vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJ_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))),
- (!cast<Instruction>("PseudoVFSGNJ_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJ_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (fneg vti.RegClass:$rs2)))),
- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJN_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))),
- (!cast<Instruction>("PseudoVFSGNJN_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJN_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
}
@@ -1446,13 +1471,26 @@ defm : VPatNConvertFP2ISDNode_W<any_fp_to_sint, "PseudoVFNCVT_RTZ_X_F_W">;
defm : VPatNConvertFP2ISDNode_W<any_fp_to_uint, "PseudoVFNCVT_RTZ_XU_F_W">;
defm : VPatNConvertI2FPSDNode_W_RM<any_sint_to_fp, "PseudoVFNCVT_F_X_W">;
defm : VPatNConvertI2FPSDNode_W_RM<any_uint_to_fp, "PseudoVFNCVT_F_XU_W">;
-foreach fvtiToFWti = AllWidenableFloatVectors in {
+foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
GetVTypeMinimalPredicates<fwti>.Predicates) in
def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
- (!cast<Instruction>("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
+ (!cast<Instruction>("PseudoVFNCVT"#
+ !if(!eq(fvti.Scalar, bf16), "BF16", "")#
+ "_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
+ (fvti.Vector (IMPLICIT_DEF)),
+ fwti.RegClass:$rs1,
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ fvti.AVL, fvti.Log2SEW, TA_MA)>;
+ // Define vfncvt.f.f.w for bf16 when Zvfbfa is enable.
+ if !eq(fvti.Scalar, bf16) then
+ let Predicates = [HasStdExtZvfbfa] in
+ def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
+ (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
(fvti.Vector (IMPLICIT_DEF)),
fwti.RegClass:$rs1,
// Value to indicate no rounding mode change in
@@ -1464,10 +1502,10 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
//===----------------------------------------------------------------------===//
// Vector Element Extracts
//===----------------------------------------------------------------------===//
-foreach vti = NoGroupFloatVectors in {
- defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_",
- vti.ScalarSuffix,
- "_S"));
+foreach vti = !listconcat(NoGroupFloatVectors, NoGroupBF16Vectors) in {
+ defvar vfmv_f_s_inst =
+ !cast<Instruction>(!strconcat("PseudoVFMV_", vti.ScalarSuffix,
+ "_S", !if(!eq(vti.Scalar, bf16), "_ALT", "")));
// Only pattern-match extract-element operations where the index is 0. Any
// other index will have been custom-lowered to slide the vector correctly
// into place.
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 38edab5400291..651070b8624e8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1058,14 +1058,17 @@ multiclass VPatBinaryFPVL_VV_VF<SDPatternOperator vop, string instruction_name,
}
multiclass VPatBinaryFPVL_VV_VF_RM<SDPatternOperator vop, string instruction_name,
- bit isSEWAware = 0, bit isBF16 = 0> {
- foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
+ bit isSEWAware = 0, bit supportBF16 = 0> {
+ foreach vti = !if(supportBF16, AllFloatAndBF16Vectors, AllFloatVectors) in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : VPatBinaryVL_V_RM<vop, instruction_name, "VV",
+ def : VPatBinaryVL_V_RM<vop, instruction_name #
+ !if(!eq(vti.Scalar, bf16), "_ALT", ""), "VV",
vti.Vector, vti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
vti.RegClass, isSEWAware>;
- def : VPatBinaryVL_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
+ def : VPatBinaryVL_VF_RM<vop, instruction_name#
+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
+ "_V"#vti.ScalarSuffix,
vti.Vector, vti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
vti.ScalarRegClass, isSEWAware>;
@@ -1093,8 +1096,8 @@ multiclass VPatBinaryFPVL_R_VF<SDPatternOperator vop, string instruction_name,
}
multiclass VPatBinaryFPVL_R_VF_RM<SDPatternOperator vop, string instruction_name,
- bit isSEWAware = 0, bit isBF16 = 0> {
- foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
+ bit isSEWAware = 0, bit supportBF16 = 0> {
+ foreach fvti = !if(supportBF16, AllFloatAndBF16Vectors, AllFloatVectors) in {
let Predicates = GetVTypePredicates<fvti>.Predicates in
def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
fvti.RegClass:$rs1,
@@ -1103,7 +1106,9 @@ multiclass VPatBinaryFPVL_R_VF_RM<SDPatternOperator vop, string instruction_name
VLOpFrag)),
(!cast<Instruction>(
!if(isSEWAware,
- instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK",
+ instruction_name#
+ !if(!eq(fvti.Scalar, bf16), "_ALT", "")#
+ "_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK",
instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_MASK"))
fvti.RegClass:$passthru,
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
@@ -1832,16 +1837,17 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates,
+ let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
!if(!eq(vti.Scalar, bf16),
[HasStdExtZvfbfwma],
- [])) in {
+ GetVTypePredicates<vti>.Predicates)) in {
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>(instruction_name#"_VV_"#suffix#"_MASK")
+ (!cast<Instruction>(instruction_name#
+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
+ "_VV_"#suffix#"_MASK")
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask VMV0:$vm),
// Value to indicate no rounding mode change in
@@ -1852,7 +1858,9 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
+ (!cast<Instruction>(instruction_name#
+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
+ "_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask VMV0:$vm),
// Value to indicate no rounding mode change in
@@ -2296,9 +2304,12 @@ foreach vtiTowti = AllWidenableIntVectors in {
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
-defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD", isSEWAware=1>;
-defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fsub_vl, "PseudoVFSUB", isSEWAware=1>;
-defm : VPatBinaryFPVL_R_VF_RM<any_riscv_fsub_vl, "PseudoVFRSUB", isSEWAware=1>;
+defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD", isSEWAware=1,
+ supportBF16=1>;
+defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fsub_vl, "PseudoVFSUB", isSEWAware=1,
+ supportBF16=1>;
+defm : VPatBinaryFPVL_R_VF_RM<any_riscv_fsub_vl, "PseudoVFRSUB", isSEWAware=1,
+ supportBF16=1>;
// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
defm : VPatBinaryFPWVL_VV_VF_WV_WF_RM<riscv_vfwadd_vl, riscv_vfwadd_w_vl,
@@ -2307,7 +2318,8 @@ defm : VPatBinaryFPWVL_VV_VF_WV_WF_RM<riscv_vfwsub_vl, riscv_vfwsub_w_vl,
"PseudoVFWSUB", isSEWAware=1>;
// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
-defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fmul_vl, "PseudoVFMUL", isSEWAware=1>;
+defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fmul_vl, "PseudoVFMUL", isSEWAware=1,
+ supportBF16=1>;
defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fdiv_vl, "PseudoVFDIV", isSEWAware=1>;
defm : VPatBinaryFPVL_R_VF_RM<any_riscv_fdiv_vl, "PseudoVFRDIV", isSEWAware=1>;
@@ -2321,7 +2333,8 @@ defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmadd_vl, "PseudoVFNMADD">;
defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmsub_vl, "PseudoVFNMSUB">;
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
-defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC",
+ AllWidenableFloatAndBF16Vectors>;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
@@ -2348,9 +2361,10 @@ defm : VPatFPSetCCVL_VV_VF_FV<any_riscv_fsetccs_vl, SETLE,
defm : VPatFPSetCCVL_VV_VF_FV<any_riscv_fsetccs_vl, SETOLE,
"PseudoVMFLE", "PseudoVMFGE">;
-foreach vti = AllFloatVectors in {
+foreach vti = AllFloatAndBF16Vectors in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
// 13.8. Vector Floating-Point Square-Root Instruction
+ if !ne(vti.Scalar, bf16) then
def : Pat<(any_riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask VMV0:$vm),
VLOpFrag),
(!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX # "_E" # vti.SEW # "_MASK")
@@ -2364,14 +2378,18 @@ foreach vti = AllFloatVectors in {
// 13.12. Vector Floating-Point Sign-Injection Instructions
def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJX_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK")
+ (!cast<Instruction>("PseudoVFSGNJX_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK")
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
TA_MA)>;
// Handle fneg with VFSGNJN using the same input for both operands.
def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK")
+ (!cast<Instruction>("PseudoVFSGNJN_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK")
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
TA_MA)>;
@@ -2381,7 +2399,9 @@ foreach vti = AllFloatVectors in {
vti.RegClass:$passthru,
(vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJ_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
+ (!cast<Instruction>("PseudoVFSGNJ_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
vti.RegClass:$passthru, vti.RegClass:$rs1,
vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
TAIL_AGNOSTIC)>;
@@ -2393,7 +2413,9 @@ foreach vti = AllFloatVectors in {
srcvalue,
(vti.Mask true_mask),
VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
+ (!cast<Instruction>("PseudoVFSGNJN_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>;
@@ -2402,11 +2424,15 @@ foreach vti = AllFloatVectors in {
vti.RegClass:$passthru,
(vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJ_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
+ (!cast<Instruction>("PseudoVFSGNJ_"#
+ !if(!eq(vti.Scalar, bf16), "ALT_", "")#
+ "V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
vti.RegClass:$passthru, vti.RegClass:$rs1,
vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
TAIL_AGNOSTIC)>;
+ // TODO: Support for Zvfbfa
+ if !ne(vti.Scalar, bf16) then {
// Rounding without exception to implement nearbyint.
def : Pat<(any_riscv_vfround_noexcept_vl (vti.Vector vti.RegClass:$rs1),
(vti.Mask VMV0:$vm), VLOpFrag),
@@ -2420,6 +2446,7 @@ foreach vti = AllFloatVectors in {
(!cast<Instruction>("PseudoVFCLASS_V_"# vti.LMul.MX #"_MASK")
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
(vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, TA_MA)>;
+ }
}
}
@@ -2476,7 +2503,7 @@ foreach fvti = AllFloatVectors in {
}
}
-foreach fvti = AllFloatVectors in {
+foreach fvti = AllFloatAndBF16Vectors in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
// 13.16. Vector Floating-Point Move Instruction
@@ -2492,11 +2519,13 @@ foreach fvti = AllFloatVectors in {
}
}
-foreach fvti = AllFloatVectors in {
+foreach fvti = AllFloatAndBF16Vectors in {
let Predicates = GetVTypePredicates<fvti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)),
- (!cast<Instruction>("PseudoVFMV_V_" # fvti.ScalarSuffix # "_" #
+ (!cast<Instruction>("PseudoVFMV_V_" #
+ !if(!eq(fvti.Scalar, bf16), "ALT_", "") #
+ fvti.ScalarSuffix # "_" #
fvti.LMul.MX)
$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2),
GPR:$vl, fvti.Log2SEW, TU_MU)>;
@@ -2526,20 +2555,37 @@ defm : VPatWConvertFP2IVL_V<any_riscv_vfcvt_rtz_x_f_vl, "PseudoVFWCVT_RTZ_X_F_V"
defm : VPatWConvertI2FPVL_V<any_riscv_uint_to_fp_vl, "PseudoVFWCVT_F_XU_V">;
defm : VPatWConvertI2FPVL_V<any_riscv_sint_to_fp_vl, "PseudoVFWCVT_F_X_V">;
-foreach fvtiToFWti = AllWidenableFloatVectors in {
+foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
// Define vfwcvt.f.f.v for f16 when Zvfhmin is enable.
- let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
- GetVTypeMinimalPredicates<fwti>.Predicates) in
+ // Define vfwcvtbf16.f.f.v for bf16 when Zvfbfmin is enable.
+ let Predicates = !listconcat(GetVTypeMinimalPredicates<fwti>.Predicates,
+ !if(!eq(fvti.Scalar, bf16),
+ [HasStdExtZvfbfmin],
+ GetVTypeMinimalPredicates<fvti>.Predicates)) in {
def : Pat<(fwti.Vector (any_riscv_fpextend_vl
(fvti.Vector fvti.RegClass:$rs1),
(fvti.Mask VMV0:$vm),
VLOpFrag)),
- (!cast<Instruction>("PseudoVFWCVT_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
+ (!cast<Instruction>("PseudoVFWCVT"#
+ !if(!eq(fvti.Scalar, bf16), "BF16", "")#
+ "_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
(fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1,
(fvti.Mask VMV0:$vm),
GPR:$vl, fvti.Log2SEW, TA_MA)>;
+
+ // Define vfwcvt.f.f.v for bf16 when Zvfbfa is enable.
+ if !eq(fvti.Scalar, bf16) then
+ let Predicates = [HasStdExtZvfbfa] in
+ def : Pat<(fwti.Vector (any_riscv_fpextend_vl
+ (fvti.Vector fvti.RegClass:$rs1),
+ (fvti.Mask VMV0:$vm),
+ VLOpFrag)),
+ (!cast<Instruction>("PseudoVFWCVT_F_F_ALT_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
+ (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1,
+ (fvti.Mask VMV0:$vm),
+ GPR:$vl, fvti.Log2SEW, TA_MA)>;
}
// 13.19 Narrowing Floating-Point/Integer Type-Convert Instructions
@@ -2555,16 +2601,21 @@ defm : VPatNConvertI2FPVL_W_RM<any_riscv_sint_to_fp_vl, "PseudoVFNCVT_F_X_W">;
defm : VPatNConvertI2FP_RM_VL_W<riscv_vfcvt_rm_f_xu_vl, "PseudoVFNCVT_F_XU_W">;
defm : VPatNConvertI2FP_RM_VL_W<riscv_vfcvt_rm_f_x_vl, "PseudoVFNCVT_F_X_W">;
-foreach fvtiToFWti = AllWidenableFloatVectors in {
+foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
// Define vfncvt.f.f.w for f16 when Zvfhmin is enable.
- let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
- GetVTypeMinimalPredicates<fwti>.Predicates) in {
+ // Define vfncvtbf16.f.f.w for bf16 when Zvfbfmin is enable.
+ let Predicates = !listconcat(GetVTypeMinimalPredicates<fwti>.Predicates,
+ !if(!eq(fvti.Scalar, bf16),
+ [HasStdExtZvfbfmin],
+ GetVTypeMinimalPredicates<fvti>.Predicates)) in
def : Pat<(fvti.Vector (any_riscv_fpround_vl
(fwti.Vector fwti.RegClass:$rs1),
(fwti.Mask VMV0:$vm), VLOpFrag)),
- (!cast<Instruction>("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
+ (!cast<Instruction>("PseudoVFNCVT"#
+ !if(!eq(fvti.Scalar, bf16), "BF16", "")#
+ "_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
(fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
(fwti.Mask VMV0:$vm),
// Value to indicate no rounding mode change in
@@ -2581,6 +2632,20 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
(fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
(fwti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TA_MA)>;
}
+
+ // Define vfncvt.f.f.w for bf16 when Zvfbfa is enable.
+ if !eq(fvti.Scalar, bf16) then
+ let Predicates = [HasStdExtZvfbfa] in
+ def : Pat<(fvti.Vector (any_riscv_fpround_vl
+ (fwti.Vector fwti.RegClass:$rs1),
+ (fwti.Mask VMV0:$vm), VLOpFrag)),
+ (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
+ (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
+ (fwti.Mask VMV0:$vm),
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ GPR:$vl, fvti.Log2SEW, TA_MA)>;
}
// 14. Vector Reduction Operations
@@ -2751,7 +2816,7 @@ foreach vti = AllIntegerVectors in {
}
// 16.2. Floating-Point Scalar Move Instructions
-foreach vti = NoGroupFloatVectors in {
+foreach vti = !listconcat(NoGroupFloatVectors, NoGroupBF16Vectors) in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru),
(vti.Scalar (fpimm0)),
@@ -2764,7 +2829,8 @@ foreach vti = NoGroupFloatVectors in {
def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru),
vti.ScalarRegClass:$rs1,
VLOpFrag)),
- (!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix)
+ (!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#
+ !if(!eq(vti.Scalar, bf16), "_ALT", ""))
vti.RegClass:$passthru,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td
index e24e4a33288f7..866e831fdcd94 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td
@@ -406,47 +406,11 @@ let Predicates = [HasStdExtZvfbfmin] in {
"PseudoVFWCVTBF16_F_F", isSEWAware=1>;
defm : VPatConversionVF_WF_BF_RM<"int_riscv_vfncvtbf16_f_f_w",
"PseudoVFNCVTBF16_F_F", isSEWAware=1>;
-
- foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in {
- defvar fvti = fvtiToFWti.Vti;
- defvar fwti = fvtiToFWti.Wti;
- def : Pat<(fwti.Vector (any_riscv_fpextend_vl
- (fvti.Vector fvti.RegClass:$rs1),
- (fvti.Mask VMV0:$vm),
- VLOpFrag)),
- (!cast<Instruction>("PseudoVFWCVTBF16_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
- (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1,
- (fvti.Mask VMV0:$vm),
- GPR:$vl, fvti.Log2SEW, TA_MA)>;
-
- def : Pat<(fvti.Vector (any_riscv_fpround_vl
- (fwti.Vector fwti.RegClass:$rs1),
- (fwti.Mask VMV0:$vm), VLOpFrag)),
- (!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
- (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
- (fwti.Mask VMV0:$vm),
- // Value to indicate no rounding mode change in
- // RISCVInsertReadWriteCSR
- FRM_DYN,
- GPR:$vl, fvti.Log2SEW, TA_MA)>;
- def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
- (!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
- (fvti.Vector (IMPLICIT_DEF)),
- fwti.RegClass:$rs1,
- // Value to indicate no rounding mode change in
- // RISCVInsertReadWriteCSR
- FRM_DYN,
- fvti.AVL, fvti.Log2SEW, TA_MA)>;
- }
}
let Predicates = [HasStdExtZvfbfwma] in {
defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmaccbf16", "PseudoVFWMACCBF16",
AllWidenableBF16ToFloatVectors, isSEWAware=1>;
- defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACCBF16",
- AllWidenableBF16ToFloatVectors>;
- defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16",
- AllWidenableBF16ToFloatVectors>;
}
multiclass VPatConversionVI_VF_BF16<string intrinsic, string instruction> {
@@ -614,191 +578,4 @@ defm : VPatConversionVF_WF_BF16<"int_riscv_vfncvt_rod_f_f_w", "PseudoVFNCVT_ROD_
isSEWAware=1>;
defm : VPatBinaryV_VX<"int_riscv_vfslide1up", "PseudoVFSLIDE1UP_ALT", AllBF16Vectors>;
defm : VPatBinaryV_VX<"int_riscv_vfslide1down", "PseudoVFSLIDE1DOWN_ALT", AllBF16Vectors>;
-
-foreach fvti = AllBF16Vectors in {
- defvar ivti = GetIntVTypeInfo<fvti>.Vti;
- let Predicates = GetVTypePredicates<ivti>.Predicates in {
- // 13.16. Vector Floating-Point Move Instruction
- // If we're splatting fpimm0, use vmv.v.x vd, x0.
- def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
- fvti.Vector:$passthru, (fvti.Scalar (fpimm0)), VLOpFrag)),
- (!cast<Instruction>("PseudoVMV_V_I_"#fvti.LMul.MX)
- $passthru, 0, GPR:$vl, fvti.Log2SEW, TU_MU)>;
- def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
- fvti.Vector:$passthru, (fvti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), VLOpFrag)),
- (!cast<Instruction>("PseudoVMV_V_X_"#fvti.LMul.MX)
- $passthru, GPR:$imm, GPR:$vl, fvti.Log2SEW, TU_MU)>;
- }
-
- let Predicates = GetVTypePredicates<fvti>.Predicates in {
- def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
- fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)),
- (!cast<Instruction>("PseudoVFMV_V_ALT_" # fvti.ScalarSuffix # "_" #
- fvti.LMul.MX)
- $passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2),
- GPR:$vl, fvti.Log2SEW, TU_MU)>;
- }
-}
-
-foreach vti = NoGroupBF16Vectors in {
- let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru),
- (vti.Scalar (fpimm0)),
- VLOpFrag)),
- (PseudoVMV_S_X $passthru, (XLenVT X0), GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru),
- (vti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))),
- VLOpFrag)),
- (PseudoVMV_S_X $passthru, GPR:$imm, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru),
- vti.ScalarRegClass:$rs1,
- VLOpFrag)),
- (!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#"_ALT")
- vti.RegClass:$passthru,
- (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
- }
-
- defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_",
- vti.ScalarSuffix,
- "_S_ALT"));
- // Only pattern-match extract-element operations where the index is 0. Any
- // other index will have been custom-lowered to slide the vector correctly
- // into place.
- let Predicates = GetVTypePredicates<vti>.Predicates in
- def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)),
- (vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>;
-}
-
-let Predicates = [HasStdExtZvfbfa] in {
- foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in {
- defvar fvti = fvtiToFWti.Vti;
- defvar fwti = fvtiToFWti.Wti;
- def : Pat<(fwti.Vector (any_riscv_fpextend_vl
- (fvti.Vector fvti.RegClass:$rs1),
- (fvti.Mask VMV0:$vm),
- VLOpFrag)),
- (!cast<Instruction>("PseudoVFWCVT_F_F_ALT_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
- (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1,
- (fvti.Mask VMV0:$vm),
- GPR:$vl, fvti.Log2SEW, TA_MA)>;
-
- def : Pat<(fvti.Vector (any_riscv_fpround_vl
- (fwti.Vector fwti.RegClass:$rs1),
- (fwti.Mask VMV0:$vm), VLOpFrag)),
- (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
- (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
- (fwti.Mask VMV0:$vm),
- // Value to indicate no rounding mode change in
- // RISCVInsertReadWriteCSR
- FRM_DYN,
- GPR:$vl, fvti.Log2SEW, TA_MA)>;
- def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
- (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
- (fvti.Vector (IMPLICIT_DEF)),
- fwti.RegClass:$rs1,
- // Value to indicate no rounding mode change in
- // RISCVInsertReadWriteCSR
- FRM_DYN,
- fvti.AVL, fvti.Log2SEW, TA_MA)>;
- }
-
- foreach vti = AllBF16Vectors in {
- // 13.12. Vector Floating-Point Sign-Injection Instructions
- def : Pat<(fabs (vti.Vector vti.RegClass:$rs)),
- (!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
- // Handle fneg with VFSGNJN using the same input for both operands.
- def : Pat<(fneg (vti.Vector vti.RegClass:$rs)),
- (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
- (vti.Vector vti.RegClass:$rs2))),
- (!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
- def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
- (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))),
- (!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
- (vti.Vector (fneg vti.RegClass:$rs2)))),
- (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
- def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
- (vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))),
- (!cast<Instruction>("PseudoVFSGNJN_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)),
- vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
-
- // 13.12. Vector Floating-Point Sign-Injection Instructions
- def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
- VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
- vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
- TA_MA)>;
- // Handle fneg with VFSGNJN using the same input for both operands.
- def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
- VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK")
- (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
- vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
- TA_MA)>;
-
- def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
- (vti.Vector vti.RegClass:$rs2),
- vti.RegClass:$passthru,
- (vti.Mask VMV0:$vm),
- VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
- vti.RegClass:$passthru, vti.RegClass:$rs1,
- vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
- TAIL_AGNOSTIC)>;
-
- def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
- (riscv_fneg_vl vti.RegClass:$rs2,
- (vti.Mask true_mask),
- VLOpFrag),
- srcvalue,
- (vti.Mask true_mask),
- VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
- (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1,
- vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
- (SplatFPOp vti.ScalarRegClass:$rs2),
- vti.RegClass:$passthru,
- (vti.Mask VMV0:$vm),
- VLOpFrag),
- (!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
- vti.RegClass:$passthru, vti.RegClass:$rs1,
- vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
- TAIL_AGNOSTIC)>;
- }
- }
-
- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB_ALT",
- isSEWAware=1, isBF16=1>;
-
- defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fsub_vl, "PseudoVFSUB_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fmul_vl, "PseudoVFMUL_ALT",
- isSEWAware=1, isBF16=1>;
- defm : VPatBinaryFPVL_R_VF_RM<any_riscv_fsub_vl, "PseudoVFRSUB_ALT",
- isSEWAware=1, isBF16=1>;
} // Predicates = [HasStdExtZvfbfa]
More information about the llvm-commits
mailing list