[Mlir-commits] [mlir] 7a68225 - [mlir][sparse] Cleaning up code style for genCast

wren romano llvmlistbot at llvm.org
Tue Mar 7 14:43:47 PST 2023


Author: wren romano
Date: 2023-03-07T14:43:40-08:00
New Revision: 7a68225428ba62571a637d3924d81c22a4de7683

URL: https://github.com/llvm/llvm-project/commit/7a68225428ba62571a637d3924d81c22a4de7683
DIFF: https://github.com/llvm/llvm-project/commit/7a68225428ba62571a637d3924d81c22a4de7683.diff

LOG: [mlir][sparse] Cleaning up code style for genCast

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145432

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index c3344ba1af6e2..40836f4d77781 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -199,36 +199,37 @@ StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
 //===----------------------------------------------------------------------===//
 
 Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
-                             Type dstTy) {
-  Type srcTy = value.getType();
-  if (srcTy != dstTy) {
-    // int <=> index
-    if (dstTy.isa<IndexType>() || srcTy.isa<IndexType>())
-      return builder.create<arith::IndexCastOp>(loc, dstTy, value);
-
-    bool ext = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
-
-    // float => float.
-    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && ext)
-      return builder.create<arith::ExtFOp>(loc, dstTy, value);
-
-    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !ext)
-      return builder.create<arith::TruncFOp>(loc, dstTy, value);
-
-    // int => int
-    if (srcTy.isUnsignedInteger() && dstTy.isa<IntegerType>() && ext)
-      return builder.create<arith::ExtUIOp>(loc, dstTy, value);
-
-    if (srcTy.isSignedInteger() && dstTy.isa<IntegerType>() && ext)
-      return builder.create<arith::ExtSIOp>(loc, dstTy, value);
-
-    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !ext)
-      return builder.create<arith::TruncIOp>(loc, dstTy, value);
+                             Type dstTp) {
+  const Type srcTp = value.getType();
+  if (srcTp == dstTp)
+    return value;
+
+  // int <=> index
+  if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
+    return builder.create<arith::IndexCastOp>(loc, dstTp, value);
+
+  const bool ext =
+      srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth();
+
+  // float => float.
+  if (srcTp.isa<FloatType>() && dstTp.isa<FloatType>()) {
+    if (ext)
+      return builder.create<arith::ExtFOp>(loc, dstTp, value);
+    return builder.create<arith::TruncFOp>(loc, dstTp, value);
+  }
 
-    llvm_unreachable("unhandled type casting");
+  // int => int
+  const auto srcIntTp = srcTp.dyn_cast<IntegerType>();
+  if (srcIntTp && dstTp.isa<IntegerType>()) {
+    if (!ext)
+      return builder.create<arith::TruncIOp>(loc, dstTp, value);
+    if (srcIntTp.isUnsigned())
+      return builder.create<arith::ExtUIOp>(loc, dstTp, value);
+    if (srcIntTp.isSigned())
+      return builder.create<arith::ExtSIOp>(loc, dstTp, value);
   }
 
-  return value;
+  llvm_unreachable("unhandled type casting");
 }
 
 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {


        


More information about the Mlir-commits mailing list