[Mlir-commits] [mlir] 236a908 - [mlir][sparse] replace support lib conversion with actual MLIR codegen

Aart Bik llvmlistbot at llvm.org
Mon Aug 23 14:26:13 PDT 2021


Author: Aart Bik
Date: 2021-08-23T14:26:05-07:00
New Revision: 236a90802d5a7f6823685990fe76fd9beec9b4a5

URL: https://github.com/llvm/llvm-project/commit/236a90802d5a7f6823685990fe76fd9beec9b4a5
DIFF: https://github.com/llvm/llvm-project/commit/236a90802d5a7f6823685990fe76fd9beec9b4a5.diff

LOG: [mlir][sparse] replace support lib conversion with actual MLIR codegen

Rationale:
Passing in a pointer to the memref data in order to implement the
dense to sparse conversion was a bit too low-level. This revision
improves upon that approach with a cleaner solution of generating
a loop nest in MLIR code itself that prepares the COO object before
passing it to our "swiss army knife" setup.  This is much more
intuitive *and* now also allows for dynamic shapes.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D108491

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 94d5328e76e8d..4987e5faf0e4c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -223,8 +223,6 @@ static LogicalResult verify(ConvertOp op) {
         if (shape1[d] != shape2[d])
           return op.emitError()
                  << "unexpected conversion mismatch in dimension " << d;
-        if (shape1[d] == MemRefType::kDynamicSize)
-          return op.emitError("unexpected dynamic size");
       }
       return success();
     }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 01b180084c21a..9b55b777fbc82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -14,8 +14,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -110,10 +112,11 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
 }
 
 /// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation.
-static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
-                       SparseTensorEncodingAttr &enc, uint32_t action,
-                       Value ptr) {
+/// support library for materializing sparse tensors into the computation. The
+/// method returns the call value and assigns the permutation to 'perm'.
+static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
+                        SparseTensorEncodingAttr &enc, uint32_t action,
+                        Value &perm, Value ptr = Value()) {
   Location loc = op->getLoc();
   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
   SmallVector<Value, 8> params;
@@ -136,17 +139,16 @@ static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   // Dimension order permutation array. This is the "identity" permutation by
   // default, or otherwise the "reverse" permutation of a given ordering, so
   // that indices can be mapped quickly to the right position.
-  SmallVector<APInt, 4> perm(sz);
-  AffineMap p = enc.getDimOrdering();
-  if (p) {
-    assert(p.isPermutation() && p.getNumResults() == sz);
+  SmallVector<APInt, 4> rev(sz);
+  if (AffineMap p = enc.getDimOrdering()) {
     for (unsigned i = 0; i < sz; i++)
-      perm[p.getDimPosition(i)] = APInt(64, i);
+      rev[p.getDimPosition(i)] = APInt(64, i);
   } else {
     for (unsigned i = 0; i < sz; i++)
-      perm[i] = APInt(64, i);
+      rev[i] = APInt(64, i);
   }
-  params.push_back(getTensor(rewriter, 64, loc, perm));
+  perm = getTensor(rewriter, 64, loc, rev);
+  params.push_back(perm);
   // Secondary and primary types encoding.
   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
@@ -159,53 +161,54 @@ static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   params.push_back(
       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
   // User action and pointer.
+  Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+  if (!ptr)
+    ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
   params.push_back(
       rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
   params.push_back(ptr);
   // Generate the call to create new tensor.
-  Type ptrType =
-      LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
   StringRef name = "newSparseTensor";
-  rewriter.replaceOpWithNewOp<CallOp>(
-      op, ptrType, getFunc(op, name, ptrType, params), params);
+  auto call =
+      rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
+  return call.getResult(0);
 }
 
-/// Generates a call that exposes the data pointer as a void pointer.
-// TODO: probing the data pointer directly is a bit raw; we should replace
-//       this with proper memref util calls once they become available.
-static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op,
-                       Value val, Value &ptr) {
+/// Generates a call that adds one element to a coordinate scheme.
+static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
+                          Value ptr, Value tensor, Value ind, Value perm,
+                          ValueRange ivs) {
   Location loc = op->getLoc();
-  ShapedType sType = op->getResult(0).getType().cast<ShapedType>();
-  Type eltType = sType.getElementType();
-  // Specialize name for the data type. Even though the final buffferized
-  // version only operates on pointers, 
diff erent names are required to
-  // ensure type correctness for all intermediate states.
   StringRef name;
+  Type eltType = tensor.getType().cast<ShapedType>().getElementType();
   if (eltType.isF64())
-    name = "getPtrF64";
+    name = "addEltF64";
   else if (eltType.isF32())
-    name = "getPtrF32";
+    name = "addEltF32";
   else if (eltType.isInteger(64))
-    name = "getPtrI64";
+    name = "addEltI64";
   else if (eltType.isInteger(32))
-    name = "getPtrI32";
+    name = "addEltI32";
   else if (eltType.isInteger(16))
-    name = "getPtrI16";
+    name = "addEltI16";
   else if (eltType.isInteger(8))
-    name = "getPtrI8";
+    name = "addEltI8";
   else
-    return false;
-  auto memRefTp = MemRefType::get(sType.getShape(), eltType);
-  auto unrankedTp = UnrankedMemRefType::get(eltType, 0);
-  Value c = rewriter.create<memref::BufferCastOp>(loc, memRefTp, val);
-  Value d = rewriter.create<memref::CastOp>(loc, unrankedTp, c);
-  Type ptrType =
-      LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
-  auto call =
-      rewriter.create<CallOp>(loc, ptrType, getFunc(op, name, ptrType, d), d);
-  ptr = call.getResult(0);
-  return true;
+    llvm_unreachable("Unknown element type");
+  Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
+  // TODO: add if here?
+  unsigned i = 0;
+  for (auto iv : ivs) {
+    Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
+    rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
+  }
+  SmallVector<Value, 8> params;
+  params.push_back(ptr);
+  params.push_back(val);
+  params.push_back(ind);
+  params.push_back(perm);
+  Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+  rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +276,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
-    genNewCall(rewriter, op, enc, 0, operands[0]);
+    Value perm;
+    rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
     return success();
   }
 };
@@ -291,11 +295,46 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     //             and sparse => dense
     if (!encDst || encSrc)
       return failure();
-    // This is a dense => sparse conversion.
-    Value ptr;
-    if (!genPtrCall(rewriter, op, operands[0], ptr))
-      return failure();
-    genNewCall(rewriter, op, encDst, 1, ptr);
+    // This is a dense => sparse conversion, that is handled as follows:
+    //   t = newSparseCOO()
+    //   for i1 in dim1
+    //    ..
+    //     for ik in dimk
+    //       val = a[i1,..,ik]
+    //       if val != 0
+    //         t->add(val, [i1,..,ik], [p1,..,pk])
+    //   s = newSparseTensor(t)
+    // Note that the dense tensor traversal code is actually implemented
+    // using MLIR IR to avoid having to expose too much low-level
+    // memref traversal details to the runtime support library.
+    Location loc = op->getLoc();
+    ShapedType shape = resType.cast<ShapedType>();
+    auto memTp =
+        MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
+    Value perm;
+    Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
+    Value tensor = operands[0];
+    Value arg = rewriter.create<ConstantOp>(
+        loc, rewriter.getIndexAttr(shape.getRank()));
+    Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+    SmallVector<Value> lo;
+    SmallVector<Value> hi;
+    SmallVector<Value> st;
+    Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
+    for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+      lo.push_back(zero);
+      hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+      st.push_back(one);
+    }
+    scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
+                       [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                           ValueRange args) -> scf::ValueVector {
+                         genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
+                                       ivs);
+                         return {};
+                       });
+    rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 379d0185fbf83..6fab920fbcc4c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -99,6 +99,9 @@ struct SparseTensorConversionPass
     ConversionTarget target(*ctx);
     target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
                         ToTensorOp>();
+    // All dynamic rules below accept new function, call, return, and dimop
+    // operations as legal output of the rewriting provided that all sparse
+    // tensor types have been fully rewritten.
     target.addDynamicallyLegalOp<FuncOp>(
         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
@@ -106,8 +109,15 @@ struct SparseTensorConversionPass
     });
     target.addDynamicallyLegalOp<ReturnOp>(
         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
-    target.addLegalOp<ConstantOp, tensor::CastOp, memref::BufferCastOp,
-                      memref::CastOp>();
+    target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    // The following operations and dialects may be introduced by the
+    // rewriting rules, and are therefore marked as legal.
+    target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp>();
+    target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
+                           memref::MemRefDialect>();
+    // Populate with rules and apply rewriting rules.
     populateFuncOpTypeConversionPattern(patterns, converter);
     populateCallOpTypeConversionPattern(patterns, converter);
     populateSparseTensorConversionPatterns(converter, patterns);

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index faa36391c5198..e8f4567f904b8 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -36,7 +36,7 @@
 // (a) A coordinate scheme for temporarily storing and lexicographically
 //     sorting a sparse tensor by index.
 //
-// (b) A "one-size-fits-all" sparse storage scheme defined by per-rank
+// (b) A "one-size-fits-all" sparse tensor storage scheme defined by per-rank
 //     sparse/dense annnotations to be used by generated MLIR code.
 //
 // The following external formats are supported:
@@ -71,7 +71,7 @@ struct Element {
 template <typename V>
 struct SparseTensor {
 public:
-  SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity = 0)
+  SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
       : sizes(szs), pos(0) {
     if (capacity)
       elements.reserve(capacity);
@@ -94,6 +94,16 @@ struct SparseTensor {
   /// Getter for elements array.
   const std::vector<Element<V>> &getElements() const { return elements; }
 
+  /// Factory method.
+  static SparseTensor<V> *newSparseTensor(uint64_t size, uint64_t *sizes,
+                                          uint64_t *perm,
+                                          uint64_t capacity = 0) {
+    std::vector<uint64_t> indices(size);
+    for (uint64_t r = 0; r < size; r++)
+      indices[perm[r]] = sizes[r];
+    return new SparseTensor<V>(indices, capacity);
+  }
+
 private:
   /// Returns true if indices of e1 < indices of e2.
   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
@@ -155,8 +165,9 @@ class SparseTensorStorageBase {
 template <typename P, typename I, typename V>
 class SparseTensorStorage : public SparseTensorStorageBase {
 public:
-  /// Constructs sparse tensor storage scheme following the given
-  /// per-rank dimension dense/sparse annotations.
+  /// Constructs a sparse tensor storage scheme from the given sparse
+  /// tensor in coordinate scheme following the given per-rank dimension
+  /// dense/sparse annotations.
   SparseTensorStorage(SparseTensor<V> *tensor, uint8_t *sparsity)
       : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
     // Provide hints on capacity.
@@ -192,7 +203,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   }
   void getValues(std::vector<V> **out) override { *out = &values; }
 
-  // Factory method.
+  /// Factory method.
   static SparseTensorStorage<P, I, V> *newSparseTensor(SparseTensor<V> *t,
                                                        uint8_t *s) {
     t->sort(); // sort lexicographically
@@ -202,10 +213,9 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   }
 
 private:
-  /// Initializes sparse tensor storage scheme from a memory-resident
-  /// representation of an external sparse tensor. This method prepares
-  /// the pointers and indices arrays under the given per-rank dimension
-  /// dense/sparse annotations.
+  /// Initializes sparse tensor storage scheme from a memory-resident sparse
+  /// tensor in coordinate scheme. This method prepares the pointers and indices
+  /// arrays under the given per-rank dimension dense/sparse annotations.
   void traverse(SparseTensor<V> *tensor, uint8_t *sparsity, uint64_t lo,
                 uint64_t hi, uint64_t d) {
     const std::vector<Element<V>> &elements = tensor->getElements();
@@ -355,14 +365,13 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t size,
   // and the number of nonzeros as initial capacity.
   assert(size == idata[0] && "rank mismatch");
   uint64_t nnz = idata[1];
+  for (uint64_t r = 0; r < size; r++)
+    assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
+           "dimension size mismatch");
+  SparseTensor<V> *tensor =
+      SparseTensor<V>::newSparseTensor(size, idata + 2, perm, nnz);
+  //  Read all nonzero elements.
   std::vector<uint64_t> indices(size);
-  for (uint64_t r = 0; r < size; r++) {
-    uint64_t sz = idata[2 + r];
-    assert((sizes[r] == 0 || sizes[r] == sz) && "dimension size mismatch");
-    indices[perm[r]] = sz;
-  }
-  SparseTensor<V> *tensor = new SparseTensor<V>(indices, nnz);
-  // Read all nonzero elements.
   for (uint64_t k = 0; k < nnz; k++) {
     uint64_t idx = -1;
     for (uint64_t r = 0; r < size; r++) {
@@ -387,39 +396,6 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t size,
   return tensor;
 }
 
-/// Helper to copy a linearized dense tensor.
-template <typename V>
-static V *copyTensorTraverse(SparseTensor<V> *tensor,
-                             std::vector<uint64_t> &indices, uint64_t r,
-                             uint64_t rank, uint64_t *sizes, uint64_t *perm,
-                             V *data) {
-  for (uint64_t i = 0, sz = sizes[r]; i < sz; i++) {
-    indices[perm[r]] = i;
-    if (r + 1 == rank) {
-      V d = *data++;
-      if (d)
-        tensor->add(indices, d);
-    } else {
-      data =
-          copyTensorTraverse(tensor, indices, r + 1, rank, sizes, perm, data);
-    }
-  }
-  return data;
-}
-
-/// Copies the nonzeros of a linearized dense tensor into a memory-resident
-/// sparse tensor in coordinate scheme.
-template <typename V>
-static SparseTensor<V> *copyTensor(uint64_t size, uint64_t *sizes,
-                                   uint64_t *perm, V *data) {
-  std::vector<uint64_t> indices(size);
-  for (uint64_t r = 0; r < size; r++)
-    indices[perm[r]] = sizes[r];
-  SparseTensor<V> *tensor = new SparseTensor<V>(indices);
-  copyTensorTraverse<V>(tensor, indices, 0, size, sizes, perm, data);
-  return tensor;
-}
-
 } // anonymous namespace
 
 extern "C" {
@@ -445,11 +421,6 @@ char *getTensorFilename(uint64_t id) {
 //
 //===----------------------------------------------------------------------===//
 
-struct UnrankedMemRef {
-  uint64_t rank;
-  void *descriptor;
-};
-
 #define TEMPLATE(NAME, TYPE)                                                   \
   struct NAME {                                                                \
     const TYPE *base;                                                          \
@@ -464,8 +435,10 @@ struct UnrankedMemRef {
     SparseTensor<V> *tensor;                                                   \
     if (action == 0)                                                           \
       tensor = openTensor<V>(static_cast<char *>(ptr), asize, sizes, perm);    \
+    else if (action == 1)                                                      \
+      tensor = static_cast<SparseTensor<V> *>(ptr);                            \
     else                                                                       \
-      tensor = copyTensor<V>(asize, sizes, perm, static_cast<V *>(ptr));       \
+      return SparseTensor<V>::newSparseTensor(asize, sizes, perm);             \
     return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity);    \
   }
 
@@ -483,8 +456,22 @@ struct UnrankedMemRef {
     return {v->data(), v->data(), 0, {v->size()}, {1}};                        \
   }
 
-#define PTR(NAME)                                                              \
-  const void *NAME(int64_t sz, UnrankedMemRef *m) { return m->descriptor; }
+#define IMPL3(NAME, TYPE)                                                      \
+  void *NAME(void *tensor, TYPE value, uint64_t *ibase, uint64_t *idata,       \
+             uint64_t ioff, uint64_t isize, uint64_t istride, uint64_t *pbase, \
+             uint64_t *pdata, uint64_t poff, uint64_t psize,                   \
+             uint64_t pstride) {                                               \
+    assert(istride == 1 && pstride == 1 && isize == psize);                    \
+    uint64_t *indx = idata + ioff;                                             \
+    if (!value)                                                                \
+      return tensor;                                                           \
+    uint64_t *perm = pdata + poff;                                             \
+    std::vector<uint64_t> indices(isize);                                      \
+    for (uint64_t r = 0; r < isize; r++)                                       \
+      indices[perm[r]] = indx[r];                                              \
+    static_cast<SparseTensor<TYPE> *>(tensor)->add(indices, value);            \
+    return tensor;                                                             \
+  }
 
 TEMPLATE(MemRef1DU64, uint64_t);
 TEMPLATE(MemRef1DU32, uint32_t);
@@ -510,6 +497,10 @@ enum PrimaryTypeEnum : uint64_t {
 
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
+///  action
+///    0 : ptr contains filename to read into storage
+///    1 : ptr contains coordinate scheme to assign to storage
+///    2 : returns coordinate scheme to fill (call back later with 1)
 void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff,
                       uint64_t asize, uint64_t astride, uint64_t *sbase,
                       uint64_t *sdata, uint64_t soff, uint64_t ssize,
@@ -518,6 +509,7 @@ void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff,
                       uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
                       uint32_t action, void *ptr) {
   assert(astride == 1 && sstride == 1 && pstride == 1);
+  assert(asize == ssize && ssize == psize);
   uint8_t *sparsity = adata + aoff;
   uint64_t *sizes = sdata + soff;
   uint64_t *perm = pdata + poff;
@@ -606,18 +598,19 @@ void delSparseTensor(void *tensor) {
   delete static_cast<SparseTensorStorageBase *>(tensor);
 }
 
-/// Helper to get pointer, one per value type.
-PTR(getPtrF64)
-PTR(getPtrF32)
-PTR(getPtrI64)
-PTR(getPtrI32)
-PTR(getPtrI16)
-PTR(getPtrI8)
+/// Helper to add value to coordinate scheme, one per value type.
+IMPL3(addEltF64, double)
+IMPL3(addEltF32, float)
+IMPL3(addEltI64, int64_t)
+IMPL3(addEltI32, int32_t)
+IMPL3(addEltI16, int16_t)
+IMPL3(addEltI8, int8_t)
 
 #undef TEMPLATE
 #undef CASE
 #undef IMPL1
 #undef IMPL2
+#undef IMPL3
 
 } // extern "C"
 

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 5a2e3b1356720..33ddfa67543a3 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -112,24 +112,93 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
   return %0 : tensor<?x?x?xf32, #SparseTensor>
 }
 
-// CHECK-LABEL: func @sparse_convert(
+// CHECK-LABEL: func @sparse_convert_1d(
+//  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG: %[[D0:.*]] = constant dense<0> : tensor<1xi64>
+//   CHECK-DAG: %[[D1:.*]] = constant dense<1> : tensor<1xi8>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[D1]] : tensor<1xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[D0]] : tensor<1xi64> to tensor<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
+//       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
+//       CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
+//       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
+//       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
+//       CHECK:   memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
+//       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Y]])
+//       CHECK: }
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
+  return %0 : tensor<?xi32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_convert_2d(
 //  CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = constant 1 : index
 //   CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8>
 //   CHECK-DAG: %[[V:.*]] = constant dense<[2, 4]> : tensor<2xi64>
 //   CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64>
-//       CHECK: %[[C:.*]] = memref.buffer_cast %arg0 : memref<2x4xf64>
-//       CHECK: %[[M:.*]] = memref.cast %[[C]] : memref<2x4xf64> to memref<*xf64>
-//       CHECK: %[[C:.*]] = call @getPtrF64(%[[M]]) : (memref<*xf64>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
 //   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
+//       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
+//       CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
+//       CHECK:   scf.for %[[J:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
+//       CHECK:     %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]]] : tensor<2x4xf64>
+//       CHECK:     memref.store %[[I]], %[[M]][%[[C0]]] : memref<2xindex>
+//       CHECK:     memref.store %[[J]], %[[M]][%[[C1]]] : memref<2xindex>
+//       CHECK:     call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:   }
+//       CHECK: }
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_convert(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
+func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
   return %0 : tensor<2x4xf64, #SparseMatrix>
 }
 
+// CHECK-LABEL: func @sparse_convert_3d(
+//  CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = constant 2 : index
+//   CHECK-DAG: %[[U:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
+//   CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<3xi64>
+//   CHECK-DAG: %[[W:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
+//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
+//       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
+//       CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
+//       CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
+//       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
+//       CHECK:   scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] {
+//       CHECK:     scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] {
+//       CHECK:       %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]], %[[K]]] : tensor<?x?x?xf64>
+//       CHECK:       memref.store %[[I]], %[[M]][%[[C0]]] : memref<3xindex>
+//       CHECK:       memref.store %[[J]], %[[M]][%[[C1]]] : memref<3xindex>
+//       CHECK:       memref.store %[[K]], %[[M]][%[[C2]]] : memref<3xindex>
+//       CHECK:       call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK: }
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
+  return %0 : tensor<?x?x?xf64, #SparseTensor>
+}
+
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index


        


More information about the Mlir-commits mailing list