[Mlir-commits] [mlir] e17c913 - [mlir][python] Add `T.tf32` and missing tests for `tf32` (#116725)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 18 18:00:39 PST 2024
Author: Matthias Springer
Date: 2024-11-19T11:00:35+09:00
New Revision: e17c91341be2f6a2d229ab44a4290e7d0ef2e094
URL: https://github.com/llvm/llvm-project/commit/e17c91341be2f6a2d229ab44a4290e7d0ef2e094
DIFF: https://github.com/llvm/llvm-project/commit/e17c91341be2f6a2d229ab44a4290e7d0ef2e094.diff
LOG: [mlir][python] Add `T.tf32` and missing tests for `tf32` (#116725)
Added:
Modified:
mlir/python/mlir/extras/types.py
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 34eee1edb57ff5..b875d639e9d406 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -21,6 +21,7 @@
Float8E4M3Type,
Float8E5M2Type,
Float8E8M0FNUType,
+ FloatTF32Type,
FunctionType,
IndexType,
IntegerType,
@@ -70,6 +71,7 @@ def ui(width):
f16 = lambda: F16Type.get()
f32 = lambda: F32Type.get()
+tf32 = lambda: FloatTF32Type.get()
f64 = lambda: F64Type.get()
bf16 = lambda: BF16Type.get()
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 48ddc8359ca0a1..6ce0fc12d80824 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -639,6 +639,7 @@ def testTypeIDs():
(BF16Type, BF16Type.get()),
(F16Type, F16Type.get()),
(F32Type, F32Type.get()),
+ (FloatTF32Type, FloatTF32Type.get()),
(F64Type, F64Type.get()),
(NoneType, NoneType.get()),
(ComplexType, ComplexType.get(f32)),
@@ -668,6 +669,7 @@ def testTypeIDs():
# CHECK: BF16Type(bf16)
# CHECK: F16Type(f16)
# CHECK: F32Type(f32)
+ # CHECK: FloatTF32Type(tf32)
# CHECK: F64Type(f64)
# CHECK: NoneType(none)
# CHECK: ComplexType(complex<f32>)
@@ -734,6 +736,9 @@ def print_downcasted(typ):
# CHECK: F32Type
# CHECK: F32Type(f32)
print_downcasted(F32Type.get())
+ # CHECK: FloatTF32Type
+ # CHECK: FloatTF32Type(tf32)
+ print_downcasted(FloatTF32Type.get())
# CHECK: F64Type
# CHECK: F64Type(f64)
print_downcasted(F64Type.get())
More information about the Mlir-commits
mailing list