[Mlir-commits] [mlir] e17c913 - [mlir][python] Add `T.tf32` and missing tests for `tf32` (#116725)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 18 18:00:39 PST 2024


Author: Matthias Springer
Date: 2024-11-19T11:00:35+09:00
New Revision: e17c91341be2f6a2d229ab44a4290e7d0ef2e094

URL: https://github.com/llvm/llvm-project/commit/e17c91341be2f6a2d229ab44a4290e7d0ef2e094
DIFF: https://github.com/llvm/llvm-project/commit/e17c91341be2f6a2d229ab44a4290e7d0ef2e094.diff

LOG: [mlir][python] Add `T.tf32` and missing tests for `tf32` (#116725)

Added: 
    

Modified: 
    mlir/python/mlir/extras/types.py
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 34eee1edb57ff5..b875d639e9d406 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -21,6 +21,7 @@
     Float8E4M3Type,
     Float8E5M2Type,
     Float8E8M0FNUType,
+    FloatTF32Type,
     FunctionType,
     IndexType,
     IntegerType,
@@ -70,6 +71,7 @@ def ui(width):
 
 f16 = lambda: F16Type.get()
 f32 = lambda: F32Type.get()
+tf32 = lambda: FloatTF32Type.get()
 f64 = lambda: F64Type.get()
 bf16 = lambda: BF16Type.get()
 

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 48ddc8359ca0a1..6ce0fc12d80824 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -639,6 +639,7 @@ def testTypeIDs():
             (BF16Type, BF16Type.get()),
             (F16Type, F16Type.get()),
             (F32Type, F32Type.get()),
+            (FloatTF32Type, FloatTF32Type.get()),
             (F64Type, F64Type.get()),
             (NoneType, NoneType.get()),
             (ComplexType, ComplexType.get(f32)),
@@ -668,6 +669,7 @@ def testTypeIDs():
         # CHECK: BF16Type(bf16)
         # CHECK: F16Type(f16)
         # CHECK: F32Type(f32)
+        # CHECK: FloatTF32Type(tf32)
         # CHECK: F64Type(f64)
         # CHECK: NoneType(none)
         # CHECK: ComplexType(complex<f32>)
@@ -734,6 +736,9 @@ def print_downcasted(typ):
         # CHECK: F32Type
         # CHECK: F32Type(f32)
         print_downcasted(F32Type.get())
+        # CHECK: FloatTF32Type
+        # CHECK: FloatTF32Type(tf32)
+        print_downcasted(FloatTF32Type.get())
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())


        


More information about the Mlir-commits mailing list