[Mlir-commits] [mlir] [mlir][IR] Treat `tf32` as 19-bit float (PR #116738)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 18 19:46:29 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/116738

TF32 is a variant of F32 that is truncated to 19 bits. There used to be special handling in `FloatType::getWidth()` such that TF32 was treated as a 32-bit float in some places. (Some places use `FloatType::getWidth`, others directly query the `APFloat` semantics.) This caused problems because `FloatType::getWidth` did not agree with the underlying `APFloat` semantics.

In particular, creating an elements attr / array attr with `tf32` element type crashed. E.g.:
```
"foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()

mlir-opt: llvm-project/llvm/lib/Support/APFloat.cpp:4108: void llvm::detail::IEEEFloat::initFromAPInt(const fltSemantics *, const APInt &): Assertion `api.getBitWidth() == Sem->sizeInBits' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
```

```
"foo"() {f32attr = array<tf32: 1024.>} : () -> ()

mlir-opt: llvm-project/mlir/lib/AsmParser/AttributeParser.cpp:847: void (anonymous namespace)::DenseArrayElementParser::append(const APInt &): Assertion `data.getBitWidth() % 8 == 0' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
```

It is unclear why the special handling for TF32 is needed. For reference: #105573


>From fca9a88d78f6945d58c1ef8d6965942679028082 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 19 Nov 2024 04:35:21 +0100
Subject: [PATCH] [mlir][IR] Treat `tf32` as float with bitwidth 19

---
 mlir/lib/IR/BuiltinTypes.cpp |  5 -----
 mlir/test/IR/attribute.mlir  | 16 ++++++++++++++++
 2 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..e8e8f3cdfbfd73 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,11 +91,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  // The actual width of TF32 is 19 bits. However, since it is a truncated
-  // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
-  // for compatibility.
-  if (llvm::isa<FloatTF32Type>(*this))
-    return 32;
   return APFloat::semanticsSizeInBits(getFloatSemantics());
 }
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..0085d64ae82b6b 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -561,6 +561,14 @@ func.func @correct_type_pass() {
 
 // -----
 
+func.func @tf32_elements_attr() {
+  // CHECK: "foo"() {attr = dense<4.000000e+00> : tensor<tf32>} : () -> ()
+  "foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test StringElementsAttr
 //===----------------------------------------------------------------------===//
@@ -675,6 +683,14 @@ func.func @dense_array_attr() attributes {
 
 // -----
 
+func.func @test_invalid_bitwidth_type() {
+  // expected-error @below{{element type bitwidth must be a multiple of 8}}
+  "foo"() {tf32attr = array<tf32: 1024.0>} : () -> ()
+  return
+}
+
+// -----
+
 func.func @testConfinedDenseArrayAttr() {
   "test.confined_dense_array_attr"() {
     i64attr = array<i64: 0, 2, 3>,



More information about the Mlir-commits mailing list