[clang] [llvm] [RISCV] Use BF16 in Xsfvfwmaccqqq intrinsics (PR #71140)
via cfe-commits
cfe-commits at lists.llvm.org
Thu Nov 2 21:34:58 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Shao-Ce SUN (sunshaoce)
<details>
<summary>Changes</summary>
BF16 implementation based on @<!-- -->joshua-arch1's https://reviews.llvm.org/D152498
Fixed the incorrect f16 type introduced in https://github.com/llvm/llvm-project/pull/68296
---------
Co-authored-by: Jun Sha (Joshua) <cooper.joshua@<!-- -->linux.alibaba.com>
---
Patch is 46.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71140.diff
19 Files Affected:
- (modified) clang/include/clang/AST/Type.h (+2-2)
- (modified) clang/include/clang/Basic/RISCVVTypes.def (+29-6)
- (modified) clang/include/clang/Basic/riscv_sifive_vector.td (+2-2)
- (modified) clang/include/clang/Basic/riscv_vector_common.td (+1)
- (modified) clang/include/clang/Support/RISCVVIntrinsicUtils.h (+7-4)
- (modified) clang/lib/AST/ASTContext.cpp (+9-3)
- (modified) clang/lib/AST/Type.cpp (+4-3)
- (modified) clang/lib/Sema/SemaRISCVVectorLookup.cpp (+3)
- (modified) clang/lib/Support/RISCVVIntrinsicUtils.cpp (+25)
- (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c (+15-15)
- (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/overloaded/sf_vfwmacc_4x4x4.c (+15-15)
- (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/policy/non-overloaded/sf_vfwmacc_4x4x4.c (+15-15)
- (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/policy/overloaded/sf_vfwmacc_4x4x4.c (+15-15)
- (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-handcrafted/rvv-intrinsic-datatypes.cpp (+13)
- (modified) clang/test/Sema/riscv-types.c (+19)
- (modified) clang/test/Sema/rvv-required-features.c (+1-1)
- (modified) clang/utils/TableGen/RISCVVEmitter.cpp (+5-3)
- (modified) llvm/lib/Support/RISCVISAInfo.cpp (+1-1)
- (modified) llvm/lib/Target/RISCV/RISCVFeatures.td (+1-1)
``````````diff
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index f64cd5e0ef64910..f99c4faa7170527 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -7294,7 +7294,7 @@ inline bool Type::isRVVType() const {
inline bool Type::isRVVType(unsigned ElementCount) const {
bool Ret = false;
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (NumEls == ElementCount) \
Ret |= isSpecificBuiltinType(BuiltinType::Id);
#include "clang/Basic/RISCVVTypes.def"
@@ -7305,7 +7305,7 @@ inline bool Type::isRVVType(unsigned Bitwidth, bool IsFloat) const {
bool Ret = false;
#define RVV_TYPE(Name, Id, SingletonId)
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (ElBits == Bitwidth && IsFloat == IsFP) \
Ret |= isSpecificBuiltinType(BuiltinType::Id);
#include "clang/Basic/RISCVVTypes.def"
diff --git a/clang/include/clang/Basic/RISCVVTypes.def b/clang/include/clang/Basic/RISCVVTypes.def
index 575bca58b51e023..af44cdcd53e5bd0 100644
--- a/clang/include/clang/Basic/RISCVVTypes.def
+++ b/clang/include/clang/Basic/RISCVVTypes.def
@@ -12,7 +12,8 @@
// A builtin type that has not been covered by any other #define
// Defining this macro covers all the builtins.
//
-// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP)
+// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP,
+// IsBF)
// A RISC-V V scalable vector.
//
// - RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)
@@ -45,7 +46,8 @@
#endif
#ifndef RVV_VECTOR_TYPE
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP)\
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
+ IsFP, IsBF) \
RVV_TYPE(Name, Id, SingletonId)
#endif
@@ -55,13 +57,20 @@
#endif
#ifndef RVV_VECTOR_TYPE_INT
-#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned) \
- RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false)
+#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
+ IsSigned) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false, \
+ false)
#endif
#ifndef RVV_VECTOR_TYPE_FLOAT
-#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
- RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true)
+#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
+#endif
+
+#ifndef RVV_VECTOR_TYPE_BFLOAT
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
#endif
//===- Vector types -------------------------------------------------------===//
@@ -125,6 +134,19 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m2_t", RvvFloat16m2, RvvFloat16m2Ty, 8, 16,
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf4_t", RvvBFloat16mf4, RvvBFloat16mf4Ty,
+ 1, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf2_t", RvvBFloat16mf2, RvvBFloat16mf2Ty,
+ 2, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32,
+ 16, 1)
+
RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1, 32, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2, 32, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4, 32, 1)
@@ -430,6 +452,7 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float64m2x4_t", RvvFloat64m2x4, RvvFloat64m2x4Ty, 2
RVV_VECTOR_TYPE_FLOAT("__rvv_float64m4x2_t", RvvFloat64m4x2, RvvFloat64m4x2Ty, 4, 64, 2)
+#undef RVV_VECTOR_TYPE_BFLOAT
#undef RVV_VECTOR_TYPE_FLOAT
#undef RVV_VECTOR_TYPE_INT
#undef RVV_VECTOR_TYPE
diff --git a/clang/include/clang/Basic/riscv_sifive_vector.td b/clang/include/clang/Basic/riscv_sifive_vector.td
index 1e081c734d4941b..d4c22769d9b95ae 100644
--- a/clang/include/clang/Basic/riscv_sifive_vector.td
+++ b/clang/include/clang/Basic/riscv_sifive_vector.td
@@ -109,7 +109,7 @@ multiclass RVVVFWMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
Name = NAME,
HasMasked = false,
Log2LMUL = [-2, -1, 0, 1, 2] in
- defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "x", suffixes_prototypes>;
+ defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "b", suffixes_prototypes>;
}
multiclass RVVVQMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
@@ -146,7 +146,7 @@ let UnMaskedPolicyScheme = HasPolicyOperand in
let UnMaskedPolicyScheme = HasPolicyOperand in
let RequiredFeatures = ["Xsfvfwmaccqqq"] in
- defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "w", "wwSvv"]]>;
+ defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "Fw", "FwFwSvv"]]>;
let UnMaskedPolicyScheme = HasPassthruOperand, RequiredFeatures = ["Xsfvfnrclipxfqf"] in {
let ManualCodegen = [{
diff --git a/clang/include/clang/Basic/riscv_vector_common.td b/clang/include/clang/Basic/riscv_vector_common.td
index 326c3883f0a8409..4036ce8e6903f42 100644
--- a/clang/include/clang/Basic/riscv_vector_common.td
+++ b/clang/include/clang/Basic/riscv_vector_common.td
@@ -41,6 +41,7 @@
// x: float16_t (half)
// f: float32_t (float)
// d: float64_t (double)
+// b: bfloat16_t (bfloat16)
//
// This way, given an LMUL, a record with a TypeRange "sil" will cause the
// definition of 3 builtins. Each type "t" in the TypeRange (in this example
diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
index 7904658576e5d50..cd620a8fb2b5c14 100644
--- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -207,10 +207,11 @@ enum class BasicType : uint8_t {
Int16 = 1 << 1,
Int32 = 1 << 2,
Int64 = 1 << 3,
- Float16 = 1 << 4,
- Float32 = 1 << 5,
- Float64 = 1 << 6,
- MaxOffset = 6,
+ BFloat16 = 1 << 4,
+ Float16 = 1 << 5,
+ Float32 = 1 << 6,
+ Float64 = 1 << 7,
+ MaxOffset = 7,
LLVM_MARK_AS_BITMASK_ENUM(Float64),
};
@@ -225,6 +226,7 @@ enum ScalarTypeKind : uint8_t {
SignedInteger,
UnsignedInteger,
Float,
+ BFloat,
Invalid,
Undefined,
};
@@ -300,6 +302,7 @@ class RVVType {
return isVector() && ElementBitwidth == Width;
}
bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
+ bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; }
bool isSignedInteger() const {
return ScalarType == ScalarTypeKind::SignedInteger;
}
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 1cb81cffd37ea58..a781a7d5a8638cc 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2177,7 +2177,7 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
break;
#include "clang/Basic/PPCTypes.def"
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, ElKind, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
case BuiltinType::Id: \
Width = 0; \
Align = ElBits; \
@@ -3939,6 +3939,9 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
case BuiltinType::Id: \
return {ElBits == 16 ? Float16Ty : (ElBits == 32 ? FloatTy : DoubleTy), \
llvm::ElementCount::getScalable(NumEls), NF};
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ case BuiltinType::Id: \
+ return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return {BoolTy, llvm::ElementCount::getScalable(NumEls), 1};
@@ -3986,11 +3989,14 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
} else if (Target->hasRISCVVTypes()) {
uint64_t EltTySize = getTypeSize(EltTy);
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (!EltTy->isBooleanType() && \
((EltTy->hasIntegerRepresentation() && \
EltTy->hasSignedIntegerRepresentation() == IsSigned) || \
- (EltTy->hasFloatingRepresentation() && IsFP)) && \
+ (EltTy->hasFloatingRepresentation() && !EltTy->isBFloat16Type() && \
+ IsFP && !IsBF) || \
+ (EltTy->hasFloatingRepresentation() && EltTy->isBFloat16Type() && \
+ IsBF && !IsFP)) && \
EltTySize == ElBits && NumElts == NumEls && NumFields == NF) \
return SingletonId;
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index d1cbfbd150ba53f..df56544b871e22a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2475,9 +2475,10 @@ QualType Type::getSveEltType(const ASTContext &Ctx) const {
bool Type::isRVVVLSBuiltinType() const {
if (const BuiltinType *BT = getAs<BuiltinType>()) {
switch (BT->getKind()) {
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP) \
- case BuiltinType::Id: \
- return NF == 1;
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
+ IsFP, IsBF) \
+ case BuiltinType::Id: \
+ return NF == 1;
#include "clang/Basic/RISCVVTypes.def"
default:
return false;
diff --git a/clang/lib/Sema/SemaRISCVVectorLookup.cpp b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
index 8e72eba1ac4c56f..9a5aecf669a07df 100644
--- a/clang/lib/Sema/SemaRISCVVectorLookup.cpp
+++ b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
@@ -117,6 +117,9 @@ static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
case ScalarTypeKind::UnsignedInteger:
QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
break;
+ case ScalarTypeKind::BFloat:
+ QT = Context.BFloat16Ty;
+ break;
case ScalarTypeKind::Float:
switch (Type->getElementBitwidth()) {
case 64:
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
index 751d0aedacc9a1f..78d49f15732a11e 100644
--- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -101,6 +101,7 @@ RVVType::RVVType(BasicType BT, int Log2LMUL,
// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
+// bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
// clang-format on
bool RVVType::verifyType() const {
@@ -112,6 +113,8 @@ bool RVVType::verifyType() const {
return false;
if (isFloat() && ElementBitwidth == 8)
return false;
+ if (isBFloat() && ElementBitwidth != 16)
+ return false;
if (IsTuple && (NF == 1 || NF > 8))
return false;
if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
@@ -199,6 +202,9 @@ void RVVType::initBuiltinStr() {
llvm_unreachable("Unhandled ElementBitwidth!");
}
break;
+ case ScalarTypeKind::BFloat:
+ BuiltinStr += "b";
+ break;
default:
llvm_unreachable("ScalarType is invalid!");
}
@@ -234,6 +240,9 @@ void RVVType::initClangBuiltinStr() {
case ScalarTypeKind::Float:
ClangBuiltinStr += "float";
break;
+ case ScalarTypeKind::BFloat:
+ ClangBuiltinStr += "bfloat";
+ break;
case ScalarTypeKind::SignedInteger:
ClangBuiltinStr += "int";
break;
@@ -300,6 +309,15 @@ void RVVType::initTypeStr() {
} else
Str += getTypeString("float");
break;
+ case ScalarTypeKind::BFloat:
+ if (isScalar()) {
+ if (ElementBitwidth == 16)
+ Str += "__bf16";
+ else
+ llvm_unreachable("Unhandled floating type.");
+ } else
+ Str += getTypeString("bfloat");
+ break;
case ScalarTypeKind::SignedInteger:
Str += getTypeString("int");
break;
@@ -322,6 +340,9 @@ void RVVType::initShortStr() {
case ScalarTypeKind::Float:
ShortStr = "f" + utostr(ElementBitwidth);
break;
+ case ScalarTypeKind::BFloat:
+ ShortStr = "bf" + utostr(ElementBitwidth);
+ break;
case ScalarTypeKind::SignedInteger:
ShortStr = "i" + utostr(ElementBitwidth);
break;
@@ -373,6 +394,10 @@ void RVVType::applyBasicType() {
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::Float;
break;
+ case BasicType::BFloat16:
+ ElementBitwidth = 16;
+ ScalarType = ScalarTypeKind::BFloat;
+ break;
default:
llvm_unreachable("Unhandled type code!");
}
diff --git a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
index 185b8f236b62a8d..0a08798ef50371c 100644
--- a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
+++ b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
@@ -7,51 +7,51 @@
#include <sifive_vector.h>
// CHECK-RV64-LABEL: define dso_local <vscale x 1 x float> @test_sf_vfwmacc_4x4x4_f32mf2
-// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 1 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 1 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4f16.nxv1f16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 1 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4bf16.nxv1bf16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 1 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 1 x float> [[TMP0]]
//
-vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vfloat16m1_t vs1, vfloat16mf4_t vs2, size_t vl) {
+vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vbfloat16m1_t vs1, vbfloat16mf4_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32mf2(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 2 x float> @test_sf_vfwmacc_4x4x4_f32m1
-// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 2 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 2 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4f16.nxv2f16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 2 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4bf16.nxv2bf16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 2 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 2 x float> [[TMP0]]
//
-vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vfloat16m1_t vs1, vfloat16mf2_t vs2, size_t vl) {
+vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vbfloat16m1_t vs1, vbfloat16mf2_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m1(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 4 x float> @test_sf_vfwmacc_4x4x4_f32m2
-// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 4 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 4 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4f16.nxv4f16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 4 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4bf16.nxv4bf16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 4 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 4 x float> [[TMP0]]
//
-vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vfloat16m1_t vs1, vfloat16m1_t vs2, size_t vl) {
+vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vbfloat16m1_t vs1, vbfloat16m1_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m2(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 8 x float> @test_sf_vfwmacc_4x4x4_f32m4
-// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 8 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 8 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4f16.nxv8f16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 8 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4bf16.nxv8bf16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 8 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 8 x float> [[TMP0]]
//
-vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vfloat16m1_t vs1, vfloat16m2_t vs2, size_t vl) {
+vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vbfloat16m1_t vs1, vbfloat16m2_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m4(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 16 x float> @test_sf_vfwmacc_4x4x4_f32m8
-// CHECK-RV64-SAME: (<vscale x 16 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 16 x half> ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/71140
More information about the cfe-commits
mailing list