[Mlir-commits] [mlir] 05c7f45 - [mlir][sparse] add dense to sparse conversion implementation

Aart Bik llvmlistbot at llvm.org
Mon Aug 9 12:12:57 PDT 2021


Author: Aart Bik
Date: 2021-08-09T12:12:39-07:00
New Revision: 05c7f450dfce16dc1360eb8ec2fbd1858b8ede6a

URL: https://github.com/llvm/llvm-project/commit/05c7f450dfce16dc1360eb8ec2fbd1858b8ede6a
DIFF: https://github.com/llvm/llvm-project/commit/05c7f450dfce16dc1360eb8ec2fbd1858b8ede6a.diff

LOG: [mlir][sparse] add dense to sparse conversion implementation

Implements lowering dense to sparse conversion, for static tensor types only.
First step towards general sparse_tensor.convert support.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D107681

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ae09e3134be0b..94d5328e76e8d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -219,10 +219,13 @@ static LogicalResult verify(ConvertOp op) {
       assert(tp1.getRank() == tp2.getRank());
       auto shape1 = tp1.getShape();
       auto shape2 = tp2.getShape();
-      for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
+      for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
         if (shape1[d] != shape2[d])
           return op.emitError()
                  << "unexpected conversion mismatch in dimension " << d;
+        if (shape1[d] == MemRefType::kDynamicSize)
+          return op.emitError("unexpected dynamic size");
+      }
       return success();
     }
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 0f3c811f1fe3e..c5f4b07fb4981 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -27,6 +27,10 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
 /// Returns internal type encoding for primary storage. Keep these
 /// values consistent with the sparse runtime support library.
 static unsigned getPrimaryTypeEncoding(Type tp) {
@@ -105,6 +109,109 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
   return SymbolRefAttr::get(context, name);
 }
 
+/// Generates a call into the "swiss army knife" method of the sparse runtime
+/// support library for materializing sparse tensors into the computation.
+static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
+                       SparseTensorEncodingAttr &enc, uint32_t action,
+                       Value ptr) {
+  Location loc = op->getLoc();
+  ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
+  SmallVector<Value, 8> params;
+  // Sparsity annotations in tensor constant form.
+  SmallVector<APInt, 4> attrs;
+  unsigned sz = enc.getDimLevelType().size();
+  for (unsigned i = 0; i < sz; i++)
+    attrs.push_back(
+        APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
+  params.push_back(getTensor(rewriter, 8, loc, attrs));
+  // Dimension sizes array of the enveloping *dense* tensor. Useful for either
+  // verification of external data, or for construction of internal data.
+  auto shape = resType.getShape();
+  SmallVector<APInt, 4> sizes;
+  for (unsigned i = 0; i < sz; i++) {
+    uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
+    sizes.push_back(APInt(64, s));
+  }
+  params.push_back(getTensor(rewriter, 64, loc, sizes));
+  // Dimension order permutation array. This is the "identity" permutation by
+  // default, or otherwise the "reverse" permutation of a given ordering, so
+  // that indices can be mapped quickly to the right position.
+  SmallVector<APInt, 4> perm(sz);
+  AffineMap p = enc.getDimOrdering();
+  if (p) {
+    assert(p.isPermutation() && p.getNumResults() == sz);
+    for (unsigned i = 0; i < sz; i++)
+      perm[p.getDimPosition(i)] = APInt(64, i);
+  } else {
+    for (unsigned i = 0; i < sz; i++)
+      perm[i] = APInt(64, i);
+  }
+  params.push_back(getTensor(rewriter, 64, loc, perm));
+  // Secondary and primary types encoding.
+  unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
+  unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
+  unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
+  assert(primary);
+  params.push_back(
+      rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
+  params.push_back(
+      rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
+  params.push_back(
+      rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
+  // User action and pointer.
+  params.push_back(
+      rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
+  params.push_back(ptr);
+  // Generate the call to create new tensor.
+  Type ptrType =
+      LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+  StringRef name = "newSparseTensor";
+  rewriter.replaceOpWithNewOp<CallOp>(
+      op, ptrType, getFunc(op, name, ptrType, params), params);
+}
+
+/// Generates a call that exposes the data pointer as a void pointer.
+// TODO: probing the data pointer directly is a bit raw; we should replace
+//       this with proper memref util calls once they become available.
+static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op,
+                       Value val, Value &ptr) {
+  Location loc = op->getLoc();
+  ShapedType sType = op->getResult(0).getType().cast<ShapedType>();
+  Type eltType = sType.getElementType();
+  // Specialize name for the data type. Even though the final buffferized
+  // version only operates on pointers, 
diff erent names are required to
+  // ensure type correctness for all intermediate states.
+  StringRef name;
+  if (eltType.isF64())
+    name = "getPtrF64";
+  else if (eltType.isF32())
+    name = "getPtrF32";
+  else if (eltType.isInteger(64))
+    name = "getPtrI64";
+  else if (eltType.isInteger(32))
+    name = "getPtrI32";
+  else if (eltType.isInteger(16))
+    name = "getPtrI16";
+  else if (eltType.isInteger(8))
+    name = "getPtrI8";
+  else
+    return false;
+  auto memRefTp = MemRefType::get(sType.getShape(), eltType);
+  auto unrankedTp = UnrankedMemRefType::get(eltType, 0);
+  Value c = rewriter.create<memref::BufferCastOp>(loc, memRefTp, val);
+  Value d = rewriter.create<memref::CastOp>(loc, unrankedTp, c);
+  Type ptrType =
+      LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+  auto call =
+      rewriter.create<CallOp>(loc, ptrType, getFunc(op, name, ptrType, d), d);
+  ptr = call.getResult(0);
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
 /// Sparse conversion rule for returns.
 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
 public:
@@ -141,56 +248,11 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
   LogicalResult
   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
     Type resType = op.getType();
-    Type eltType = resType.cast<ShapedType>().getElementType();
-    MLIRContext *context = op->getContext();
-    SmallVector<Value, 5> params;
-    // Sparse encoding.
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
-    // User pointer.
-    params.push_back(operands[0]);
-    // Sparsity annotations in tensor constant form.
-    SmallVector<APInt, 4> attrs;
-    unsigned sz = enc.getDimLevelType().size();
-    for (unsigned i = 0; i < sz; i++)
-      attrs.push_back(
-          APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
-    params.push_back(getTensor(rewriter, 8, loc, attrs));
-    // Dimension order permutation array. This is the "identity"
-    // permutation by default, or otherwise the "reverse" permutation
-    // of a given ordering, so that indices can be mapped quickly
-    // to the right position.
-    SmallVector<APInt, 4> perm(sz);
-    AffineMap p = enc.getDimOrdering();
-    if (p) {
-      assert(p.isPermutation() && p.getNumResults() == sz);
-      for (unsigned i = 0; i < sz; i++)
-        perm[p.getDimPosition(i)] = APInt(64, i);
-    } else {
-      for (unsigned i = 0; i < sz; i++)
-        perm[i] = APInt(64, i);
-    }
-    params.push_back(getTensor(rewriter, 64, loc, perm));
-    // Secondary and primary types encoding.
-    unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
-    unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
-    unsigned primary = getPrimaryTypeEncoding(eltType);
-    if (!primary)
-      return failure();
-    params.push_back(
-        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
-    params.push_back(
-        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
-    params.push_back(
-        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
-    // Generate the call to create new tensor.
-    Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    StringRef name = "newSparseTensor";
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, ptrType, getFunc(op, name, ptrType, params), params);
+    genNewCall(rewriter, op, enc, 0, operands[0]);
     return success();
   }
 };
@@ -201,8 +263,19 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
   LogicalResult
   matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    // TODO: implement conversions lowering
-    return failure();
+    Type resType = op.getType();
+    auto encDst = getSparseTensorEncoding(resType);
+    auto encSrc = getSparseTensorEncoding(op.source().getType());
+    // TODO: implement sparse => sparse
+    //             and sparse => dense
+    if (!encDst || encSrc)
+      return failure();
+    // This is a dense => sparse conversion.
+    Value ptr;
+    if (!genPtrCall(rewriter, op, operands[0], ptr))
+      return failure();
+    genNewCall(rewriter, op, encDst, 1, ptr);
+    return success();
   }
 };
 
@@ -325,6 +398,10 @@ class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
 /// Populates the given patterns list with conversion rules required for
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index a779c6ef2aae5..379d0185fbf83 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -97,7 +97,8 @@ struct SparseTensorConversionPass
     RewritePatternSet patterns(ctx);
     SparseTensorTypeConverter converter;
     ConversionTarget target(*ctx);
-    target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>();
+    target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
+                        ToTensorOp>();
     target.addDynamicallyLegalOp<FuncOp>(
         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
@@ -105,8 +106,8 @@ struct SparseTensorConversionPass
     });
     target.addDynamicallyLegalOp<ReturnOp>(
         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
-    target.addLegalOp<ConstantOp>();
-    target.addLegalOp<tensor::CastOp>();
+    target.addLegalOp<ConstantOp, tensor::CastOp, memref::BufferCastOp,
+                      memref::CastOp>();
     populateFuncOpTypeConversionPattern(patterns, converter);
     populateCallOpTypeConversionPattern(patterns, converter);
     populateSparseTensorConversionPatterns(converter, patterns);

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index b928842d431b9..faa36391c5198 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -71,14 +71,15 @@ struct Element {
 template <typename V>
 struct SparseTensor {
 public:
-  SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
+  SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity = 0)
       : sizes(szs), pos(0) {
-    elements.reserve(capacity);
+    if (capacity)
+      elements.reserve(capacity);
   }
   /// Adds element as indices and value.
   void add(const std::vector<uint64_t> &ind, V val) {
     assert(getRank() == ind.size());
-    for (int64_t r = 0, rank = getRank(); r < rank; r++)
+    for (uint64_t r = 0, rank = getRank(); r < rank; r++)
       assert(ind[r] < sizes[r]); // within bounds
     elements.emplace_back(Element<V>(ind, val));
   }
@@ -97,7 +98,7 @@ struct SparseTensor {
   /// Returns true if indices of e1 < indices of e2.
   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
     assert(e1.indices.size() == e2.indices.size());
-    for (int64_t r = 0, rank = e1.indices.size(); r < rank; r++) {
+    for (uint64_t r = 0, rank = e1.indices.size(); r < rank; r++) {
       if (e1.indices[r] == e2.indices[r])
         continue;
       return e1.indices[r] < e2.indices[r];
@@ -332,7 +333,8 @@ static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
 /// Reads a sparse tensor with the given filename into a memory-resident
 /// sparse tensor in coordinate scheme.
 template <typename V>
-static SparseTensor<V> *openTensor(char *filename, uint64_t *perm) {
+static SparseTensor<V> *openTensor(char *filename, uint64_t size,
+                                   uint64_t *sizes, uint64_t *perm) {
   // Open the file.
   FILE *file = fopen(filename, "r");
   if (!file) {
@@ -351,16 +353,19 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t *perm) {
   }
   // Prepare sparse tensor object with per-rank dimension sizes
   // and the number of nonzeros as initial capacity.
-  uint64_t rank = idata[0];
+  assert(size == idata[0] && "rank mismatch");
   uint64_t nnz = idata[1];
-  std::vector<uint64_t> indices(rank);
-  for (uint64_t r = 0; r < rank; r++)
-    indices[perm[r]] = idata[2 + r];
+  std::vector<uint64_t> indices(size);
+  for (uint64_t r = 0; r < size; r++) {
+    uint64_t sz = idata[2 + r];
+    assert((sizes[r] == 0 || sizes[r] == sz) && "dimension size mismatch");
+    indices[perm[r]] = sz;
+  }
   SparseTensor<V> *tensor = new SparseTensor<V>(indices, nnz);
   // Read all nonzero elements.
   for (uint64_t k = 0; k < nnz; k++) {
     uint64_t idx = -1;
-    for (uint64_t r = 0; r < rank; r++) {
+    for (uint64_t r = 0; r < size; r++) {
       if (fscanf(file, "%" PRIu64, &idx) != 1) {
         fprintf(stderr, "Cannot find next index in %s\n", filename);
         exit(1);
@@ -382,6 +387,39 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t *perm) {
   return tensor;
 }
 
+/// Helper to copy a linearized dense tensor.
+template <typename V>
+static V *copyTensorTraverse(SparseTensor<V> *tensor,
+                             std::vector<uint64_t> &indices, uint64_t r,
+                             uint64_t rank, uint64_t *sizes, uint64_t *perm,
+                             V *data) {
+  for (uint64_t i = 0, sz = sizes[r]; i < sz; i++) {
+    indices[perm[r]] = i;
+    if (r + 1 == rank) {
+      V d = *data++;
+      if (d)
+        tensor->add(indices, d);
+    } else {
+      data =
+          copyTensorTraverse(tensor, indices, r + 1, rank, sizes, perm, data);
+    }
+  }
+  return data;
+}
+
+/// Copies the nonzeros of a linearized dense tensor into a memory-resident
+/// sparse tensor in coordinate scheme.
+template <typename V>
+static SparseTensor<V> *copyTensor(uint64_t size, uint64_t *sizes,
+                                   uint64_t *perm, V *data) {
+  std::vector<uint64_t> indices(size);
+  for (uint64_t r = 0; r < size; r++)
+    indices[perm[r]] = sizes[r];
+  SparseTensor<V> *tensor = new SparseTensor<V>(indices);
+  copyTensorTraverse<V>(tensor, indices, 0, size, sizes, perm, data);
+  return tensor;
+}
+
 } // anonymous namespace
 
 extern "C" {
@@ -407,6 +445,11 @@ char *getTensorFilename(uint64_t id) {
 //
 //===----------------------------------------------------------------------===//
 
+struct UnrankedMemRef {
+  uint64_t rank;
+  void *descriptor;
+};
+
 #define TEMPLATE(NAME, TYPE)                                                   \
   struct NAME {                                                                \
     const TYPE *base;                                                          \
@@ -418,8 +461,11 @@ char *getTensorFilename(uint64_t id) {
 
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
-    SparseTensor<V> *tensor = openTensor<V>(filename, perm);                   \
-    assert(asize == tensor->getRank());                                        \
+    SparseTensor<V> *tensor;                                                   \
+    if (action == 0)                                                           \
+      tensor = openTensor<V>(static_cast<char *>(ptr), asize, sizes, perm);    \
+    else                                                                       \
+      tensor = copyTensor<V>(asize, sizes, perm, static_cast<V *>(ptr));       \
     return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity);    \
   }
 
@@ -437,6 +483,9 @@ char *getTensorFilename(uint64_t id) {
     return {v->data(), v->data(), 0, {v->size()}, {1}};                        \
   }
 
+#define PTR(NAME)                                                              \
+  const void *NAME(int64_t sz, UnrankedMemRef *m) { return m->descriptor; }
+
 TEMPLATE(MemRef1DU64, uint64_t);
 TEMPLATE(MemRef1DU32, uint32_t);
 TEMPLATE(MemRef1DU16, uint16_t);
@@ -459,13 +508,18 @@ enum PrimaryTypeEnum : uint64_t {
   kI8 = 6
 };
 
-void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
-                      uint64_t aoff, uint64_t asize, uint64_t astride,
-                      uint64_t *pbase, uint64_t *pdata, uint64_t poff,
-                      uint64_t psize, uint64_t pstride, uint64_t ptrTp,
-                      uint64_t indTp, uint64_t valTp) {
-  assert(astride == 1 && pstride == 1);
+/// Constructs a new sparse tensor. This is the "swiss army knife"
+/// method for materializing sparse tensors into the computation.
+void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff,
+                      uint64_t asize, uint64_t astride, uint64_t *sbase,
+                      uint64_t *sdata, uint64_t soff, uint64_t ssize,
+                      uint64_t sstride, uint64_t *pbase, uint64_t *pdata,
+                      uint64_t poff, uint64_t psize, uint64_t pstride,
+                      uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
+                      uint32_t action, void *ptr) {
+  assert(astride == 1 && sstride == 1 && pstride == 1);
   uint8_t *sparsity = adata + aoff;
+  uint64_t *sizes = sdata + soff;
   uint64_t *perm = pdata + poff;
 
   // Double matrices with all combinations of overhead storage.
@@ -524,10 +578,12 @@ void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
   exit(1);
 }
 
+/// Returns size of sparse tensor in given dimension.
 uint64_t sparseDimSize(void *tensor, uint64_t d) {
   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
+/// Methods that provide direct access to pointers, indices, and values.
 IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers)
 IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
 IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
@@ -545,10 +601,19 @@ IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
 IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
 IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
 
+/// Releases sparse tensor storage.
 void delSparseTensor(void *tensor) {
   delete static_cast<SparseTensorStorageBase *>(tensor);
 }
 
+/// Helper to get pointer, one per value type.
+PTR(getPtrF64)
+PTR(getPtrF32)
+PTR(getPtrI64)
+PTR(getPtrI32)
+PTR(getPtrI16)
+PTR(getPtrI8)
+
 #undef TEMPLATE
 #undef CASE
 #undef IMPL1

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index f67cf4d69c163..a2a2b4d15ce25 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize | FileCheck %s
 
 #DenseVector = #sparse_tensor.encoding<{
   dimLevelType = ["dense"]
@@ -42,11 +42,13 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
 
 // CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//       CHECK: %[[D:.*]] = constant dense<1> : tensor<1xi8>
-//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi8> to tensor<?xi8>
-//       CHECK: %[[P:.*]] = constant dense<0> : tensor<1xi64>
-//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<1xi64> to tensor<?xi64>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[U:.*]] = constant dense<1> : tensor<1xi8>
+//   CHECK-DAG: %[[V:.*]] = constant dense<128> : tensor<1xi64>
+//   CHECK-DAG: %[[W:.*]] = constant dense<0> : tensor<1xi64>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<1xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<1xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<1xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
@@ -55,11 +57,13 @@ func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 
 // CHECK-LABEL: func @sparse_new2d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//       CHECK: %[[D:.*]] = constant dense<[0, 1]> : tensor<2xi8>
-//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi8> to tensor<?xi8>
-//       CHECK: %[[P:.*]] = constant dense<[0, 1]> : tensor<2xi64>
-//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<2xi64> to tensor<?xi64>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8>
+//   CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<2xi64>
+//   CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
@@ -68,17 +72,37 @@ func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
 
 // CHECK-LABEL: func @sparse_new3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//       CHECK: %[[D:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
-//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<3xi8> to tensor<?xi8>
-//       CHECK: %[[P:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
-//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<3xi64> to tensor<?xi64>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[U:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
+//   CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<3xi64>
+//   CHECK-DAG: %[[W:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?x?xf32, #SparseTensor>
   return %0 : tensor<?x?x?xf32, #SparseTensor>
 }
 
+// CHECK-LABEL: func @sparse_convert(
+//  CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8>
+//   CHECK-DAG: %[[V:.*]] = constant dense<[2, 4]> : tensor<2xi64>
+//   CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64>
+//       CHECK: %[[C:.*]] = memref.buffer_cast %arg0 : memref<2x4xf64>
+//       CHECK: %[[M:.*]] = memref.cast %[[C]] : memref<2x4xf64> to memref<*xf64>
+//       CHECK: %[[C:.*]] = call @getPtrF64(%[[M]]) : (memref<*xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_convert(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
+  %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
+  return %0 : tensor<2x4xf64, #SparseMatrix>
+}
+
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
new file mode 100644
index 0000000000000..25b0f59647992
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize  \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+#trait_scale = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = X(i,j) * 2"
+}
+
+//
+// Integration test that lowers a kernel annotated as sparse to actual sparse
+// code, initializes a matching sparse storage scheme from a dense tensor,
+// and runs the resulting code with the JIT compiler.
+//
+module {
+  //
+  // A kernel that scales a sparse matrix A by a factor of 2.0.
+  //
+  func @sparse_scale(%argx: tensor<8x8xf32, #CSR>
+                     {linalg.inplaceable = true}) -> tensor<8x8xf32, #CSR> {
+    %c = constant 2.0 : f32
+    %0 = linalg.generic #trait_scale
+      outs(%argx: tensor<8x8xf32, #CSR>) {
+        ^bb(%x: f32):
+          %1 = mulf %x, %c : f32
+          linalg.yield %1 : f32
+    } -> tensor<8x8xf32, #CSR>
+    return %0 : tensor<8x8xf32, #CSR>
+  }
+
+  //
+  // Main driver that converts a dense tensor into a sparse tensor
+  // and then calls the sparse scaling kernel with the sparse tensor
+  // as input argument.
+  //
+  func @entry() {
+    %c0 = constant 0 : index
+    %f0 = constant 0.0 : f32
+
+    // Initialize a dense tensor.
+    %0 = constant dense<[
+       [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
+       [0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+       [0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+       [0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0],
+       [0.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0],
+       [0.0, 1.0, 1.0, 0.0, 0.0, 6.0, 0.0, 0.0],
+       [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 7.0, 1.0],
+       [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 8.0]
+    ]> : tensor<8x8xf32>
+
+    // Convert dense tensor to sparse tensor and call sparse kernel.
+    %1 = sparse_tensor.convert %0 : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
+    %2 = call @sparse_scale(%1)
+      : (tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
+
+    // Print the resulting compacted values for verification.
+    //
+    // CHECK: ( 2, 2, 2, 4, 6, 8, 2, 10, 2, 2, 12, 2, 14, 2, 2, 16 )
+    //
+    %m = sparse_tensor.values %2 : tensor<8x8xf32, #CSR> to memref<?xf32>
+    %v = vector.transfer_read %m[%c0], %f0: memref<?xf32>, vector<16xf32>
+    vector.print %v : vector<16xf32>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list