[Mlir-commits] [mlir] efa15f4 - [mlir][sparse] add ability for sparse tensor output

Aart Bik llvmlistbot at llvm.org
Fri Jan 21 15:43:38 PST 2022


Author: Aart Bik
Date: 2022-01-21T15:43:29-08:00
New Revision: efa15f417847439eb8884f1ef47e163fa450ac7c

URL: https://github.com/llvm/llvm-project/commit/efa15f417847439eb8884f1ef47e163fa450ac7c
DIFF: https://github.com/llvm/llvm-project/commit/efa15f417847439eb8884f1ef47e163fa450ac7c.diff

LOG: [mlir][sparse] add ability for sparse tensor output

Rationale:
Although file I/O is a bit alien to MLIR itself, we provide two convenient ways
for sparse tensor I/O. The input part was already there (behind the swiss army
knife sparse_tensor.new). Now we have a sparse_tensor.out to write out data. As
before, the ops are kept vague and may change in the future. For now this
allows us to compare TACO vs MLIR very easily.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D117850

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b7fce5b3137c1..1209a70a72c22 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -351,4 +351,24 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
   let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
 }
 
+def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
+    Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
+  string summary = "Outputs a sparse tensor to the given destination";
+  string description = [{
+    Outputs the contents of a sparse tensor to the destination defined by an
+    opaque pointer provided by `dest`. For targets that have access to a file
+    system, for example, this pointer may specify a filename (or file) for output.
+    The form of the operation is kept deliberately very general to allow for
+    alternative implementations in the future, such as sending the contents to
+    a buffer defined by a pointer.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.out %t, %dest : tensor<1024x1024xf64, #CSR>, !Dest
+    ```
+  }];
+  let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
+}
+
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8f44b5b1b9e22..8a7942c8d666a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -329,6 +329,12 @@ static LogicalResult verify(ReleaseOp op) {
   return success();
 }
 
+static LogicalResult verify(OutOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor for output");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a28f9ac70b318..94e87b3b79b7f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -15,6 +15,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
+
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -189,8 +190,8 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
 /// computation.
 static void newParams(ConversionPatternRewriter &rewriter,
                       SmallVector<Value, 8> &params, Operation *op,
-                      SparseTensorEncodingAttr &enc, Action action,
-                      ValueRange szs, Value ptr = Value()) {
+                      ShapedType stp, SparseTensorEncodingAttr &enc,
+                      Action action, ValueRange szs, Value ptr = Value()) {
   Location loc = op->getLoc();
   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
   unsigned sz = dlt.size();
@@ -218,7 +219,7 @@ static void newParams(ConversionPatternRewriter &rewriter,
   }
   params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
-  Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
+  Type elemTp = stp.getElementType();
   params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
   params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
   params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
@@ -420,9 +421,10 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     // inferred from the result type of the new operator.
     SmallVector<Value, 4> sizes;
     SmallVector<Value, 8> params;
-    sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
+    ShapedType stp = resType.cast<ShapedType>();
+    sizesFromType(rewriter, sizes, op.getLoc(), stp);
     Value ptr = adaptor.getOperands()[0];
-    newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
+    newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr);
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
@@ -441,7 +443,9 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
     // Generate the call to construct empty tensor. The sizes are
     // explicitly defined by the arguments to the init operator.
     SmallVector<Value, 8> params;
-    newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
+    ShapedType stp = resType.cast<ShapedType>();
+    newParams(rewriter, params, op, stp, enc, Action::kEmpty,
+              adaptor.getOperands());
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
@@ -472,15 +476,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       }
       SmallVector<Value, 4> sizes;
       SmallVector<Value, 8> params;
-      sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
-                   src);
+      ShapedType stp = srcType.cast<ShapedType>();
+      sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
       // Set up encoding with right mix of src and dst so that the two
       // method calls can share most parameters, while still providing
       // the correct sparsity information to either of them.
       auto enc = SparseTensorEncodingAttr::get(
           op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
           encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
-      newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
+      newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
       Value coo = genNewCall(rewriter, op, params);
       params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
       params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
@@ -512,7 +516,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       SmallVector<Value, 4> sizes;
       SmallVector<Value, 8> params;
       sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
-      newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
+      newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator,
+                sizes, src);
       Value iter = genNewCall(rewriter, op, params);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
@@ -567,7 +572,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value, 4> sizes;
     SmallVector<Value, 8> params;
     sizesFromSrc(rewriter, sizes, loc, src);
-    newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
+    newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
     Value ptr = genNewCall(rewriter, op, params);
     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
     Value perm = params[2];
@@ -771,6 +776,45 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
+class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OutOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    ShapedType srcType = op.tensor().getType().cast<ShapedType>();
+    // Convert to default permuted COO.
+    Value src = adaptor.getOperands()[0];
+    auto encSrc = getSparseTensorEncoding(srcType);
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 8> params;
+    sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src);
+    auto enc = SparseTensorEncodingAttr::get(
+        op->getContext(), encSrc.getDimLevelType(), AffineMap(),
+        encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+    newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src);
+    Value coo = genNewCall(rewriter, op, params);
+    // Then output the tensor to external file with indices in the externally
+    // visible lexicographic index order. A sort is required if the source was
+    // not in that order yet (note that the sort can be dropped altogether if
+    // external format does not care about the order at all, but here we assume
+    // it does).
+    bool sort =
+        encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
+    params.clear();
+    params.push_back(coo);
+    params.push_back(adaptor.getOperands()[1]);
+    params.push_back(constantI1(rewriter, loc, sort));
+    Type eltType = srcType.getElementType();
+    SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
+    TypeRange noTp;
+    replaceOpWithFuncCall(rewriter, op, name, noTp, params,
+                          EmitCInterface::Off);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -787,6 +831,6 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
                SparseTensorLoadConverter, SparseTensorLexInsertConverter,
-               SparseTensorExpandConverter, SparseTensorCompressConverter>(
-      typeConverter, patterns.getContext());
+               SparseTensorExpandConverter, SparseTensorCompressConverter,
+               SparseTensorOutConverter>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 20cd1b53d31b4..605e17764773f 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -26,6 +26,8 @@
 #include <cstdio>
 #include <cstdlib>
 #include <cstring>
+#include <fstream>
+#include <iostream>
 #include <limits>
 #include <numeric>
 #include <vector>
@@ -713,6 +715,31 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
   return tensor;
 }
 
+/// Writes the sparse tensor to extended FROSTT format.
+template <typename V>
+void outSparseTensor(const SparseTensorCOO<V> &tensor, char *filename) {
+  auto &sizes = tensor.getSizes();
+  auto &elements = tensor.getElements();
+  uint64_t rank = tensor.getRank();
+  uint64_t nnz = elements.size();
+  std::fstream file;
+  file.open(filename, std::ios_base::out | std::ios_base::trunc);
+  assert(file.is_open());
+  file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
+  for (uint64_t r = 0; r < rank - 1; r++)
+    file << sizes[r] << " ";
+  file << sizes[rank - 1] << std::endl;
+  for (uint64_t i = 0; i < nnz; i++) {
+    auto &idx = elements[i].indices;
+    for (uint64_t r = 0; r < rank; r++)
+      file << (idx[r] + 1) << " ";
+    file << elements[i].value << std::endl;
+  }
+  file.flush();
+  file.close();
+  assert(file.good());
+}
+
 } // namespace
 
 extern "C" {
@@ -845,6 +872,17 @@ extern "C" {
         cursor, values, filled, added, count);                                 \
   }
 
+#define IMPL_OUT(NAME, V)                                                      \
+  void NAME(void *tensor, void *dest, bool sort) {                             \
+    assert(tensor &&dest);                                                     \
+    auto coo = static_cast<SparseTensorCOO<V> *>(tensor);                      \
+    if (sort)                                                                  \
+      coo->sort();                                                             \
+    char *filename = static_cast<char *>(dest);                                \
+    outSparseTensor<V>(*coo, filename);                                        \
+    delete coo;                                                                \
+  }
+
 // Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor
 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
 // that this file cannot get out of sync with its header.
@@ -1026,6 +1064,14 @@ IMPL_EXPINSERT(expInsertI32, int32_t)
 IMPL_EXPINSERT(expInsertI16, int16_t)
 IMPL_EXPINSERT(expInsertI8, int8_t)
 
+/// Helper to output a sparse tensor, one per value type.
+IMPL_OUT(outSparseTensorF64, double)
+IMPL_OUT(outSparseTensorF32, float)
+IMPL_OUT(outSparseTensorI64, int64_t)
+IMPL_OUT(outSparseTensorI32, int32_t)
+IMPL_OUT(outSparseTensorI16, int16_t)
+IMPL_OUT(outSparseTensorI8, int8_t)
+
 #undef CASE
 #undef IMPL_SPARSEVALUES
 #undef IMPL_GETOVERHEAD
@@ -1033,6 +1079,7 @@ IMPL_EXPINSERT(expInsertI8, int8_t)
 #undef IMPL_GETNEXT
 #undef IMPL_LEXINSERT
 #undef IMPL_EXPINSERT
+#undef IMPL_OUT
 
 //===----------------------------------------------------------------------===//
 //
@@ -1162,6 +1209,7 @@ void convertFromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
   *pValues = values;
   *pIndices = indices;
 }
+
 } // extern "C"
 
 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 89ee0d5b7c816..04c8a51817813 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -468,3 +468,27 @@ func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
     : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return
 }
+
+// CHECK-LABEL: func @sparse_out1(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
+//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
+//  CHECK-DAG:  %[[C:.*]] = arith.constant false
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: call @outSparseTensorF64(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//       CHECK: return
+func @sparse_out1(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
+  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #SparseMatrix>, !llvm.ptr<i8>
+  return
+}
+
+// CHECK-LABEL: func @sparse_out2(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
+//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
+//  CHECK-DAG:  %[[C:.*]] = arith.constant true
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: call @outSparseTensorF32(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//       CHECK: return
+func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr<i8>) {
+  sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 06d662127174c..84990221e4df4 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -204,3 +204,11 @@ func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10xf32,
   %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR>
   return %0 : tensor<10x10xf32, #CSR>
 }
+
+// -----
+
+func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
+  // expected-error at +1 {{expected a sparse tensor for output}}
+  sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 853befc1cdef4..5457e55f57e6a 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -179,3 +179,17 @@ func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
     : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return
 }
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_out(
+//  CHECK-SAME: %[[A:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{{.*}}>>,
+//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
+//       CHECK: sparse_tensor.out %[[A]], %[[B]] : tensor<?x?xf64, #sparse_tensor.encoding<{{.*}}>>, !llvm.ptr<i8>
+//       CHECK: return
+func @sparse_out(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
+  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #SparseMatrix>, !llvm.ptr<i8>
+  return
+}


        


More information about the Mlir-commits mailing list