[Mlir-commits] [mlir] 221856f - [mlir][sparse] Moved a conditional from the RT library to the generated MLIR.

wren romano llvmlistbot at llvm.org
Thu Sep 23 12:44:23 PDT 2021


Author: wren romano
Date: 2021-09-23T12:44:17-07:00
New Revision: 221856f5cd13a877543ea6c5418330c1ee7fd715

URL: https://github.com/llvm/llvm-project/commit/221856f5cd13a877543ea6c5418330c1ee7fd715
DIFF: https://github.com/llvm/llvm-project/commit/221856f5cd13a877543ea6c5418330c1ee7fd715.diff

LOG: [mlir][sparse] Moved a conditional from the RT library to the generated MLIR.

When generating code to add an element to SparseTensorCOO (e.g., when doing dense=>sparse conversion), we used to check for nonzero values on the runtime side, whereas now we generate MLIR code to do that check.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110121

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c5997f6990f2..5df5477b6f2c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -182,11 +182,27 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   return call.getResult(0);
 }
 
+/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
+/// For floating types, we use the "unordered" comparator (i.e., returns
+/// true if `v` is NaN).
+static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
+                          Type t, Value v) {
+  Value zero = rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
+  if (t.isa<FloatType>())
+    return rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, v, zero);
+  if (t.isIntOrIndex())
+    return rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, v, zero);
+  llvm_unreachable("Unknown element type");
+}
+
 /// Generates a call that adds one element to a coordinate scheme.
+/// In particular, this generates code like the following:
+///   val = a[i1,..,ik];
+///   if val != 0
+///     t->add(val, [i1,..,ik], [p1,..,pk]);
 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
                           Value ptr, Value tensor, Value ind, Value perm,
                           ValueRange ivs) {
-  Location loc = op->getLoc();
   StringRef name;
   Type eltType = tensor.getType().cast<ShapedType>().getElementType();
   if (eltType.isF64())
@@ -203,8 +219,11 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
     name = "addEltI8";
   else
     llvm_unreachable("Unknown element type");
+  Location loc = op->getLoc();
   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
-  // TODO: add if here?
+  Value cond = genIsNonzero(rewriter, loc, eltType, val);
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
+  rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
   unsigned i = 0;
   for (auto iv : ivs) {
     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
@@ -321,6 +340,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     // Note that the dense tensor traversal code is actually implemented
     // using MLIR IR to avoid having to expose too much low-level
     // memref traversal details to the runtime support library.
+    // Also note that the code below only generates the "new" ops and
+    // the loop-nest per se; whereas the entire body of the innermost
+    // loop is generated by genAddElt().
     Location loc = op->getLoc();
     ShapedType shape = resType.cast<ShapedType>();
     auto memTp =

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 6fab920fbcc4..f452a25912d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -1,4 +1,4 @@
-//===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
+//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -114,7 +114,8 @@ struct SparseTensorConversionPass
     });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp>();
+    target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp, CmpFOp,
+                      CmpIOp>();
     target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
                            memref::MemRefDialect>();
     // Populate with rules and apply rewriting rules.

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index a642a92c10f4..6a5943dc6acc 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -548,8 +548,6 @@ char *getTensorFilename(uint64_t id) {
   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
                             StridedMemRefType<uint64_t, 1> *iref,              \
                             StridedMemRefType<uint64_t, 1> *pref) {            \
-    if (!value)                                                                \
-      return tensor;                                                           \
     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
     assert(iref->sizes[0] == pref->sizes[0]);                                  \
     const uint64_t *indx = iref->data + iref->offset;                          \


        


More information about the Mlir-commits mailing list