[Mlir-commits] [mlir] 236a908 - [mlir][sparse] replace support lib conversion with actual MLIR codegen
Aart Bik
llvmlistbot at llvm.org
Mon Aug 23 14:26:13 PDT 2021
Author: Aart Bik
Date: 2021-08-23T14:26:05-07:00
New Revision: 236a90802d5a7f6823685990fe76fd9beec9b4a5
URL: https://github.com/llvm/llvm-project/commit/236a90802d5a7f6823685990fe76fd9beec9b4a5
DIFF: https://github.com/llvm/llvm-project/commit/236a90802d5a7f6823685990fe76fd9beec9b4a5.diff
LOG: [mlir][sparse] replace support lib conversion with actual MLIR codegen
Rationale:
Passing in a pointer to the memref data in order to implement the
dense to sparse conversion was a bit too low-level. This revision
improves upon that approach with a cleaner solution of generating
a loop nest in MLIR code itself that prepares the COO object before
passing it to our "swiss army knife" setup. This is much more
intuitive *and* now also allows for dynamic shapes.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D108491
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/ExecutionEngine/SparseUtils.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 94d5328e76e8d..4987e5faf0e4c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -223,8 +223,6 @@ static LogicalResult verify(ConvertOp op) {
if (shape1[d] != shape2[d])
return op.emitError()
<< "unexpected conversion mismatch in dimension " << d;
- if (shape1[d] == MemRefType::kDynamicSize)
- return op.emitError("unexpected dynamic size");
}
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 01b180084c21a..9b55b777fbc82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -14,8 +14,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -110,10 +112,11 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation.
-static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
- SparseTensorEncodingAttr &enc, uint32_t action,
- Value ptr) {
+/// support library for materializing sparse tensors into the computation. The
+/// method returns the call value and assigns the permutation to 'perm'.
+static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
+ SparseTensorEncodingAttr &enc, uint32_t action,
+ Value &perm, Value ptr = Value()) {
Location loc = op->getLoc();
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
SmallVector<Value, 8> params;
@@ -136,17 +139,16 @@ static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
// Dimension order permutation array. This is the "identity" permutation by
// default, or otherwise the "reverse" permutation of a given ordering, so
// that indices can be mapped quickly to the right position.
- SmallVector<APInt, 4> perm(sz);
- AffineMap p = enc.getDimOrdering();
- if (p) {
- assert(p.isPermutation() && p.getNumResults() == sz);
+ SmallVector<APInt, 4> rev(sz);
+ if (AffineMap p = enc.getDimOrdering()) {
for (unsigned i = 0; i < sz; i++)
- perm[p.getDimPosition(i)] = APInt(64, i);
+ rev[p.getDimPosition(i)] = APInt(64, i);
} else {
for (unsigned i = 0; i < sz; i++)
- perm[i] = APInt(64, i);
+ rev[i] = APInt(64, i);
}
- params.push_back(getTensor(rewriter, 64, loc, perm));
+ perm = getTensor(rewriter, 64, loc, rev);
+ params.push_back(perm);
// Secondary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
@@ -159,53 +161,54 @@ static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
// User action and pointer.
+ Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+ if (!ptr)
+ ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
params.push_back(ptr);
// Generate the call to create new tensor.
- Type ptrType =
- LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
StringRef name = "newSparseTensor";
- rewriter.replaceOpWithNewOp<CallOp>(
- op, ptrType, getFunc(op, name, ptrType, params), params);
+ auto call =
+ rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
+ return call.getResult(0);
}
-/// Generates a call that exposes the data pointer as a void pointer.
-// TODO: probing the data pointer directly is a bit raw; we should replace
-// this with proper memref util calls once they become available.
-static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op,
- Value val, Value &ptr) {
+/// Generates a call that adds one element to a coordinate scheme.
+static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
+ Value ptr, Value tensor, Value ind, Value perm,
+ ValueRange ivs) {
Location loc = op->getLoc();
- ShapedType sType = op->getResult(0).getType().cast<ShapedType>();
- Type eltType = sType.getElementType();
- // Specialize name for the data type. Even though the final buffferized
- // version only operates on pointers,
diff erent names are required to
- // ensure type correctness for all intermediate states.
StringRef name;
+ Type eltType = tensor.getType().cast<ShapedType>().getElementType();
if (eltType.isF64())
- name = "getPtrF64";
+ name = "addEltF64";
else if (eltType.isF32())
- name = "getPtrF32";
+ name = "addEltF32";
else if (eltType.isInteger(64))
- name = "getPtrI64";
+ name = "addEltI64";
else if (eltType.isInteger(32))
- name = "getPtrI32";
+ name = "addEltI32";
else if (eltType.isInteger(16))
- name = "getPtrI16";
+ name = "addEltI16";
else if (eltType.isInteger(8))
- name = "getPtrI8";
+ name = "addEltI8";
else
- return false;
- auto memRefTp = MemRefType::get(sType.getShape(), eltType);
- auto unrankedTp = UnrankedMemRefType::get(eltType, 0);
- Value c = rewriter.create<memref::BufferCastOp>(loc, memRefTp, val);
- Value d = rewriter.create<memref::CastOp>(loc, unrankedTp, c);
- Type ptrType =
- LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
- auto call =
- rewriter.create<CallOp>(loc, ptrType, getFunc(op, name, ptrType, d), d);
- ptr = call.getResult(0);
- return true;
+ llvm_unreachable("Unknown element type");
+ Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
+ // TODO: add if here?
+ unsigned i = 0;
+ for (auto iv : ivs) {
+ Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
+ rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
+ }
+ SmallVector<Value, 8> params;
+ params.push_back(ptr);
+ params.push_back(val);
+ params.push_back(ind);
+ params.push_back(perm);
+ Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
+ rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
}
//===----------------------------------------------------------------------===//
@@ -273,7 +276,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
auto enc = getSparseTensorEncoding(resType);
if (!enc)
return failure();
- genNewCall(rewriter, op, enc, 0, operands[0]);
+ Value perm;
+ rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
return success();
}
};
@@ -291,11 +295,46 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// and sparse => dense
if (!encDst || encSrc)
return failure();
- // This is a dense => sparse conversion.
- Value ptr;
- if (!genPtrCall(rewriter, op, operands[0], ptr))
- return failure();
- genNewCall(rewriter, op, encDst, 1, ptr);
+ // This is a dense => sparse conversion, that is handled as follows:
+ // t = newSparseCOO()
+ // for i1 in dim1
+ // ..
+ // for ik in dimk
+ // val = a[i1,..,ik]
+ // if val != 0
+ // t->add(val, [i1,..,ik], [p1,..,pk])
+ // s = newSparseTensor(t)
+ // Note that the dense tensor traversal code is actually implemented
+ // using MLIR IR to avoid having to expose too much low-level
+ // memref traversal details to the runtime support library.
+ Location loc = op->getLoc();
+ ShapedType shape = resType.cast<ShapedType>();
+ auto memTp =
+ MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
+ Value perm;
+ Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
+ Value tensor = operands[0];
+ Value arg = rewriter.create<ConstantOp>(
+ loc, rewriter.getIndexAttr(shape.getRank()));
+ Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+ SmallVector<Value> lo;
+ SmallVector<Value> hi;
+ SmallVector<Value> st;
+ Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
+ Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
+ for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+ lo.push_back(zero);
+ hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+ st.push_back(one);
+ }
+ scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
+ [&](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
+ ivs);
+ return {};
+ });
+ rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 379d0185fbf83..6fab920fbcc4c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -99,6 +99,9 @@ struct SparseTensorConversionPass
ConversionTarget target(*ctx);
target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
ToTensorOp>();
+ // All dynamic rules below accept new function, call, return, and dimop
+ // operations as legal output of the rewriting provided that all sparse
+ // tensor types have been fully rewritten.
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
@@ -106,8 +109,15 @@ struct SparseTensorConversionPass
});
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
- target.addLegalOp<ConstantOp, tensor::CastOp, memref::BufferCastOp,
- memref::CastOp>();
+ target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
+ return converter.isLegal(op.getOperandTypes());
+ });
+ // The following operations and dialects may be introduced by the
+ // rewriting rules, and are therefore marked as legal.
+ target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp>();
+ target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
+ memref::MemRefDialect>();
+ // Populate with rules and apply rewriting rules.
populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateSparseTensorConversionPatterns(converter, patterns);
diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index faa36391c5198..e8f4567f904b8 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -36,7 +36,7 @@
// (a) A coordinate scheme for temporarily storing and lexicographically
// sorting a sparse tensor by index.
//
-// (b) A "one-size-fits-all" sparse storage scheme defined by per-rank
+// (b) A "one-size-fits-all" sparse tensor storage scheme defined by per-rank
// sparse/dense annnotations to be used by generated MLIR code.
//
// The following external formats are supported:
@@ -71,7 +71,7 @@ struct Element {
template <typename V>
struct SparseTensor {
public:
- SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity = 0)
+ SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
: sizes(szs), pos(0) {
if (capacity)
elements.reserve(capacity);
@@ -94,6 +94,16 @@ struct SparseTensor {
/// Getter for elements array.
const std::vector<Element<V>> &getElements() const { return elements; }
+ /// Factory method.
+ static SparseTensor<V> *newSparseTensor(uint64_t size, uint64_t *sizes,
+ uint64_t *perm,
+ uint64_t capacity = 0) {
+ std::vector<uint64_t> indices(size);
+ for (uint64_t r = 0; r < size; r++)
+ indices[perm[r]] = sizes[r];
+ return new SparseTensor<V>(indices, capacity);
+ }
+
private:
/// Returns true if indices of e1 < indices of e2.
static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
@@ -155,8 +165,9 @@ class SparseTensorStorageBase {
template <typename P, typename I, typename V>
class SparseTensorStorage : public SparseTensorStorageBase {
public:
- /// Constructs sparse tensor storage scheme following the given
- /// per-rank dimension dense/sparse annotations.
+ /// Constructs a sparse tensor storage scheme from the given sparse
+ /// tensor in coordinate scheme following the given per-rank dimension
+ /// dense/sparse annotations.
SparseTensorStorage(SparseTensor<V> *tensor, uint8_t *sparsity)
: sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
// Provide hints on capacity.
@@ -192,7 +203,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
}
void getValues(std::vector<V> **out) override { *out = &values; }
- // Factory method.
+ /// Factory method.
static SparseTensorStorage<P, I, V> *newSparseTensor(SparseTensor<V> *t,
uint8_t *s) {
t->sort(); // sort lexicographically
@@ -202,10 +213,9 @@ class SparseTensorStorage : public SparseTensorStorageBase {
}
private:
- /// Initializes sparse tensor storage scheme from a memory-resident
- /// representation of an external sparse tensor. This method prepares
- /// the pointers and indices arrays under the given per-rank dimension
- /// dense/sparse annotations.
+ /// Initializes sparse tensor storage scheme from a memory-resident sparse
+ /// tensor in coordinate scheme. This method prepares the pointers and indices
+ /// arrays under the given per-rank dimension dense/sparse annotations.
void traverse(SparseTensor<V> *tensor, uint8_t *sparsity, uint64_t lo,
uint64_t hi, uint64_t d) {
const std::vector<Element<V>> &elements = tensor->getElements();
@@ -355,14 +365,13 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t size,
// and the number of nonzeros as initial capacity.
assert(size == idata[0] && "rank mismatch");
uint64_t nnz = idata[1];
+ for (uint64_t r = 0; r < size; r++)
+ assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
+ "dimension size mismatch");
+ SparseTensor<V> *tensor =
+ SparseTensor<V>::newSparseTensor(size, idata + 2, perm, nnz);
+ // Read all nonzero elements.
std::vector<uint64_t> indices(size);
- for (uint64_t r = 0; r < size; r++) {
- uint64_t sz = idata[2 + r];
- assert((sizes[r] == 0 || sizes[r] == sz) && "dimension size mismatch");
- indices[perm[r]] = sz;
- }
- SparseTensor<V> *tensor = new SparseTensor<V>(indices, nnz);
- // Read all nonzero elements.
for (uint64_t k = 0; k < nnz; k++) {
uint64_t idx = -1;
for (uint64_t r = 0; r < size; r++) {
@@ -387,39 +396,6 @@ static SparseTensor<V> *openTensor(char *filename, uint64_t size,
return tensor;
}
-/// Helper to copy a linearized dense tensor.
-template <typename V>
-static V *copyTensorTraverse(SparseTensor<V> *tensor,
- std::vector<uint64_t> &indices, uint64_t r,
- uint64_t rank, uint64_t *sizes, uint64_t *perm,
- V *data) {
- for (uint64_t i = 0, sz = sizes[r]; i < sz; i++) {
- indices[perm[r]] = i;
- if (r + 1 == rank) {
- V d = *data++;
- if (d)
- tensor->add(indices, d);
- } else {
- data =
- copyTensorTraverse(tensor, indices, r + 1, rank, sizes, perm, data);
- }
- }
- return data;
-}
-
-/// Copies the nonzeros of a linearized dense tensor into a memory-resident
-/// sparse tensor in coordinate scheme.
-template <typename V>
-static SparseTensor<V> *copyTensor(uint64_t size, uint64_t *sizes,
- uint64_t *perm, V *data) {
- std::vector<uint64_t> indices(size);
- for (uint64_t r = 0; r < size; r++)
- indices[perm[r]] = sizes[r];
- SparseTensor<V> *tensor = new SparseTensor<V>(indices);
- copyTensorTraverse<V>(tensor, indices, 0, size, sizes, perm, data);
- return tensor;
-}
-
} // anonymous namespace
extern "C" {
@@ -445,11 +421,6 @@ char *getTensorFilename(uint64_t id) {
//
//===----------------------------------------------------------------------===//
-struct UnrankedMemRef {
- uint64_t rank;
- void *descriptor;
-};
-
#define TEMPLATE(NAME, TYPE) \
struct NAME { \
const TYPE *base; \
@@ -464,8 +435,10 @@ struct UnrankedMemRef {
SparseTensor<V> *tensor; \
if (action == 0) \
tensor = openTensor<V>(static_cast<char *>(ptr), asize, sizes, perm); \
+ else if (action == 1) \
+ tensor = static_cast<SparseTensor<V> *>(ptr); \
else \
- tensor = copyTensor<V>(asize, sizes, perm, static_cast<V *>(ptr)); \
+ return SparseTensor<V>::newSparseTensor(asize, sizes, perm); \
return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity); \
}
@@ -483,8 +456,22 @@ struct UnrankedMemRef {
return {v->data(), v->data(), 0, {v->size()}, {1}}; \
}
-#define PTR(NAME) \
- const void *NAME(int64_t sz, UnrankedMemRef *m) { return m->descriptor; }
+#define IMPL3(NAME, TYPE) \
+ void *NAME(void *tensor, TYPE value, uint64_t *ibase, uint64_t *idata, \
+ uint64_t ioff, uint64_t isize, uint64_t istride, uint64_t *pbase, \
+ uint64_t *pdata, uint64_t poff, uint64_t psize, \
+ uint64_t pstride) { \
+ assert(istride == 1 && pstride == 1 && isize == psize); \
+ uint64_t *indx = idata + ioff; \
+ if (!value) \
+ return tensor; \
+ uint64_t *perm = pdata + poff; \
+ std::vector<uint64_t> indices(isize); \
+ for (uint64_t r = 0; r < isize; r++) \
+ indices[perm[r]] = indx[r]; \
+ static_cast<SparseTensor<TYPE> *>(tensor)->add(indices, value); \
+ return tensor; \
+ }
TEMPLATE(MemRef1DU64, uint64_t);
TEMPLATE(MemRef1DU32, uint32_t);
@@ -510,6 +497,10 @@ enum PrimaryTypeEnum : uint64_t {
/// Constructs a new sparse tensor. This is the "swiss army knife"
/// method for materializing sparse tensors into the computation.
+/// action
+/// 0 : ptr contains filename to read into storage
+/// 1 : ptr contains coordinate scheme to assign to storage
+/// 2 : returns coordinate scheme to fill (call back later with 1)
void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff,
uint64_t asize, uint64_t astride, uint64_t *sbase,
uint64_t *sdata, uint64_t soff, uint64_t ssize,
@@ -518,6 +509,7 @@ void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff,
uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
uint32_t action, void *ptr) {
assert(astride == 1 && sstride == 1 && pstride == 1);
+ assert(asize == ssize && ssize == psize);
uint8_t *sparsity = adata + aoff;
uint64_t *sizes = sdata + soff;
uint64_t *perm = pdata + poff;
@@ -606,18 +598,19 @@ void delSparseTensor(void *tensor) {
delete static_cast<SparseTensorStorageBase *>(tensor);
}
-/// Helper to get pointer, one per value type.
-PTR(getPtrF64)
-PTR(getPtrF32)
-PTR(getPtrI64)
-PTR(getPtrI32)
-PTR(getPtrI16)
-PTR(getPtrI8)
+/// Helper to add value to coordinate scheme, one per value type.
+IMPL3(addEltF64, double)
+IMPL3(addEltF32, float)
+IMPL3(addEltI64, int64_t)
+IMPL3(addEltI32, int32_t)
+IMPL3(addEltI16, int16_t)
+IMPL3(addEltI8, int8_t)
#undef TEMPLATE
#undef CASE
#undef IMPL1
#undef IMPL2
+#undef IMPL3
} // extern "C"
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 5a2e3b1356720..33ddfa67543a3 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -112,24 +112,93 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
return %0 : tensor<?x?x?xf32, #SparseTensor>
}
-// CHECK-LABEL: func @sparse_convert(
+// CHECK-LABEL: func @sparse_convert_1d(
+// CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[D0:.*]] = constant dense<0> : tensor<1xi64>
+// CHECK-DAG: %[[D1:.*]] = constant dense<1> : tensor<1xi8>
+// CHECK-DAG: %[[X:.*]] = tensor.cast %[[D1]] : tensor<1xi8> to tensor<?xi8>
+// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[D0]] : tensor<1xi64> to tensor<?xi64>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
+// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
+// CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
+// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
+// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
+// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Y]])
+// CHECK: }
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
+ %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
+ return %0 : tensor<?xi32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_convert_2d(
// CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8>
// CHECK-DAG: %[[V:.*]] = constant dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64>
-// CHECK: %[[C:.*]] = memref.buffer_cast %arg0 : memref<2x4xf64>
-// CHECK: %[[M:.*]] = memref.cast %[[C]] : memref<2x4xf64> to memref<*xf64>
-// CHECK: %[[C:.*]] = call @getPtrF64(%[[M]]) : (memref<*xf64>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
+// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]]] : tensor<2x4xf64>
+// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<2xindex>
+// CHECK: memref.store %[[J]], %[[M]][%[[C1]]] : memref<2xindex>
+// CHECK: call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+// CHECK: }
+// CHECK: }
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_convert(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
+func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
return %0 : tensor<2x4xf64, #SparseMatrix>
}
+// CHECK-LABEL: func @sparse_convert_3d(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
+// CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<3xi64>
+// CHECK-DAG: %[[W:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
+// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor<?xi8>
+// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor<?xi64>
+// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor<?xi64>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
+// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
+// CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
+// CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] {
+// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]], %[[K]]] : tensor<?x?x?xf64>
+// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<3xindex>
+// CHECK: memref.store %[[J]], %[[M]][%[[C1]]] : memref<3xindex>
+// CHECK: memref.store %[[K]], %[[M]][%[[C2]]] : memref<3xindex>
+// CHECK: call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
+ %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
+ return %0 : tensor<?x?x?xf64, #SparseTensor>
+}
+
// CHECK-LABEL: func @sparse_pointers(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index
More information about the Mlir-commits
mailing list