[Mlir-commits] [mlir] ca01034 - [mlir][sparse] Factoring out getZero() and avoiding unnecessary Type params

wren romano llvmlistbot at llvm.org
Fri Oct 1 14:18:01 PDT 2021


Author: wren romano
Date: 2021-10-01T14:17:53-07:00
New Revision: ca010347145d2f03052f50a327bb84f4efd1fa49

URL: https://github.com/llvm/llvm-project/commit/ca010347145d2f03052f50a327bb84f4efd1fa49
DIFF: https://github.com/llvm/llvm-project/commit/ca010347145d2f03052f50a327bb84f4efd1fa49.diff

LOG: [mlir][sparse] Factoring out getZero() and avoiding unnecessary Type params

This is preliminary work towards D110790

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110882

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index cd6fea5355928..b5015931c44ba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -182,12 +182,19 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   return call.getResult(0);
 }
 
+/// Generates a constant zero of the given type.
+static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
+                     Type t) {
+  return rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
+}
+
 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
 /// For floating types, we use the "unordered" comparator (i.e., returns
 /// true if `v` is NaN).
 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
-                          Type t, Value v) {
-  Value zero = rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
+                          Value v) {
+  Type t = v.getType();
+  Value zero = getZero(rewriter, loc, t);
   if (t.isa<FloatType>())
     return rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, v, zero);
   if (t.isIntOrIndex())
@@ -203,11 +210,11 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
 ///    if (tensor[ivs]!=0) {
 ///      ind = ivs
 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
-                                      Operation *op, Type eltType, Value tensor,
-                                      Value ind, ValueRange ivs) {
+                                      Operation *op, Value tensor, Value ind,
+                                      ValueRange ivs) {
   Location loc = op->getLoc();
   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
-  Value cond = genIsNonzero(rewriter, loc, eltType, val);
+  Value cond = genIsNonzero(rewriter, loc, val);
   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
   unsigned i = 0;
@@ -446,8 +453,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
                            val = genIndexAndValueForSparse(
                                rewriter, op, indices, values, ind, ivs, rank);
                          else
-                           val = genIndexAndValueForDense(rewriter, op, eltType,
-                                                          tensor, ind, ivs);
+                           val = genIndexAndValueForDense(rewriter, op, tensor,
+                                                          ind, ivs);
                          genAddEltCall(rewriter, op, eltType, ptr, val, ind,
                                        perm);
                          return {};


        


More information about the Mlir-commits mailing list