[Mlir-commits] [mlir] 6071f6f - [mlir][sparse] Fix a problem in handling data type conversion.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 30 14:54:59 PDT 2023
Author: bixia1
Date: 2023-03-30T14:54:53-07:00
New Revision: 6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf
URL: https://github.com/llvm/llvm-project/commit/6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf
DIFF: https://github.com/llvm/llvm-project/commit/6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf.diff
LOG: [mlir][sparse] Fix a problem in handling data type conversion.
Previously, the genCast function generates arith.trunci for converting f32 to
i32. Fix the function to use mlir::convertScalarToDtype to correctly handle
conversion cases beyond index casting.
Add a test case for codegen the sparse_tensor.convert op.
Reviewed By: aartbik, Peiming, wrengr
Differential Revision: https://reviews.llvm.org/D147272
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index bdd6020d9d0ac..957d41b82d23b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -208,28 +208,9 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
return builder.create<arith::IndexCastOp>(loc, dstTp, value);
- const bool ext =
- srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth();
-
- // float => float.
- if (srcTp.isa<FloatType>() && dstTp.isa<FloatType>()) {
- if (ext)
- return builder.create<arith::ExtFOp>(loc, dstTp, value);
- return builder.create<arith::TruncFOp>(loc, dstTp, value);
- }
-
- // int => int
- const auto srcIntTp = srcTp.dyn_cast<IntegerType>();
- if (srcIntTp && dstTp.isa<IntegerType>()) {
- if (!ext)
- return builder.create<arith::TruncIOp>(loc, dstTp, value);
- if (srcIntTp.isUnsigned())
- return builder.create<arith::ExtUIOp>(loc, dstTp, value);
- if (srcIntTp.isSigned())
- return builder.create<arith::ExtSIOp>(loc, dstTp, value);
- }
-
- llvm_unreachable("unhandled type casting");
+ const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
+ const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
+ return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 87db4743026db..4a54212657373 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -663,6 +663,22 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?x
return %0 : tensor<?xf32, #SparseVector>
}
+// CHECK-LABEL: func.func @sparse_convert_element_type(
+// CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK: scf.for
+// CHECK: %[[FValue:.*]] = memref.load
+// CHECK: %[[IValue:.*]] = arith.fptosi %[[FValue]]
+// CHECK: memref.store %[[IValue]]
+// CHECK: return %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] :
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xi32>, !sparse_tensor.storage_specifier
+func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xi32, #SparseVector> {
+ %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xi32, #SparseVector>
+ return %0 : tensor<?xi32, #SparseVector>
+}
+
// CHECK-LABEL: func.func @sparse_new_coo(
// CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) {
// CHECK-DAG: %[[A1:.*]] = arith.constant false
More information about the Mlir-commits
mailing list