[Mlir-commits] [mlir] b2e6b73 - [mlir][sparse] extend unpack operation to unpack arbitrary encodings.
Peiming Liu
llvmlistbot at llvm.org
Tue May 23 15:34:06 PDT 2023
Author: Peiming Liu
Date: 2023-05-23T22:34:01Z
New Revision: b2e6b7354452c10ed6f38958253fd76aca0877de
URL: https://github.com/llvm/llvm-project/commit/b2e6b7354452c10ed6f38958253fd76aca0877de
DIFF: https://github.com/llvm/llvm-project/commit/b2e6b7354452c10ed6f38958253fd76aca0877de.diff
LOG: [mlir][sparse] extend unpack operation to unpack arbitrary encodings.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D151174
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 865c1aa38f61f..e37062f5f8104 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -73,7 +73,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
This operation can be used to materialize a sparse tensor from external
sources; e.g., when passing two numpy arrays from Python.
- Disclaimer: This is users' responsibility to provide input that can be
+ Disclaimer: This is the user's responsibility to provide input that can be
correctly interpreted by the sparse compiler, which does not perform
any sanity test during runtime to verify data integrity.
@@ -102,29 +102,25 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
let hasVerifier = 1;
}
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
Arguments<(ins AnySparseTensor:$tensor,
- OptionalAttr<IndexAttr>:$batched_lvls)>,
- Results<(outs TensorOf<[AnyType]>:$values,
- TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
- AnySignlessIntegerOrIndex:$nse)> {
+ TensorOf<[AnyType]>:$out_values,
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
+ Results<(outs TensorOf<[AnyType]>:$ret_values,
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
let description = [{
The unpack operation is the inverse of `sparse_tensor::pack`. It returns
- the values, level-coordinates, and number-of-stored-entries extracted
- from the sparse tensor. The source tensor is allowed (in principle)
- to have non-identity dimOrdering/higherOrdering mappings. Regardless
- of the mappings, the returned `coordinates` are always level-coordinates,
- because this is what we mean by "unpacking" as opposed to other forms
- of exposing sparse tensors to external clients. This operation can be
- used for returning an unpacked MLIR sparse tensor to frontend; e.g.,
- returning two numpy arrays to Python.
+ the values and per-level position and coordinate array to the user
+ from the sparse tensor. This operation can be used for returning an
+ unpacked MLIR sparse tensor to frontend; e.g., returning two numpy arrays to Python.
- TODO: the current implementation does not yet support non-identity mappings.
+ Disclaimer: This is the user's responsibility to allocate large enough buffers
+ to hold the sparse tensor. The sparse compiler simply copies each fields
+ of the sparse tensor into the user-supplied buffer without bound checking.
- This operation ends the lifetime of the sparse tensor, and using
- the tensor after the unpack is undefined behavior.
+ TODO: the current implementation does not yet support non-identity mappings.
Example:
@@ -132,51 +128,18 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
// input COO format |1.1, 0.0, 0.0, 0.0|
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
// |0.0, 0.0, 0.0, 0.0|
- %values, %coordinates, %nse
- = sparse_tensor.unpack %st
- : tensor<3x4xf64, #COO> to tensor<2xf64>, tensor<2x2xindex>, index
+ %values, %pos, %coords = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
+ outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
+ -> tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
// %values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
+ // %pos = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
- // %nse = 3
```
-
- If `batched_lvls` is provided, the operation unpacks each batch of the tensors
- separately. The returned `nse` is the maximum nse of all batches. For a batch with
- a smaller nse, trailing zeros are appended in the result.
- Example:
-
- ```mlir
- // input BCOO format |1.1, 2.2, 3.3, 0.0|
- // of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
- %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
- : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
- // %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ],
- // [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64>
- // %coordinates = arith.constant dense<[[ [0], [1], [2] ],
- // [ [1], [2], [0] ]> : tensor<2x3x1xindex>
- ```
- }];
-
- let extraClassDeclaration = [{
- /// Returns the number of leading levels that are batched.
- unsigned getNumBatchedLvls();
}];
- let builders = [
- OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
- [{
- build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
- }]>,
- OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
- [{
- build($_builder, $_state, resultTypes, tensor, nullptr);
- }]>
- ];
-
-
let assemblyFormat =
- "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
- "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";
+ "$tensor `:` type($tensor) `outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)`"
+ "attr-dict `->` type($ret_values) `,` type($ret_levels)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 41805107588a4..0ecc77f228e42 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -786,65 +786,6 @@ static LogicalResult verifySparsifierGetterSetter(
return success();
}
-// DEPRECATED: This function is deprecated! Remove it after unpack supports
-// arbitrary sparse encoding.
-static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
- SparseTensorType tensorTp,
- RankedTensorType valuesTp,
- RankedTensorType coordinatesTp,
- IntegerAttr batchedLvls) {
- unsigned nBatched = batchedLvls ? batchedLvls.getValue().getZExtValue() : 0;
- if (requiresStaticShape && !tensorTp.hasStaticDimShape())
- return op->emitError("the sparse-tensor must have static shape");
- if (!tensorTp.hasEncoding())
- return op->emitError("the sparse-tensor must have an encoding attribute");
- if (!tensorTp.isIdentity())
- return op->emitError("the sparse-tensor must have the identity mapping");
- if (!isCOOType(tensorTp.getEncoding(), nBatched, true))
- return op->emitError("the sparse-tensor must have a COO type");
-
- if (coordinatesTp.getRank() != 2 + nBatched)
- return op->emitError("coordinates must have rank 2 + batched_lvls");
- if (requiresStaticShape && !coordinatesTp.hasStaticShape())
- return op->emitError("coordinates must have static shape");
- if (coordinatesTp.getElementType() != tensorTp.getCrdType())
- return op->emitError("input/output coordinate-types don't match");
-
- if (valuesTp.getRank() != 1 + nBatched)
- return op->emitError("values must have rank 1 + batched_lvls");
- if (requiresStaticShape && !valuesTp.hasStaticShape())
- return op->emitError("values must have static shape");
- if (valuesTp.getElementType() != tensorTp.getElementType())
- return op->emitError("input/output element-types don't match");
-
- for (unsigned i = 0; i < nBatched; i++) {
- const auto valBatch = valuesTp.getShape()[i];
- const auto crdBatch = coordinatesTp.getShape()[i];
- if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) ||
- crdBatch != valBatch) {
- return op->emitError(
- "values/coordinates batched level sizes don't match statically");
- }
- }
-
- const auto valuesNSE = valuesTp.getShape()[nBatched];
- const auto coordsNSE = coordinatesTp.getShape()[nBatched];
- if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) &&
- valuesNSE != coordsNSE)
- return op->emitError("values/coordinates number-of-elements don't match");
-
- // NOTE: We use `getLvlRank` because the `coordinatesTp` is for
- // level-coordinates (cf., the op documentation).
- const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched];
- const Level tensorRank = tensorTp.getLvlRank();
- // FIXME: replace the `operator!=` with our backported `safelyNE`.
- if (!ShapedType::isDynamic(coordsRank) &&
- coordsRank != static_cast<DynSize>(tensorRank) - nBatched)
- return op->emitError("input/output level-ranks don't match");
-
- return success();
-}
-
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
switch (kind) {
case SparseTensorFieldKind::CrdMemRef:
@@ -925,15 +866,17 @@ LogicalResult PackOp::verify() {
}
LogicalResult UnpackOp::verify() {
- const auto valuesTp = getRankedTensorType(getValues());
- const auto coordinatesTp = getRankedTensorType(getCoordinates());
- const auto srcTp = getSparseTensorType(getTensor());
- return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
- getBatchedLvlsAttr());
-}
+ if (getOutValues().getType() != getRetValues().getType())
+ return emitError("output values and return value type mismatch");
-unsigned UnpackOp::getNumBatchedLvls() {
- return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
+ for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
+ if (ot.getType() != rt.getType())
+ return emitError("output levels and return levels type mismatch");
+
+ const auto valuesTp = getRankedTensorType(getRetValues());
+ const auto lvlsTp = getRetLevels().getTypes();
+ const auto srcTp = getSparseTensorType(getTensor());
+ return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
}
LogicalResult ConvertOp::verify() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f17c001308bf0..e712c9396466b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -153,28 +153,32 @@ struct UnpackOpInterface
: public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
sparse_tensor::UnpackOp> {
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
- // We allocate and return unpacked memory if this is a batched unpack.
- // When the number of batched levels equals to zero, we reuse the
- // coordinates/values memref (and reallocation if the requested output size
- // is larger than the actual size). Similar to InsertOp, reallocation is
- // not considered to allocate a new piece of memory.
- return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
+ // The output buffer is pre-allocated by the user.
+ return false;
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return true;
+ // The first operand is the sparse tensor that we are unpacking.
+ return opOperand.getOperandNumber() == 0;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return false;
+ // We write into the output operand.
+ assert(op->getNumOperands() == op->getNumResults() + 1);
+ return opOperand.getOperandNumber() > 0;
}
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- // Conceptually, UnpackOp equals to a list of toCoordinates/toValueOp
- return {};
+ assert(op->getNumOperands() == op->getNumResults() + 1);
+
+ if (opOperand.getOperandNumber() == 0)
+ return {};
+ // We write directly into the output tensors and returns them.
+ return {{op->getResult(opOperand.getOperandNumber() - 1),
+ BufferRelation::Equivalent}};
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4f2e18f43c117..7d4efa8961eb5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -539,48 +539,27 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
}
-/// Returns a memref that fits the requested length (reallocates if requested
-/// length is larger, or creates a subview if it is smaller).
-static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
- Value buffer) {
- MemRefType memTp = getMemRefType(buffer);
- auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
-
- Value targetLen = constantIndex(builder, loc, len);
- Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
- // Reallocates if target length is greater than the actual buffer len.
- Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
- targetLen, bufferLen);
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
- // If targetLen > bufferLen, reallocate to get enough sparse to return.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
- builder.create<scf::YieldOp>(loc, reallocBuf);
- // Else, return a subview to fit the size.
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value subViewBuf = builder.create<memref::SubViewOp>(
- loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
- /*size=*/ArrayRef<int64_t>{len},
- /*stride=*/ArrayRef<int64_t>{1});
- builder.create<scf::YieldOp>(loc, subViewBuf);
- // Resets insertion point.
- builder.setInsertionPointAfter(ifOp);
- return ifOp.getResult(0);
+static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
+ Value tensor) {
+ auto tTp = tensor.getType().cast<TensorType>();
+ auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
+ return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
+ .getResult();
}
-static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
- ValueRange bounds) {
- assert(ivs.size() == bounds.size());
- Value crd = constantIndex(builder, loc, 0);
- for (unsigned i = 0, e = ivs.size(); i < e; i++) {
- crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
- if (i != ivs.size() - 1)
- crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
- }
- return crd;
+Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
+ auto elemTp = mem.getType().cast<MemRefType>().getElementType();
+ return builder
+ .create<memref::SubViewOp>(
+ loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
+ ValueRange{}, ValueRange{sz}, ValueRange{},
+ ArrayRef<int64_t>{0}, // static offset
+ ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
+ ArrayRef<int64_t>{1}) // static stride
+ .getResult();
}
-ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
+static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
ReassociationIndices reassociation;
for (int i = 0, e = srcTp.getRank(); i < e; i++)
reassociation.push_back(i);
@@ -1243,23 +1222,21 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
} else {
// Else simply takes the inputs.
- Value field = fKind == SparseTensorFieldKind::ValMemRef
- ? op.getValues()
- : op.getLevels()[fIdx];
-
- auto tensorType = field.getType().cast<RankedTensorType>();
- auto memrefType = MemRefType::get(tensorType.getShape(),
- tensorType.getElementType());
- field = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, field);
- if (memrefType.getRank() > 1) {
+ Value tensor = fKind == SparseTensorFieldKind::ValMemRef
+ ? op.getValues()
+ : op.getLevels()[fIdx];
+
+ TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
+ if (mem.getType().getRank() > 1) {
// Flattens the buffer to rank 1.
- auto reassoc = getReassociationForFlattening(memrefType);
- field =
- rewriter.create<memref::CollapseShapeOp>(loc, field, reassoc);
+ auto reassoc = getReassociationForFlattening(mem.getType());
+ mem = rewriter.create<memref::CastOp>(
+ loc, fType,
+ rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
+ } else {
+ mem = rewriter.create<memref::CastOp>(loc, fType, mem);
}
- field = rewriter.create<memref::CastOp>(loc, fType, field);
- fields.push_back(field);
+ fields.push_back(mem);
}
return true;
});
@@ -1269,6 +1246,9 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
Value c2 = constantIndex(rewriter, loc, 2);
Value posBack = c1; // index to the last value in the postion array
Value memSize = c2; // memory size for current array
+
+ Level trailCOOStart = getCOOStart(stt.getEncoding());
+ Level trailCOORank = stt.getLvlRank() - trailCOOStart;
// Sets up SparseTensorSpecifier.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
@@ -1277,6 +1257,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
// Sets up the level size.
auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
desc.setLvlSize(rewriter, loc, lvl, lvlSize);
+ // We use a single AOS array to store the trailing COO, so there is only
+ // one memory size to set for the entire COO section.
+ if (lvl > trailCOOStart)
+ continue;
// Sets up the memory size by reading the last value in position array.
DimLevelType dlt = stt.getLvlType(lvl);
@@ -1298,8 +1282,15 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
}
- assert(isDLTWithCrd(dlt));
- desc.setCrdMemSize(rewriter, loc, lvl, memSize);
+ assert(isDLTWithCrd(dlt) && lvl <= trailCOOStart);
+ // FIXME: This seems to be unnecessarily complex, can we simplify it?
+ if (lvl == trailCOOStart) {
+ Value cooSz = rewriter.create<arith::MulIOp>(
+ loc, memSize, constantIndex(rewriter, loc, trailCOORank));
+ desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
+ } else {
+ desc.setCrdMemSize(rewriter, loc, lvl, memSize);
+ }
}
desc.setValMemSize(rewriter, loc, memSize);
@@ -1308,166 +1299,6 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
};
-static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
- SparseTensorDescriptor desc,
- ConversionPatternRewriter &rewriter) {
- Location loc = op.getLoc();
- const auto srcTp = getSparseTensorType(op.getTensor());
- const Level lvlRank = srcTp.getLvlRank();
- Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
- : desc.getAOSMemRef();
- Value valuesBuf = desc.getValMemRef();
-
- // If frontend requests a static buffer, we reallocate the
- // values/coordinates to ensure that we meet their need.
- const auto valuesTp = getRankedTensorType(op.getValues());
- if (valuesTp.hasStaticShape()) {
- // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
- // tensor that is packed from constants.
- valuesBuf =
- reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
- }
-
- const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
- if (coordinatesTp.hasStaticShape()) {
- // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
- // tensor that is packed from constants.
- auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
- flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
- }
-
- Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
- loc,
- MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
- flatBuf, ArrayRef{ReassociationIndices{0, 1}});
-
- // Converts MemRefs back to Tensors.
- Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
- Value coordinates =
- rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
- Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
- op.getNse().getType());
-
- rewriter.replaceOp(op, {values, coordinates, nse});
- return success();
-}
-
-static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
- SparseTensorDescriptor desc,
- ConversionPatternRewriter &rewriter) {
- assert(nBatched != 0);
- Location loc = op.getLoc();
- Value c0 = constantIndex(rewriter, loc, 0);
- Value c1 = constantIndex(rewriter, loc, 1);
- Value c2 = constantIndex(rewriter, loc, 2);
-
- auto genZeroedAlloc = [loc,
- &rewriter](TensorType tt) -> TypedValue<MemRefType> {
- auto mem = rewriter
- .create<memref::AllocOp>(
- loc, MemRefType::get(tt.getShape(), tt.getElementType()))
- .getMemref();
- // TODO: Instead of filling the entire buffer, we can only fill the
- // trailing zeros.
- rewriter.create<linalg::FillOp>(
- loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
- return mem;
- };
- SparseTensorType stt = getSparseTensorType(op.getTensor());
- TensorType valTensorTp = op.getValues().getType();
- TensorType crdTensorTp = op.getCoordinates().getType();
- TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
- TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
- assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
-
- SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
- SmallVector<Value> ubs;
- for (unsigned i = 0; i < nBatched; i++) {
- assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
- ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
- }
-
- DimLevelType dlt = stt.getLvlType(nBatched);
- assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
- Value posStep = isCompressedDLT(dlt) ? c1 // forward position index by 1
- : c2; // forward position index by 2
- auto loopNest = scf::buildLoopNest(
- rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
- [&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
- &crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
- ValueRange args) -> scf::ValueVector {
- // crdMemref has shape: <... x nse x rank>
- unsigned unBatchedRank = crdMemref.getType().getShape().back();
- Value values = desc.getValMemRef();
- Value flatCrds = unBatchedRank == 1
- ? desc.getCrdMemRefOrView(builder, loc, 0)
- : desc.getAOSMemRef();
-
- Value positions = desc.getPosMemRef(nBatched);
- Value positLo = builder.create<arith::MulIOp>(
- loc, linearize(builder, loc, ivs, ubs), posStep);
- Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
-
- Value pLo = genIndexLoad(builder, loc, positions, positLo);
- Value pHi = genIndexLoad(builder, loc, positions, positHi);
- Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
-
- Value crdLo = builder.create<arith::MulIOp>(
- loc, pLo, constantIndex(builder, loc, unBatchedRank));
- Value nCrd = builder.create<arith::MulIOp>(
- loc, nse, constantIndex(builder, loc, unBatchedRank));
-
- SmallVector<Value> offsets, sizes, strides;
- for (unsigned i = 0; i < nBatched; i++) {
- offsets.push_back(ivs[i]);
- sizes.push_back(c1);
- strides.push_back(c1);
- }
- // [0, nse, 1].
- offsets.push_back(c0);
- sizes.push_back(nse);
- strides.push_back(c1);
-
- auto valView = builder.create<memref::SubViewOp>(
- loc, valMemref, offsets, sizes, strides);
- auto valReass = getReassociationForFlattening(valView.getType());
- Value valDst =
- builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
- Value valSrc =
- builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
- builder.create<memref::CopyOp>(loc, valSrc, valDst);
-
- // [0, rank, 1].
- offsets.push_back(c0);
- sizes.push_back(constantIndex(builder, loc, unBatchedRank));
- strides.push_back(c1);
-
- auto crdView = builder.create<memref::SubViewOp>(
- loc, crdMemref, offsets, sizes, strides);
- auto crdReass = getReassociationForFlattening(crdView.getType());
- Value crdDst =
- builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
- Value crdSrc =
- builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
- builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
-
- Value pred = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ugt, nse, args[0]);
- // Choose the larger NSE
- return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
- });
-
- // Converts MemRefs back to Tensors.
- Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
- Value coordinates =
- rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
- Value nse =
- genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
-
- rewriter.replaceOp(op, {values, coordinates, nse});
- return success();
-}
-
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context)
@@ -1477,13 +1308,56 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- const auto srcTp = getSparseTensorType(op.getTensor());
- const unsigned nBatched = op.getNumBatchedLvls();
- assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
- desc.getFields().size() == 4); // specifier + pos + crds + values
- (void)srcTp;
- return nBatched == 0 ? genUnBatchedUnpackOp(op, desc, rewriter)
- : genBatchedUnpackOp(op, nBatched, desc, rewriter);
+ Location loc = op.getLoc();
+ SmallVector<Value> retMem;
+ desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
+ FieldIndex fid,
+ SparseTensorFieldKind fKind, Level lvl,
+ DimLevelType dlt) -> bool {
+ if (fKind == SparseTensorFieldKind::StorageSpec)
+ return true;
+ SparseTensorType stt(desc.getRankedTensorType());
+ Value sz, src;
+ TypedValue<BaseMemRefType> dst;
+ if (fKind == SparseTensorFieldKind::ValMemRef) {
+ sz = desc.getValMemSize(rewriter, loc);
+ src = desc.getValMemRef();
+ dst = genToMemref(rewriter, loc, op.getOutValues());
+ // Values is the last field in descriptor, but it is the first
+ // operand in unpack operation.
+ // TODO: maybe change unpack/pack operation instead to be
+ // consistent.
+ retMem.insert(retMem.begin(), dst);
+ } else {
+ assert(fKind == SparseTensorFieldKind::PosMemRef ||
+ fKind == SparseTensorFieldKind::CrdMemRef);
+
+ sz = fKind == SparseTensorFieldKind::PosMemRef
+ ? desc.getPosMemSize(rewriter, loc, lvl)
+ : desc.getCrdMemSize(rewriter, loc, lvl);
+ src = desc.getMemRefField(fid);
+ dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
+ retMem.push_back(dst);
+ }
+ Value flatOut = dst;
+ if (dst.getType().getRank() != 1) {
+ auto reassoc = getReassociationForFlattening(dst.getType());
+ flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
+ }
+ Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
+ Value srcMem = genSliceToSize(rewriter, loc, src, sz);
+ rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
+ return true;
+ });
+
+ // Converts MemRefs back to Tensors.
+ SmallVector<Value> retTensor = llvm::to_vector(
+ llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
+ return rewriter.create<bufferization::ToTensorOp>(loc, v);
+ }));
+
+ rewriter.replaceOp(op, retTensor);
+ return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index 9b18394dee7e2..cf7532a522fa7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -156,6 +156,7 @@ class SparseTensorDescriptorImpl {
RankedTensorType getRankedTensorType() const { return rType; }
ValueArrayRef getFields() const { return fields; }
+ StorageLayout getLayout() const { return layout; }
protected:
SparseTensorType rType;
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3531400f75e81..c1e8afd9206ba 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -56,50 +56,38 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
// -----
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], posWidth=32, crdWidth=32}>
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
- -> (tensor<6xf64>, tensor<6x1xi32>, i32) {
+func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
// expected-error at +1 {{input/output element-types don't match}}
- %values, %coordinates, %nse = sparse_tensor.unpack %sp
- : tensor<100xf32, #SparseVector> to tensor<6xf64>, tensor<6x1xi32>, i32
- return %values, %coordinates, %nse : tensor<6xf64>, tensor<6x1xi32>, i32
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
-
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
- -> (tensor<5xf32>, tensor<6x1xi32>, i32) {
- // expected-error at +1 {{values/coordinates number-of-elements don't match}}
- %values, %coordinates, %nse = sparse_tensor.unpack %sp
- : tensor<100xf32, #SparseVector> to tensor<5xf32>, tensor<6x1xi32>, i32
- return %values, %coordinates, %nse : tensor<5xf32>, tensor<6x1xi32>, i32
+ %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
+ outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
+ -> tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+ return
}
// -----
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed-nu", "singleton"], posWidth=32, crdWidth=32}>
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
- -> (tensor<6xf32>, tensor<6x2xi32>, i32) {
- // expected-error at +1 {{input/output level-ranks don't match}}
- %values, %coordinates, %nse = sparse_tensor.unpack %sp
- : tensor<100xf32, #SparseVector> to tensor<6xf32>, tensor<6x2xi32>, i32
- return %values, %coordinates, %nse : tensor<6xf32>, tensor<6x2xi32>, i32
+func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
+ // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
+ %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
+ outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
+ -> tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>
+ return
}
// -----
-#BCOO = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
+#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"], posWidth=32, crdWidth=32}>
-func.func @invalid_unpack_type(%sp: tensor<2x100xf32, #BCOO>)
- -> (tensor<2x6xf32>, tensor<3x6x2xi32>, i32) {
- // expected-error at +1 {{values/coordinates batched level sizes don't match statically}}
- %values, %coordinates, %nse = sparse_tensor.unpack %sp batched_lvls=1
- : tensor<2x100xf32, #BCOO> to tensor<2x6xf32>, tensor<3x6x2xi32>, i32
- return %values, %coordinates, %nse : tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
+ // expected-error at +1 {{inconsistent number of fields between input/output}}
+ %rv, %rc = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
+ outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
+ -> tensor<6xf64>, tensor<6xi32>
+ return
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 41cc5e775c98c..57dff1e53edc3 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -33,28 +33,20 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
// CHECK-LABEL: func @sparse_unpack(
// CHECK-SAME: %[[T:.*]]: tensor<100xf64, #
-// CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]]
-// CHECK: return %[[D]], %[[I]], %[[N]]
-func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>)
- -> (tensor<6xf64>, tensor<6x1xi32>, i32) {
- %data, %indices, %nnz = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
- to tensor<6xf64>, tensor<6x1xi32>, i32
- return %data, %indices, %nnz : tensor<6xf64>, tensor<6x1xi32>, i32
-}
-
-// -----
-
-#BatchedSparseVector = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
-
-// CHECK-LABEL: func @sparse_unpack(
-// CHECK-SAME: %[[T:.*]]: tensor<2x100xf64, #
-// CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]] batched_lvls = 1
-// CHECK: return %[[D]], %[[I]], %[[N]]
-func.func @sparse_unpack(%sp : tensor<2x100xf64, #BatchedSparseVector>)
- -> (tensor<2x6xf64>, tensor<2x6x1xi32>, i32) {
- %data, %indices, %nnz = sparse_tensor.unpack %sp batched_lvls=1
- : tensor<2x100xf64, #BatchedSparseVector> to tensor<2x6xf64>, tensor<2x6x1xi32>, i32
- return %data, %indices, %nnz : tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
+// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
+// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
+// CHECK: %[[D:.*]], %[[P:.*]]:2 = sparse_tensor.unpack %[[T]]
+// CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
+func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
+ %od : tensor<6xf64>,
+ %op : tensor<2xindex>,
+ %oi : tensor<6x1xi32>)
+ -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
+ %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
+ outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
+ -> tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+ return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 1d948cbd604fd..09ba910fc3cfc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -23,9 +23,9 @@
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_13]]
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_12]]
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_16]]
-// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] lvl_sz at 1 with %[[VAL_13]]
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] crd_mem_sz at 1 with %[[VAL_16]]
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_13]]
// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_16]]
// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
// CHECK: }
@@ -40,36 +40,38 @@ func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinate
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[VAL_3:.*]]
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 6 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
-// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
-// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
-// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
-// CHECK: scf.yield %[[VAL_9]] : memref<6xf64>
-// CHECK: } else {
-// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref<?xf64> to memref<6xf64>
-// CHECK: scf.yield %[[VAL_10]] : memref<6xf64>
-// CHECK: }
-// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index
-// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
-// CHECK: %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index
-// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
-// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
-// CHECK: scf.yield %[[VAL_15]] : memref<12xi32>
-// CHECK: } else {
-// CHECK: %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref<?xi32> to memref<12xi32>
-// CHECK: scf.yield %[[VAL_16]] : memref<12xi32>
-// CHECK: }
-// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
-// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
-// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier
-// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ lvlTypes = [ "compressed", "singleton" ] }>>,
+// CHECK-SAME: %[[VAL_4:.*]]: tensor<6xf64>,
+// CHECK-SAME: %[[VAL_5:.*]]: tensor<2xindex>,
+// CHECK-SAME: %[[VAL_6:.*]]: tensor<6x2xi32>) -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] pos_mem_sz at 0
+// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_5]] : memref<2xindex>
+// CHECK: %[[VAL_9:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_7]]] [1] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_7]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK: memref.copy %[[VAL_10]], %[[VAL_9]] : memref<?xindex> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] crd_mem_sz at 0
+// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_6]] : memref<6x2xi32>
+// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_13]][0] {{\[}}%[[VAL_11]]] [1] : memref<12xi32> to memref<?xi32>
+// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xi32> to memref<?xi32>
+// CHECK: memref.copy %[[VAL_15]], %[[VAL_14]] : memref<?xi32> to memref<?xi32>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz
+// CHECK: %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_4]] : memref<6xf64>
+// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_17]][0] {{\[}}%[[VAL_16]]] [1] : memref<6xf64> to memref<?xf64>
+// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_16]]] [1] : memref<?xf64> to memref<?xf64>
+// CHECK: memref.copy %[[VAL_19]], %[[VAL_18]] : memref<?xf64> to memref<?xf64>
+// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
+// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
+// CHECK: %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
+// CHECK: return %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
// CHECK: }
-func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
- %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
- to tensor<6xf64>, tensor<6x2xi32>, index
- return %d, %i, %nnz : tensor<6xf64>, tensor<6x2xi32>, index
+func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
+ %od : tensor<6xf64>,
+ %op : tensor<2xindex>,
+ %oi : tensor<6x2xi32>)
+ -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
+ %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+ outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
+ -> tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
+ return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 4b44436b6da54..4c541a6b61a0f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -179,8 +179,12 @@ module {
vector.print %v: f64
}
- %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
- to tensor<3xf64>, tensor<3x2xi32>, i32
+ %od = tensor.empty() : tensor<3xf64>
+ %op = tensor.empty() : tensor<2xi32>
+ %oi = tensor.empty() : tensor<3x2xi32>
+ %d, %p, %i = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+ outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
+ -> tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -190,30 +194,22 @@ module {
%vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
vector.print %vi : vector<3x2xi32>
- // CHECK-NEXT: 3
- vector.print %n : i32
+ %bod = tensor.empty() : tensor<6xf64>
+ %bop = tensor.empty() : tensor<4xindex>
+ %boi = tensor.empty() : tensor<6x2xindex>
+ %bd, %bp, %bi = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
+ outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
+ -> tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>
- %bd, %bi, %bn = sparse_tensor.unpack %bs batched_lvls=1 :
- tensor<2x10x10xf64, #BCOO> to tensor<2x3xf64>, tensor<2x3x2xindex>, i32
+ // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
+ %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
+ vector.print %vbd : vector<6xf64>
- // CHECK-NEXT: ( ( 1, 2, 3 ), ( 4, 5, 0 ) )
- %vbd = vector.transfer_read %bd[%c0, %c0], %f0 : tensor<2x3xf64>, vector<2x3xf64>
- vector.print %vbd : vector<2x3xf64>
+ // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
+ %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
+ vector.print %vbi : vector<6x2xindex>
- // CHECK-NEXT: ( ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ), ( ( 2, 3 ), ( 4, 2 ), ( 0, 0 ) ) )
- %vbi = vector.transfer_read %bi[%c0, %c0, %c0], %c0 : tensor<2x3x2xindex>, vector<2x3x2xindex>
- vector.print %vbi : vector<2x3x2xindex>
-
- // CHECK-NEXT: 3
- vector.print %bn : i32
-
- %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
- to tensor<3xf64>, tensor<3x2xindex>, index
-
- // FIXME: This should be freed by one-shot-bufferization.
- bufferization.dealloc_tensor %bd : tensor<2x3xf64>
- bufferization.dealloc_tensor %bi : tensor<2x3x2xindex>
return
}
}
More information about the Mlir-commits
mailing list