[llvm-branch-commits] [mlir] [mlir][IR] Remove `isF...()` type API for low-precision FP types (PR #123326)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 17 03:35:50 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Remove `type.isFloat4E2M1FN()` etc. Use `isa<Float4E2M1FNType>(type)` instead.
For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28
Depends on #<!-- -->123321.
---
Patch is 22.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123326.diff
11 Files Affected:
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+13-13)
- (modified) mlir/include/mlir/IR/Types.h (-11)
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+24-16)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+20-18)
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+2-2)
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+4-5)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-4)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+2-2)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+3-3)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+1-2)
- (modified) mlir/lib/IR/Types.cpp (-19)
``````````diff
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6f52195c1d7c92..e752cdfb47fbb1 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -329,31 +329,31 @@ def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;
-def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
+def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
BuildableType<"$_builder.getType<BFloat16Type>()">;
-def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
BuildableType<"$_builder.getType<FloatTF32Type>()">;
-def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
+def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
-def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
+def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
-def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
+def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
-def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
+def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
-def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
+def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
-def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
+def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
-def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
-def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
+def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
-def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
-def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
-def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index acd0f894abbbe6..0e82ad2be907ab 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,17 +125,6 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
- bool isFloat4E2M1FN() const;
- bool isFloat6E2M3FN() const;
- bool isFloat6E3M2FN() const;
- bool isFloat8E5M2() const;
- bool isFloat8E4M3() const;
- bool isFloat8E4M3FN() const;
- bool isFloat8E5M2FNUZ() const;
- bool isFloat8E4M3FNUZ() const;
- bool isFloat8E4M3B11FNUZ() const;
- bool isFloat8E3M4() const;
- bool isFloat8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 250e4a6bbf8dfd..313d6830b41b2a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
- return unwrap(type).isFloat4E2M1FN();
+ return llvm::isa<Float4E2M1FNType>(unwrap(type));
}
MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
- return unwrap(type).isFloat6E2M3FN();
+ return llvm::isa<Float6E2M3FNType>(unwrap(type));
}
MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
- return unwrap(type).isFloat6E3M2FN();
+ return llvm::isa<Float6E3M2FNType>(unwrap(type));
}
MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E5M2(MlirType type) {
- return unwrap(type).isFloat8E5M2();
+ return llvm::isa<Float8E5M2Type>(unwrap(type));
}
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3(MlirType type) {
- return unwrap(type).isFloat8E4M3();
+ return llvm::isa<Float8E4M3Type>(unwrap(type));
}
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
- return unwrap(type).isFloat8E4M3FN();
+ return llvm::isa<Float8E4M3FNType>(unwrap(type));
}
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
- return unwrap(type).isFloat8E5M2FNUZ();
+ return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
}
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
- return unwrap(type).isFloat8E4M3FNUZ();
+ return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
}
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
- return unwrap(type).isFloat8E4M3B11FNUZ();
+ return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
}
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E3M4(MlirType type) {
- return unwrap(type).isFloat8E3M4();
+ return llvm::isa<Float8E3M4Type>(unwrap(type));
}
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
- return unwrap(type).isFloat8E8M0FNU();
+ return llvm::isa<Float8E8M0FNUType>(unwrap(type));
}
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
-bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+bool mlirTypeIsABF16(MlirType type) {
+ return llvm::isa<BFloat16Type>(unwrap(type));
+}
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(BFloat16Type::get(unwrap(ctx)));
@@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
-bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+bool mlirTypeIsAF16(MlirType type) {
+ return llvm::isa<Float16Type>(unwrap(type));
+}
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(Float16Type::get(unwrap(ctx)));
@@ -239,7 +243,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
return wrap(FloatTF32Type::getTypeID());
}
-bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+bool mlirTypeIsATF32(MlirType type) { return llvm::isa<FloatTF32Type>(type); }
MlirType mlirTF32TypeGet(MlirContext ctx) {
return wrap(FloatTF32Type::get(unwrap(ctx)));
@@ -247,7 +251,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
-bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+bool mlirTypeIsAF32(MlirType type) {
+ return llvm::isa<Float32Type>(unwrap(type));
+}
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(Float32Type::get(unwrap(ctx)));
@@ -255,7 +261,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
-bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+bool mlirTypeIsAF64(MlirType type) {
+ return llvm::isa<Float64Type>(unwrap(type));
+}
MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(Float64Type::get(unwrap(ctx)));
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1564e417a7a48e..5d09d6f1d69523 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
+ chipset >= kGfx940) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
+ chipset >= kGfx940) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
+ if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
- if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
+ if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
return std::nullopt;
}
@@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (isa<Float8E5M2FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index a8283023afc53d..33370566996eee 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 64bdb248dff430..247a8ab28a44be 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
return type;
// F4, F6, F8 types are converted to integer types with the same bit width.
- if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
- type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
- type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
- type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
- type.isFloat8E8M0FNU())
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+ Float8E8M0FNUType>(type))
return IntegerType::get(&getContext(), type.getWidth());
// Other floating-point types: A custom type conversion rule must be
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 34a6b1d506540d..7e97fb84434f89 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
wgmmaK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaK = 16;
- } else if (inputElemType.isFloat8E4M3FN() ||
- inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+ } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
+ inputElemType.isInteger(16)) {
wgmmaK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
return NVVM::WGMMATypes::bf16;
- if (elemType.isFloat8E4M3FN())
+ if (isa<Float8E4M3FNType>(elemType))
return NVVM::WGMMATypes::e4m3;
- if (elemType.isFloat8E5M2())
+ if (isa<Float8E5M2Type>(elemType))
return NVVM::WGMMATypes::e5m2;
if (elemType.isInteger(1))
return NVVM::WGMMATypes::b1;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 492e4781f57810..5af0cb0c7ba1cc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index a027350e8a5f70..47d1b8492e06ec 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -525,8 +525,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return success();
// F16 += f8 + f8
// F32 += f8 + f8
- if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
- (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+ if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
+ isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
(typeD.isF32() || typeD.isF16()))
return success();
@@ -548,7 +548,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
- typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+ isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
if (llvm::is_contained(allowedN, sizeN))
return success();
diff --git a/mlir/lib/Dial...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/123326
More information about the llvm-branch-commits
mailing list