[Mlir-commits] [mlir] 9d1db3d - [mlir][sparse] generalize sparse_tensor.convert on static/dynamic dimension sizes
Aart Bik
llvmlistbot at llvm.org
Mon Oct 18 13:54:11 PDT 2021
Author: Aart Bik
Date: 2021-10-18T13:54:03-07:00
New Revision: 9d1db3d4a1970ebb88803fdd862ce1d633b46bdc
URL: https://github.com/llvm/llvm-project/commit/9d1db3d4a1970ebb88803fdd862ce1d633b46bdc
DIFF: https://github.com/llvm/llvm-project/commit/9d1db3d4a1970ebb88803fdd862ce1d633b46bdc.diff
LOG: [mlir][sparse] generalize sparse_tensor.convert on static/dynamic dimension sizes
This revison lifts the artificial restriction on having exact matches between
source and destination type shapes. A static size may become dynamic. We still
reject changing a dynamic size into a static size to avoid the need for a
runtime "assert" on the conversion. This revision also refactors some of the
conversion code to share same-content buffers.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D111915
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c7e6e0a408f2c..d1724b4c6f5cc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -80,9 +80,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
string summary = "Converts between
diff erent tensor types";
string description = [{
Converts one sparse or dense tensor type to another tensor type. The rank
- and dimensions of the source and destination types must match exactly,
- only the sparse encoding of these types may be
diff erent. The name `convert`
- was preferred over `cast`, since the operation may incur a non-trivial cost.
+ and dimensions of the source and destination types must match, but the sparse
+ encoding of these types can obviously be
diff erent. The name `convert` was
+ preferred over `cast`, since the operation may incur a non-trivial cost.
When converting between two
diff erent sparse tensor types, only explicitly
stored values are moved from one underlying sparse storage format to
@@ -97,9 +97,14 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
Examples:
```mlir
- %0 = sparse_tensor.convert %1 : tensor<32x32xf32> to tensor<32x32xf32, #CSR>
-
- %2 = sparse_tensor.convert %3 : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR>
+ %0 = sparse_tensor.convert %a : tensor<32x32xf32> to tensor<32x32xf32, #CSR>
+ %1 = sparse_tensor.convert %a : tensor<32x32xf32> to tensor<?x?xf32, #CSR>
+ %2 = sparse_tensor.convert %b : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR>
+ %3 = sparse_tensor.convert %c : tensor<4x8xf64, #CSR> to tensor<4x?xf64, #CSC>
+
+ // The following conversion is not allowed (since it would require a
+ // runtime assertion that the source's dimension size is actually 100).
+ %4 = sparse_tensor.convert %d : tensor<?xf64> to tensor<100xf64, #SV>
```
}];
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8a0e4677d5056..bb499be5052a3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -240,8 +240,11 @@ static LogicalResult verify(ConvertOp op) {
assert(tp1.getRank() == tp2.getRank());
auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
+ // Accept size matches between the source and the destination type
+ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
+ // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
- if (shape1[d] != shape2[d])
+ if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
return op.emitError("unexpected conversion mismatch in dimension ")
<< d;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index e98b1fa261731..ffd852fef7333 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -99,7 +99,7 @@ inline static Value constantZero(ConversionPatternRewriter &rewriter,
/// Generates a constant of `index` type.
inline static Value constantIndex(ConversionPatternRewriter &rewriter,
- Location loc, unsigned i) {
+ Location loc, int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i);
}
@@ -144,6 +144,70 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
return result;
}
+/// Generates dimension size call.
+static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
+ SparseTensorEncodingAttr &enc, Value src,
+ int64_t idx) {
+ // Permute the index according to an optional dimension ordering.
+ if (AffineMap p = enc.getDimOrdering())
+ idx = p.getPermutedPosition(idx);
+ // Generate the call.
+ Location loc = op->getLoc();
+ StringRef name = "sparseDimSize";
+ SmallVector<Value, 2> params;
+ params.push_back(src);
+ params.push_back(constantIndex(rewriter, loc, idx));
+ Type iTp = rewriter.getIndexType();
+ auto fn = getFunc(op, name, iTp, params);
+ return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
+}
+
+/// Generates a call into the "swiss army knife" method of the sparse runtime
+/// support library for materializing sparse tensors into the computation.
+static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
+ ArrayRef<Value> params) {
+ Location loc = op->getLoc();
+ StringRef name = "newSparseTensor";
+ Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
+ auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
+ auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
+ return call.getResult(0);
+}
+
+/// Populates given sizes array from type.
+static void sizesFromType(ConversionPatternRewriter &rewriter,
+ SmallVector<Value, 4> &sizes, Location loc,
+ ShapedType stp) {
+ auto shape = stp.getShape();
+ for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
+ uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
+ sizes.push_back(constantIndex(rewriter, loc, s));
+ }
+}
+
+/// Populates given sizes array from source.
+static void sizesFromSrc(ConversionPatternRewriter &rewriter,
+ SmallVector<Value, 4> &sizes, Location loc,
+ Value src) {
+ ShapedType stp = src.getType().cast<ShapedType>();
+ for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+ sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
+}
+
+/// Populates given sizes array from type (for static sizes) and from
+/// an already converted into opague pointer source (for dynamic sizes).
+static void sizesFromPtr(ConversionPatternRewriter &rewriter,
+ SmallVector<Value, 4> &sizes, Operation *op,
+ SparseTensorEncodingAttr &enc, ShapedType stp,
+ Value src) {
+ auto shape = stp.getShape();
+ for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+ if (shape[i] == ShapedType::kDynamicSize)
+ sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
+ else
+ sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
+}
+
/// Generates a temporary buffer of the given size and type.
static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
unsigned sz, Type tp) {
@@ -152,7 +216,7 @@ static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
}
-/// Fills a temporary buffer of the given type with arguments.
+/// Generates a temporary buffer of the given type and given contents.
static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> values) {
unsigned sz = values.size();
@@ -165,36 +229,28 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
return buffer;
}
-/// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation. The
-/// method returns the call value and assigns the permutation to 'perm'.
-static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
- SparseTensorEncodingAttr &enc, uint32_t action,
- Value &perm, ValueRange szs, Value ptr = Value()) {
+/// Populates parameters required to call the "swiss army knife" method of the
+/// sparse runtime support library for materializing sparse tensors into the
+/// computation.
+static void newParams(ConversionPatternRewriter &rewriter,
+ SmallVector<Value, 8> ¶ms, Operation *op,
+ SparseTensorEncodingAttr &enc, uint32_t action,
+ ValueRange szs, Value ptr = Value()) {
Location loc = op->getLoc();
- ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
- SmallVector<Value, 8> params;
- // Sparsity annotations in tensor constant form.
- SmallVector<Value, 4> attrs;
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
unsigned sz = dlt.size();
+ // Sparsity annotations.
+ SmallVector<Value, 4> attrs;
for (unsigned i = 0; i < sz; i++)
attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
params.push_back(genBuffer(rewriter, loc, attrs));
- // Dimension sizes array of the enveloping *dense* tensor. Useful for either
+ // Dimension sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
- auto shape = resType.getShape();
+ // The index type is casted to I64 for API consistency.
+ Type iTp = rewriter.getI64Type();
SmallVector<Value, 4> sizes;
- if (szs.size() > 0) {
- for (Value s : szs)
- sizes.push_back(
- rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type()));
- } else {
- for (unsigned i = 0; i < sz; i++) {
- uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
- sizes.push_back(constantI64(rewriter, loc, s));
- }
- }
+ for (Value s : szs)
+ sizes.push_back(rewriter.create<arith::IndexCastOp>(loc, s, iTp));
params.push_back(genBuffer(rewriter, loc, sizes));
// Dimension order permutation array. This is the "identity" permutation by
// default, or otherwise the "reverse" permutation of a given ordering, so
@@ -207,9 +263,9 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
for (unsigned i = 0; i < sz; i++)
rev[i] = constantI64(rewriter, loc, i);
}
- perm = genBuffer(rewriter, loc, rev);
- params.push_back(perm);
+ params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
+ ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
@@ -223,12 +279,6 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(constantI32(rewriter, loc, action));
params.push_back(ptr);
- // Generate the call to create new tensor.
- StringRef name = "newSparseTensor";
- auto call = rewriter.create<CallOp>(
- loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
- params);
- return call.getResult(0);
}
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
@@ -299,9 +349,8 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
params.push_back(ind);
params.push_back(perm);
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- rewriter.create<CallOp>(
- loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
- params);
+ auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
+ rewriter.create<CallOp>(loc, pTp, fn, params);
}
/// If the tensor is a sparse constant, generates and returns the pair of
@@ -362,24 +411,17 @@ class SparseTensorToDimSizeConverter
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resType = op.getType();
+ // Only rewrite annotated DimOp with constant index.
auto enc = getSparseTensorEncoding(op.source().getType());
if (!enc)
return failure();
- // Permute the dim index.
Optional<int64_t> index = op.getConstantIndex();
if (!index.hasValue())
return failure();
- int64_t idx = index.getValue();
- if (AffineMap p = enc.getDimOrdering())
- idx = p.getPermutedPosition(idx);
// Generate the call.
- StringRef name = "sparseDimSize";
- SmallVector<Value, 2> params;
- params.push_back(adaptor.getOperands()[0]);
- params.push_back(constantIndex(rewriter, op.getLoc(), idx));
- rewriter.replaceOpWithNewOp<CallOp>(
- op, resType, getFunc(op, name, resType, params), params);
+ Value src = adaptor.getOperands()[0];
+ int64_t idx = index.getValue();
+ rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
return success();
}
};
@@ -394,9 +436,14 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
auto enc = getSparseTensorEncoding(resType);
if (!enc)
return failure();
- Value perm;
- rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {},
- adaptor.getOperands()[0]));
+ // Generate the call to construct tensor from ptr. The sizes are
+ // inferred from the result type of the new operator.
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 8> params;
+ sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
+ Value ptr = adaptor.getOperands()[0];
+ newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
+ rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
};
@@ -411,9 +458,11 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
auto enc = getSparseTensorEncoding(resType);
if (!enc)
return failure();
- Value perm;
- rewriter.replaceOp(
- op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands()));
+ // Generate the call to construct empty tensor. The sizes are
+ // explicitly defined by the arguments to the init operator.
+ SmallVector<Value, 8> params;
+ newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
+ rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
};
@@ -424,10 +473,12 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
Type resType = op.getType();
+ Type srcType = op.source().getType();
auto encDst = getSparseTensorEncoding(resType);
- auto encSrc = getSparseTensorEncoding(op.source().getType());
- auto src = adaptor.getOperands()[0];
+ auto encSrc = getSparseTensorEncoding(srcType);
+ Value src = adaptor.getOperands()[0];
if (encDst && encSrc) {
// This is a sparse => sparse conversion, which is handled as follows:
// t = src->toCOO(); ; src to COO in dst order
@@ -435,10 +486,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// Using the coordinate scheme as an intermediate does not always
// yield the fastest conversion but avoids the need for a full
// O(N^2) conversion matrix.
- Value perm;
- Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src);
- rewriter.replaceOp(
- op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo));
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 8> params;
+ sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
+ src);
+ newParams(rewriter, params, op, encDst, kToCOO, sizes, src);
+ Value coo = genNewCall(rewriter, op, params);
+ params[6] = constantI32(rewriter, loc, kFromCOO);
+ params[7] = coo;
+ rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
if (!encDst || encSrc) {
@@ -471,12 +527,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// Also note that the code below only generates the "new" ops and
// the loop-nest per se; whereas the entire body of the innermost
// loop is generated by genAddElt().
- Location loc = op->getLoc();
- ShapedType shape = resType.cast<ShapedType>();
- Value perm;
- Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {});
- Value ind =
- genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType());
+ ShapedType stp = resType.cast<ShapedType>();
+ unsigned rank = stp.getRank();
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 8> params;
+ sizesFromSrc(rewriter, sizes, loc, src);
+ newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
+ Value ptr = genNewCall(rewriter, op, params);
+ Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
+ Value perm = params[2];
SmallVector<Value> lo;
SmallVector<Value> hi;
SmallVector<Value> st;
@@ -493,14 +552,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
st.push_back(one);
} else {
- for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+ for (unsigned i = 0; i < rank; i++) {
lo.push_back(zero);
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
st.push_back(one);
}
}
- Type eltType = shape.getElementType();
- unsigned rank = shape.getRank();
+ Type eltType = stp.getElementType();
scf::buildLoopNest(
rewriter, op.getLoc(), lo, hi, st, {},
[&](OpBuilder &builder, Location loc, ValueRange ivs,
@@ -514,8 +572,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
return {};
});
- rewriter.replaceOp(
- op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr));
+ // Final call to construct sparse tensor storage.
+ params[6] = constantI32(rewriter, loc, kFromCOO);
+ params[7] = ptr;
+ rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
};
@@ -529,9 +589,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef name = "delSparseTensor";
TypeRange none;
- rewriter.create<CallOp>(op.getLoc(), none,
- getFunc(op, name, none, adaptor.getOperands()),
- adaptor.getOperands());
+ auto fn = getFunc(op, name, none, adaptor.getOperands());
+ rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
rewriter.eraseOp(op);
return success();
}
@@ -560,11 +619,9 @@ class SparseTensorToPointersConverter
name = "sparsePointers8";
else
return failure();
- rewriter.replaceOpWithNewOp<CallOp>(op, resType,
- getFunc(op, name, resType,
- adaptor.getOperands(),
- /*emitCInterface=*/true),
- adaptor.getOperands());
+ auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
+ rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
return success();
}
};
@@ -591,11 +648,9 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
name = "sparseIndices8";
else
return failure();
- rewriter.replaceOpWithNewOp<CallOp>(op, resType,
- getFunc(op, name, resType,
- adaptor.getOperands(),
- /*emitCInterface=*/true),
- adaptor.getOperands());
+ auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
+ rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
return success();
}
};
@@ -624,11 +679,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
name = "sparseValuesI8";
else
return failure();
- rewriter.replaceOpWithNewOp<CallOp>(op, resType,
- getFunc(op, name, resType,
- adaptor.getOperands(),
- /*emitCInterface=*/true),
- adaptor.getOperands());
+ auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
+ rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index d6e43079d8c0b..577b79c6e9b0c 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -127,8 +127,8 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
// CHECK-DAG: %[[JJ:.*]] = arith.index_cast %[[J]] : index to i64
// CHECK-DAG: memref.store %[[II]], %[[Q]][%[[C0]]] : memref<2xi64>
// CHECK-DAG: memref.store %[[JJ]], %[[Q]][%[[C1]]] : memref<2xi64>
-// CHECK: %[[A:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+// CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
%0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf64, #SparseMatrix>
@@ -156,22 +156,23 @@ func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32,
// CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
// CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
// CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
// CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
// CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
-// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
// CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
// CHECK: }
-// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
@@ -180,8 +181,14 @@ func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
// CHECK-LABEL: func @sparse_convert_1d_ss(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-// CHECK: %[[C:.*]] = call @newSparseTensor(%{{.}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
-// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
+// CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
+// CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+// CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
+// CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
@@ -198,7 +205,8 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
// CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
-// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
// CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
@@ -209,7 +217,7 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
// CHECK: call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
// CHECK: }
// CHECK: }
-// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
@@ -226,7 +234,8 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
// CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
-// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
// CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -235,7 +244,7 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
// CHECK: %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32>
// CHECK: call @addEltF32(%{{.*}}, %[[V]], %[[N]], %{{.*}})
// CHECK: }
-// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
// Initialize a tensor.
@@ -250,18 +259,19 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
+// CHECK-DAG: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
+// CHECK-DAG: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
// CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<3xi8>
// CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<3xi64>
// CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<3xi64>
// CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<3xi8> to memref<?xi8>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xi64> to memref<?xi64>
// CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xi64> to memref<?xi64>
-// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+// CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
// CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
// CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
-// CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
-// CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] {
// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] {
@@ -273,7 +283,7 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 03abcb5e82d0b..89553595c8142 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -162,8 +162,8 @@ func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
-func @sparse_convert_mismatch(%arg0: tensor<10x10xf32>) -> tensor<10x?xf32, #CSR> {
+func @sparse_convert_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10xf32, #CSR> {
// expected-error at +1 {{unexpected conversion mismatch in dimension 1}}
- %0 = sparse_tensor.convert %arg0 : tensor<10x10xf32> to tensor<10x?xf32, #CSR>
- return %0 : tensor<10x?xf32, #CSR>
+ %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR>
+ return %0 : tensor<10x10xf32, #CSR>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir
new file mode 100644
index 0000000000000..3bfd0df72d6f2
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s \
+// RUN: --sparsification --sparse-tensor-conversion \
+// RUN: --linalg-bufferize --convert-linalg-to-loops \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#DCSC = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Integration test that tests conversions between sparse tensors,
+// where the dynamic sizes of the shape of the enveloping tensor
+// may change (the actual underlying sizes obviously never change).
+//
+module {
+
+ //
+ // Helper method to print values array. The transfer actually
+ // reads more than required to verify size of buffer as well.
+ //
+ func @dump(%arg0: memref<?xf64>) {
+ %c = arith.constant 0 : index
+ %d = arith.constant -1.0 : f64
+ %0 = vector.transfer_read %arg0[%c], %d: memref<?xf64>, vector<8xf64>
+ vector.print %0 : vector<8xf64>
+ return
+ }
+
+ func @entry() {
+ %t1 = arith.constant sparse<
+ [ [0,0], [0,1], [0,63], [1,0], [1,1], [31,0], [31,63] ],
+ [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 ]> : tensor<32x64xf64>
+ %t2 = tensor.cast %t1 : tensor<32x64xf64> to tensor<?x?xf64>
+
+ // Four dense to sparse conversions.
+ %1 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<?x?xf64, #DCSR>
+ %2 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<?x?xf64, #DCSC>
+ %3 = sparse_tensor.convert %t2 : tensor<?x?xf64> to tensor<?x?xf64, #DCSR>
+ %4 = sparse_tensor.convert %t2 : tensor<?x?xf64> to tensor<?x?xf64, #DCSC>
+
+ // Two cross conversions.
+ %5 = sparse_tensor.convert %3 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64, #DCSC>
+ %6 = sparse_tensor.convert %4 : tensor<?x?xf64, #DCSC> to tensor<?x?xf64, #DCSR>
+
+ //
+ // All proper row-/column-wise?
+ //
+ // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+ // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+ // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+ // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+ // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+ // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+ //
+ %m1 = sparse_tensor.values %1 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+ %m2 = sparse_tensor.values %2 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+ %m3 = sparse_tensor.values %3 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+ %m4 = sparse_tensor.values %4 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+ %m5 = sparse_tensor.values %5 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+ %m6 = sparse_tensor.values %6 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+ call @dump(%m1) : (memref<?xf64>) -> ()
+ call @dump(%m2) : (memref<?xf64>) -> ()
+ call @dump(%m3) : (memref<?xf64>) -> ()
+ call @dump(%m4) : (memref<?xf64>) -> ()
+ call @dump(%m5) : (memref<?xf64>) -> ()
+ call @dump(%m6) : (memref<?xf64>) -> ()
+
+ // Release the resources.
+ sparse_tensor.release %1 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %2 : tensor<?x?xf64, #DCSC>
+ sparse_tensor.release %3 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %4 : tensor<?x?xf64, #DCSC>
+ sparse_tensor.release %5 : tensor<?x?xf64, #DCSC>
+ sparse_tensor.release %6 : tensor<?x?xf64, #DCSR>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list