[Mlir-commits] [mlir] [mlir][IR] Treat `tf32` as 19-bit float (PR #116738)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 18 19:46:29 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/116738
TF32 is a variant of F32 that is truncated to 19 bits. There used to be special handling in `FloatType::getWidth()` such that TF32 was treated as a 32-bit float in some places. (Some places use `FloatType::getWidth`, others directly query the `APFloat` semantics.) This caused problems because `FloatType::getWidth` did not agree with the underlying `APFloat` semantics.
In particular, creating an elements attr / array attr with `tf32` element type crashed. E.g.:
```
"foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()
mlir-opt: llvm-project/llvm/lib/Support/APFloat.cpp:4108: void llvm::detail::IEEEFloat::initFromAPInt(const fltSemantics *, const APInt &): Assertion `api.getBitWidth() == Sem->sizeInBits' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
```
```
"foo"() {f32attr = array<tf32: 1024.>} : () -> ()
mlir-opt: llvm-project/mlir/lib/AsmParser/AttributeParser.cpp:847: void (anonymous namespace)::DenseArrayElementParser::append(const APInt &): Assertion `data.getBitWidth() % 8 == 0' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
```
It is unclear why the special handling for TF32 is needed. For reference: #105573
>From fca9a88d78f6945d58c1ef8d6965942679028082 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 19 Nov 2024 04:35:21 +0100
Subject: [PATCH] [mlir][IR] Treat `tf32` as float with bitwidth 19
---
mlir/lib/IR/BuiltinTypes.cpp | 5 -----
mlir/test/IR/attribute.mlir | 16 ++++++++++++++++
2 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..e8e8f3cdfbfd73 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,11 +91,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
- // The actual width of TF32 is 19 bits. However, since it is a truncated
- // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
- // for compatibility.
- if (llvm::isa<FloatTF32Type>(*this))
- return 32;
return APFloat::semanticsSizeInBits(getFloatSemantics());
}
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..0085d64ae82b6b 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -561,6 +561,14 @@ func.func @correct_type_pass() {
// -----
+func.func @tf32_elements_attr() {
+ // CHECK: "foo"() {attr = dense<4.000000e+00> : tensor<tf32>} : () -> ()
+ "foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test StringElementsAttr
//===----------------------------------------------------------------------===//
@@ -675,6 +683,14 @@ func.func @dense_array_attr() attributes {
// -----
+func.func @test_invalid_bitwidth_type() {
+ // expected-error @below{{element type bitwidth must be a multiple of 8}}
+ "foo"() {tf32attr = array<tf32: 1024.0>} : () -> ()
+ return
+}
+
+// -----
+
func.func @testConfinedDenseArrayAttr() {
"test.confined_dense_array_attr"() {
i64attr = array<i64: 0, 2, 3>,
More information about the Mlir-commits
mailing list