[Mlir-commits] [mlir] bdeae1f - [mlir][sparse][taco] Support f16.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 21 09:08:31 PDT 2022


Author: bixia1
Date: 2022-06-21T09:08:26-07:00
New Revision: bdeae1f57b261c9a6f1cb1cc08c80086907cff93

URL: https://github.com/llvm/llvm-project/commit/bdeae1f57b261c9a6f1cb1cc08c80086907cff93
DIFF: https://github.com/llvm/llvm-project/commit/bdeae1f57b261c9a6f1cb1cc08c80086907cff93.diff

LOG: [mlir][sparse][taco] Support f16.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D128105

Added: 
    

Modified: 
    mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
index b4fee5269c427..211e727c13e9e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
@@ -12,7 +12,9 @@
 dense = pt.dense
 
 passed = 0
-all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64]
+all_types = [
+    pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
+]
 for t in all_types:
   i, j = pt.get_index_vars(2)
   A = pt.tensor([2, 3], dtype=t)
@@ -29,5 +31,5 @@
   passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
   passed += np.allclose(values, [20.0, 10.0, 70.0])
 
-# CHECK: Number of passed: 18
+# CHECK: Number of passed: 21
 print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 48b6b552bb110..6f82016a40c53 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -72,7 +72,7 @@ class Type(enum.Enum):
   INT16 = np.int16
   INT32 = np.int32
   INT64 = np.int64
-  # numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
+  FLOAT16 = np.float16
   FLOAT32 = np.float32
   FLOAT64 = np.float64
   COMPLEX64 = np.complex64
@@ -80,15 +80,15 @@ class Type(enum.Enum):
 
 
 # All floating point type enums.
-_FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
+_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64)
 # All integral type enums.
 _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
 # All complex type enums.
 _COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
 # Type alias for any numpy type used to implement the runtime support for the
 # enum data types.
-_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
-                        np.float64, np.complex64, np.complex128]
+_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
+                        np.float32, np.float64, np.complex64, np.complex128]
 
 
 @dataclasses.dataclass(frozen=True)
@@ -132,6 +132,7 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
       Type.INT16: "i16",
       Type.INT32: "i32",
       Type.INT64: "i64",
+      Type.FLOAT16: "f16",
       Type.FLOAT32: "f32",
       Type.FLOAT64: "f64",
       Type.COMPLEX64: "complex<f32>",
@@ -147,6 +148,7 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
       np.int16: Type.INT16,
       np.int32: Type.INT32,
       np.int64: Type.INT64,
+      np.float16: Type.FLOAT16,
       np.float32: Type.FLOAT32,
       np.float64: Type.FLOAT64,
       np.complex64: Type.COMPLEX64,
@@ -162,6 +164,7 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
       Type.INT16: ir.IntegerType.get_signless(16),
       Type.INT32: ir.IntegerType.get_signless(32),
       Type.INT64: ir.IntegerType.get_signless(64),
+      Type.FLOAT16: ir.F16Type.get(),
       Type.FLOAT32: ir.F32Type.get(),
       Type.FLOAT64: ir.F64Type.get(),
       Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
index 8300dfef5bc63..d11eb76edca93 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
@@ -39,6 +39,7 @@
 int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
 int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
 int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
+float16 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT16)
 float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32)
 float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64)
 complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index f5ec14aa80b03..969d78b2e5887 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -89,6 +89,8 @@ def _get_support_func_locator() -> _SupportFuncLocator:
                       c_lib.convertFromMLIRSparseTensorI32),
                      (np.int64, c_lib.convertToMLIRSparseTensorI64,
                       c_lib.convertFromMLIRSparseTensorI64),
+                     (np.float16, c_lib.convertToMLIRSparseTensorF16,
+                      c_lib.convertFromMLIRSparseTensorF16),
                      (np.float32, c_lib.convertToMLIRSparseTensorF32,
                       c_lib.convertFromMLIRSparseTensorF32),
                      (np.float64, c_lib.convertToMLIRSparseTensorF64,


        


More information about the Mlir-commits mailing list