[Mlir-commits] [mlir] 6071f6f - [mlir][sparse] Fix a problem in handling data type conversion.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 30 14:54:59 PDT 2023


Author: bixia1
Date: 2023-03-30T14:54:53-07:00
New Revision: 6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf

URL: https://github.com/llvm/llvm-project/commit/6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf
DIFF: https://github.com/llvm/llvm-project/commit/6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf.diff

LOG: [mlir][sparse] Fix a problem in handling data type conversion.

Previously, the genCast function generates arith.trunci for converting f32 to
i32. Fix the function to use mlir::convertScalarToDtype to correctly handle
conversion cases beyond index casting.

Add a test case for codegen the sparse_tensor.convert op.

Reviewed By: aartbik, Peiming, wrengr

Differential Revision: https://reviews.llvm.org/D147272

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index bdd6020d9d0ac..957d41b82d23b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -208,28 +208,9 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
     return builder.create<arith::IndexCastOp>(loc, dstTp, value);
 
-  const bool ext =
-      srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth();
-
-  // float => float.
-  if (srcTp.isa<FloatType>() && dstTp.isa<FloatType>()) {
-    if (ext)
-      return builder.create<arith::ExtFOp>(loc, dstTp, value);
-    return builder.create<arith::TruncFOp>(loc, dstTp, value);
-  }
-
-  // int => int
-  const auto srcIntTp = srcTp.dyn_cast<IntegerType>();
-  if (srcIntTp && dstTp.isa<IntegerType>()) {
-    if (!ext)
-      return builder.create<arith::TruncIOp>(loc, dstTp, value);
-    if (srcIntTp.isUnsigned())
-      return builder.create<arith::ExtUIOp>(loc, dstTp, value);
-    if (srcIntTp.isSigned())
-      return builder.create<arith::ExtSIOp>(loc, dstTp, value);
-  }
-
-  llvm_unreachable("unhandled type casting");
+  const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
+  const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
+  return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 87db4743026db..4a54212657373 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -663,6 +663,22 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?x
   return %0 : tensor<?xf32, #SparseVector>
 }
 
+// CHECK-LABEL: func.func @sparse_convert_element_type(
+//  CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*]]: memref<?xf32>,
+//  CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: scf.for
+//       CHECK:   %[[FValue:.*]] = memref.load
+//       CHECK:   %[[IValue:.*]] = arith.fptosi %[[FValue]]
+//       CHECK:   memref.store %[[IValue]]
+//       CHECK: return  %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] :
+//  CHECK-SAME:   memref<?xi32>, memref<?xi64>, memref<?xi32>, !sparse_tensor.storage_specifier
+func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xi32, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xi32, #SparseVector>
+  return %0 : tensor<?xi32, #SparseVector>
+}
+
 // CHECK-LABEL: func.func @sparse_new_coo(
 // CHECK-SAME:  %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) {
 //   CHECK-DAG: %[[A1:.*]] = arith.constant false


        


More information about the Mlir-commits mailing list