[Mlir-commits] [mlir] d4db528 - [mlir][sparse] extend unpack operation to support unpacking a batched COO type
Peiming Liu
llvmlistbot at llvm.org
Mon May 1 11:17:36 PDT 2023
Author: Peiming Liu
Date: 2023-05-01T18:17:29Z
New Revision: d4db52893857a836940e0951daa205de1bb1d201
URL: https://github.com/llvm/llvm-project/commit/d4db52893857a836940e0951daa205de1bb1d201
DIFF: https://github.com/llvm/llvm-project/commit/d4db52893857a836940e0951daa205de1bb1d201.diff
LOG: [mlir][sparse] extend unpack operation to support unpacking a batched COO type
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D149103
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_2d.mlir
mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index eea58f91b583c..f29ea600e3347 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -124,9 +124,10 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
}
def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
- Arguments<(ins AnySparseTensor:$tensor)>,
- Results<(outs 1DTensorOf<[AnyType]>:$values,
- 2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
+ Arguments<(ins AnySparseTensor:$tensor,
+ OptionalAttr<IndexAttr>:$batched_lvls)>,
+ Results<(outs TensorOf<[AnyType]>:$values,
+ TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
AnySignlessIntegerOrIndex:$nse)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
@@ -159,11 +160,44 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
// %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
// %nse = 3
```
+
+ If `batched_lvls` is provided, the operation unpacks each batch of the tensors
+ separately. The returned `nse` is the maximum nse of all batches. For a batch with
+ a smaller nse, trailing zeros are appended in the result.
+ Example:
+
+ ```mlir
+ // input BCOO format |1.1, 2.2, 3.3, 0.0|
+ // of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
+ %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
+ : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
+ // %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ],
+ // [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64>
+ // %coordinates = arith.constant dense<[[ [0], [1], [2] ],
+ // [ [1], [2], [0] ]> : tensor<2x3x1xindex>
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns the number of leading levels that are batched.
+ unsigned getNumBatchedLvls();
}];
+ let builders = [
+ OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
+ [{
+ build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
+ [{
+ build($_builder, $_state, resultTypes, tensor, nullptr);
+ }]>
+ ];
+
+
let assemblyFormat =
- "$tensor attr-dict `:` type($tensor)"
- "`to` type($values) `,` type($coordinates) `,` type($nse)";
+ "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
+ "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b2353016079b7..42776c7d80a32 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -719,7 +719,11 @@ LogicalResult UnpackOp::verify() {
const auto coordinatesTp = getRankedTensorType(getCoordinates());
const auto srcTp = getSparseTensorType(getTensor());
return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
- nullptr);
+ getBatchedLvlsAttr());
+}
+
+unsigned UnpackOp::getNumBatchedLvls() {
+ return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
}
LogicalResult ConvertOp::verify() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 8a8b2eda5175d..f17c001308bf0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -153,9 +153,12 @@ struct UnpackOpInterface
: public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
sparse_tensor::UnpackOp> {
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
- // Similar to InsertOp, reallocation is not considered to allocate a new
- // piece of memory.
- return false;
+ // We allocate and return unpacked memory if this is a batched unpack.
+ // When the number of batched levels equals to zero, we reuse the
+ // coordinates/values memref (and reallocation if the requested output size
+ // is larger than the actual size). Similar to InsertOp, reallocation is
+ // not considered to allocate a new piece of memory.
+ return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 3a488b311b95a..9aae52db873a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -213,6 +213,18 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
+Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
+ Value s) {
+ Value load = builder.create<memref::LoadOp>(loc, mem, s);
+ if (!load.getType().isa<IndexType>()) {
+ if (load.getType().getIntOrFloatBitWidth() < 64)
+ load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
+ load =
+ builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
+ }
+ return load;
+}
+
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
if (tp.isa<FloatType>())
return builder.getFloatAttr(tp, 1.0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index b6e6def4e5860..3e1d0b00f825b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -75,6 +75,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+/// Generates a pointer/index load from the sparse storage scheme. Narrower
+/// data types need to be zero extended before casting the value into the
+/// index type used for looping and indexing.
+Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
+
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index afa4828bf170a..ba6b4641408a5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -41,25 +41,6 @@ using namespace mlir::sparse_tensor;
// File local helper functions.
//===----------------------------------------------------------------------===//
-/// Generates a pointer/index load from the sparse storage scheme. Narrower
-/// data types need to be zero extended before casting the value into the
-/// index type used for looping and indexing.
-static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
- Value s) {
- // For the scalar case, we simply zero extend narrower indices into 64-bit
- // values before casting to index without a performance penalty. Here too,
- // however, indices that already are 64-bit, in theory, cannot express the
- // full range as explained above.
- Value load = builder.create<memref::LoadOp>(loc, mem, s);
- if (!load.getType().isa<IndexType>()) {
- if (load.getType().getIntOrFloatBitWidth() < 64)
- load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
- load =
- builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
- }
- return load;
-}
-
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
@@ -707,7 +688,8 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
continue;
}
- bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType);
+ bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+ isCompressedWithHiDLT(lvlType);
// We can at most have one sparse input, otherwise, a while loop is
// required to co-iterate multiple sparse tensors.
assert(!isSparseCond || !isSparse);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index c1cb0926622f6..4b94392b3d19c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -602,6 +602,25 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
return ifOp.getResult(0);
}
+static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange bounds) {
+ assert(ivs.size() == bounds.size());
+ Value crd = constantIndex(builder, loc, 0);
+ for (unsigned i = 0, e = ivs.size(); i < e; i++) {
+ crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
+ if (i != ivs.size() - 1)
+ crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
+ }
+ return crd;
+}
+
+ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
+ ReassociationIndices reassociation;
+ for (int i = 0, e = srcTp.getRank(); i < e; i++)
+ reassociation.push_back(i);
+ return reassociation;
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1252,12 +1271,7 @@ static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
[&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
ValueRange ivs) {
// Linearize index variables
- Value crd = constantIndex(builder, loc, 0);
- for (unsigned i = 0, e = ivs.size(); i < e; i++) {
- crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
- if (i != ivs.size() - 1)
- crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
- }
+ Value crd = linearize(builder, loc, ivs, ubs);
Value len = constantIndex(builder, loc, nse);
Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
SmallVector<Value> indices(ivs.begin(), ivs.end());
@@ -1420,6 +1434,166 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
};
+static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
+ SparseTensorDescriptor desc,
+ ConversionPatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+ const auto srcTp = getSparseTensorType(op.getTensor());
+ const Level lvlRank = srcTp.getLvlRank();
+ Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
+ : desc.getAOSMemRef();
+ Value valuesBuf = desc.getValMemRef();
+
+ // If frontend requests a static buffer, we reallocate the
+ // values/coordinates to ensure that we meet their need.
+ const auto valuesTp = getRankedTensorType(op.getValues());
+ if (valuesTp.hasStaticShape()) {
+ // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
+ // tensor that is packed from constants.
+ valuesBuf =
+ reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
+ }
+
+ const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
+ if (coordinatesTp.hasStaticShape()) {
+ // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
+ // tensor that is packed from constants.
+ auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
+ flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
+ }
+
+ Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
+ loc,
+ MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
+ flatBuf, ArrayRef{ReassociationIndices{0, 1}});
+
+ // Converts MemRefs back to Tensors.
+ Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
+ Value coordinates =
+ rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
+ Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
+ op.getNse().getType());
+
+ rewriter.replaceOp(op, {values, coordinates, nse});
+ return success();
+}
+
+static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
+ SparseTensorDescriptor desc,
+ ConversionPatternRewriter &rewriter) {
+ assert(nBatched != 0);
+ Location loc = op.getLoc();
+ Value c0 = constantIndex(rewriter, loc, 0);
+ Value c1 = constantIndex(rewriter, loc, 1);
+ Value c2 = constantIndex(rewriter, loc, 2);
+
+ auto genZeroedAlloc = [loc,
+ &rewriter](TensorType tt) -> TypedValue<MemRefType> {
+ auto mem = rewriter
+ .create<memref::AllocOp>(
+ loc, MemRefType::get(tt.getShape(), tt.getElementType()))
+ .getMemref();
+ // TODO: Instead of filling the entire buffer, we can only fill the
+ // trailing zeros.
+ rewriter.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
+ return mem;
+ };
+ SparseTensorType stt = getSparseTensorType(op.getTensor());
+ TensorType valTensorTp = op.getValues().getType();
+ TensorType crdTensorTp = op.getCoordinates().getType();
+ TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
+ TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
+ assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
+
+ SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
+ SmallVector<Value> ubs;
+ for (unsigned i = 0; i < nBatched; i++) {
+ assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
+ ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
+ }
+
+ DimLevelType dlt = stt.getLvlType(nBatched);
+ assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
+ Value posStep = isCompressedDLT(dlt) ? c1 // forward position index by 1
+ : c2; // forward position index by 2
+ auto loopNest = scf::buildLoopNest(
+ rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
+ [&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
+ &crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ // crdMemref has shape: <... x nse x rank>
+ unsigned unBatchedRank = crdMemref.getType().getShape().back();
+ Value values = desc.getValMemRef();
+ Value flatCrds = unBatchedRank == 1
+ ? desc.getCrdMemRefOrView(builder, loc, 0)
+ : desc.getAOSMemRef();
+
+ Value positions = desc.getPosMemRef(nBatched);
+ Value positLo = builder.create<arith::MulIOp>(
+ loc, linearize(builder, loc, ivs, ubs), posStep);
+ Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
+
+ Value pLo = genIndexLoad(builder, loc, positions, positLo);
+ Value pHi = genIndexLoad(builder, loc, positions, positHi);
+ Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
+
+ Value crdLo = builder.create<arith::MulIOp>(
+ loc, pLo, constantIndex(builder, loc, unBatchedRank));
+ Value nCrd = builder.create<arith::MulIOp>(
+ loc, nse, constantIndex(builder, loc, unBatchedRank));
+
+ SmallVector<Value> offsets, sizes, strides;
+ for (unsigned i = 0; i < nBatched; i++) {
+ offsets.push_back(ivs[i]);
+ sizes.push_back(c1);
+ strides.push_back(c1);
+ }
+ // [0, nse, 1].
+ offsets.push_back(c0);
+ sizes.push_back(nse);
+ strides.push_back(c1);
+
+ auto valView = builder.create<memref::SubViewOp>(
+ loc, valMemref, offsets, sizes, strides);
+ auto valReass = getReassociationForFlattening(valView.getType());
+ Value valDst =
+ builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
+ Value valSrc =
+ builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
+ builder.create<memref::CopyOp>(loc, valSrc, valDst);
+
+ // [0, rank, 1].
+ offsets.push_back(c0);
+ sizes.push_back(constantIndex(builder, loc, unBatchedRank));
+ strides.push_back(c1);
+
+ auto crdView = builder.create<memref::SubViewOp>(
+ loc, crdMemref, offsets, sizes, strides);
+ auto crdReass = getReassociationForFlattening(crdView.getType());
+ Value crdDst =
+ builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
+ Value crdSrc =
+ builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
+ builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
+
+ Value pred = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ugt, nse, args[0]);
+ // Choose the larger NSE
+ return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
+ });
+
+ // Converts MemRefs back to Tensors.
+ Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
+ Value coordinates =
+ rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
+ Value nse =
+ genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
+
+ rewriter.replaceOp(op, {values, coordinates, nse});
+ return success();
+}
+
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
@@ -1431,52 +1605,26 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- Location loc = op.getLoc();
const auto srcTp = getSparseTensorType(op.getTensor());
- const Level lvlRank = srcTp.getLvlRank();
+ const unsigned nBatched = op.getNumBatchedLvls();
+ assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
+ desc.getFields().size() == 4); // specifier + pos + crds + values
+ auto logicRes = nBatched == 0
+ ? genUnBatchedUnpackOp(op, desc, rewriter)
+ : genBatchedUnpackOp(op, nBatched, desc, rewriter);
+ Value posBuf = desc.getPosMemRef(nBatched);
- assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4);
-
- Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
- : desc.getAOSMemRef();
- Value valuesBuf = desc.getValMemRef();
- Value posBuf = desc.getPosMemRef(0);
if (createDeallocs) {
// Unpack ends the lifetime of the sparse tensor. While the value array
// and coordinate array are unpacked and returned, the position array
// becomes useless and need to be freed (if user requests).
- rewriter.create<memref::DeallocOp>(loc, posBuf);
- }
-
- // If frontend requests a static buffer, we reallocate the
- // values/coordinates to ensure that we meet their need.
- const auto valuesTp = getRankedTensorType(op.getValues());
- if (valuesTp.hasStaticShape()) {
- valuesBuf =
- reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
- }
-
- const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
- if (coordinatesTp.hasStaticShape()) {
- auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
- flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
+ // FIXME: Depending on whether the tensor being unpacked is created by
+ // PackOp or not, we may or may not need to free other memref fields of
+ // the sparse tensor too (PackOp borrows value/coordinate buffer).
+ rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
}
- Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
- loc,
- MemRefType::get(coordinatesTp.getShape(),
- coordinatesTp.getElementType()),
- flatBuf, ArrayRef{ReassociationIndices{0, 1}});
-
- // Converts MemRefs back to Tensors.
- Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
- Value coordinates =
- rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
- Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
- op.getNse().getType());
-
- rewriter.replaceOp(op, {values, coordinates, nse});
- return success();
+ return logicRes;
}
private:
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b6f43adaf399c..0766e906c7216 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -128,6 +128,18 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
// -----
+#BCOO = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+
+func.func @invalid_unpack_type(%sp: tensor<2x100xf32, #BCOO>)
+ -> (tensor<2x6xf32>, tensor<3x6x2xi32>, i32) {
+ // expected-error at +1 {{values/coordinates batched level sizes don't match statically}}
+ %values, %coordinates, %nse = sparse_tensor.unpack %sp batched_lvls=1
+ : tensor<2x100xf32, #BCOO> to tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+ return %values, %coordinates, %nse : tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+}
+
+// -----
+
func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
// expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e3e548c993714..3bfa7c2164494 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -59,6 +59,21 @@ func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>)
// -----
+#BatchedSparseVector = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+
+// CHECK-LABEL: func @sparse_unpack(
+// CHECK-SAME: %[[T:.*]]: tensor<2x100xf64, #
+// CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]] batched_lvls = 1
+// CHECK: return %[[D]], %[[I]], %[[N]]
+func.func @sparse_unpack(%sp : tensor<2x100xf64, #BatchedSparseVector>)
+ -> (tensor<2x6xf64>, tensor<2x6x1xi32>, i32) {
+ %data, %indices, %nnz = sparse_tensor.unpack %sp batched_lvls=1
+ : tensor<2x100xf64, #BatchedSparseVector> to tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+ return %data, %indices, %nnz : tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+}
+
+// -----
+
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_dealloc(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 42f2f1c35c5b4..58dc1e49dcf98 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -603,19 +603,19 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
dimLevelType = [ "dense", "compressed-hi" ],
}>
// CHECK-LABEL: func.func @sub_ss_batched(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf64, #{{.*}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xf64, #{{.*}}>>) -> tensor<2x3xf64, #{{.*}}>> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #{{.*}}>>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #{{.*}}>> to memref<?xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #{{.*}}>> to memref<?xf64>
+// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #{{.*}}>>) {
// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_4]] : index
@@ -628,9 +628,9 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_18]] : index
// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_22]] : index
// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #{{.*}}>>
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>):
+// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #{{.*}}>>):
// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_30]]] : memref<?xindex>
// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xindex>
// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_33]] : index
@@ -638,31 +638,31 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1
-// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #{{.*}}>>) {
// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xf64>
// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<?xf64>
// CHECK: %[[VAL_43:.*]] = arith.subf %[[VAL_41]], %[[VAL_42]] : f64
-// CHECK: %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_44]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK: scf.yield %[[VAL_44]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } else {
// CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
-// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #{{.*}}>>) {
// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xf64>
-// CHECK: %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_48]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK: scf.yield %[[VAL_48]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } else {
// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
-// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #{{.*}}>>) {
// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<?xf64>
// CHECK: %[[VAL_52:.*]] = arith.negf %[[VAL_51]] : f64
-// CHECK: %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_53]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK: scf.yield %[[VAL_53]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_32]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_32]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: }
-// CHECK: scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: }
-// CHECK: scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: }
// CHECK: %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
// CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_30]], %[[VAL_4]] : index
@@ -670,25 +670,25 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] : index
// CHECK: %[[VAL_61:.*]] = arith.select %[[VAL_59]], %[[VAL_60]], %[[VAL_31]] : index
-// CHECK: scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #{{.*}}>>
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_3]] to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_65:.*]] = %[[VAL_66:.*]]#2)
+// CHECK: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_65:.*]]#0 to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_66:.*]] = %[[VAL_65]]#2)
// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_64]]] : memref<?xindex>
// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_64]]] : memref<?xf64>
-// CHECK: %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_65]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_69]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_66]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK: scf.yield %[[VAL_69]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_3]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_72:.*]] = %[[VAL_73:.*]])
-// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref<?xindex>
-// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref<?xf64>
-// CHECK: %[[VAL_76:.*]] = arith.negf %[[VAL_75]] : f64
-// CHECK: %[[VAL_77:.*]] = sparse_tensor.insert %[[VAL_76]] into %[[VAL_72]]{{\[}}%[[VAL_13]], %[[VAL_74]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_77]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_72:.*]]#1 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_73:.*]] = %[[VAL_74:.*]])
+// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref<?xindex>
+// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref<?xf64>
+// CHECK: %[[VAL_77:.*]] = arith.negf %[[VAL_76]] : f64
+// CHECK: %[[VAL_78:.*]] = sparse_tensor.insert %[[VAL_77]] into %[[VAL_73]]{{\[}}%[[VAL_13]], %[[VAL_75]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK: scf.yield %[[VAL_78]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_78:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_79:.*]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: return %[[VAL_79]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_80:.*]] = sparse_tensor.load %[[VAL_81:.*]] hasInserts : tensor<2x3xf64, #{{.*}}>>
+// CHECK: return %[[VAL_80]] : tensor<2x3xf64, #{{.*}}>>
// CHECK: }
func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf64, #BatchedVector>)
-> tensor<2x3xf64, #BatchedVector> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index 57013e7715c43..3d95c86f4aa12 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -145,23 +145,25 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
}>
// CHECK-LABEL: func.func @foreach_bcoo(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #{{.*}}>>) {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #{{.*}}>> to memref<?xf64>
// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index
-// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_10]] step %[[VAL_3]] {
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
-// CHECK: "test.use"(%[[VAL_12]]) : (f64) -> ()
-// CHECK: }
-// CHECK: }
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// CHECK: "test.use"(%[[VAL_13]]) : (f64) -> ()
+// CHECK: } {"Emitted from" = "sparse_tensor.foreach"}
+// CHECK: } {"Emitted from" = "sparse_tensor.foreach"}
// CHECK: return
+// CHECK: }
func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) {
sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do {
^bb0(%1: index, %2: index, %3: index, %v: f64) :
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4648cb3bf2983..fb0d4a73068d9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -45,7 +45,6 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
// CHECK-SAME: %[[VAL_3:.*]]
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: memref.dealloc %[[VAL_0]] : memref<?xindex>
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
@@ -69,6 +68,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier
+// CHECK: memref.dealloc %[[VAL_0]] : memref<?xindex>
// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
// CHECK: }
func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 2b86d566ec4fd..34f0188a92720 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -31,6 +31,10 @@
crdWidth = 32
}>
+#BCOO = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ]
+}>
+
module {
//
// Main driver.
@@ -60,6 +64,25 @@ module {
%s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex>
to tensor<10x10xf64, #SortedCOO>
+ %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32>
+ to tensor<10x10xf64, #SortedCOOI32>
+
+ %bdata = arith.constant dense<
+ [[ 1.0, 2.0, 3.0],
+ [ 4.0, 5.0, 0.0]]
+ > : tensor<2x3xf64>
+
+ %bindex = arith.constant dense<
+ [[[ 1, 2],
+ [ 5, 6],
+ [ 7, 8]],
+ [[ 2, 3],
+ [ 4, 2],
+ [ 10, 10]]]
+ > : tensor<2x3x2xindex>
+ %bs = sparse_tensor.pack %bdata, %bindex batched_lvls = 1 :
+ tensor<2x3xf64>, tensor<2x3x2xindex> to tensor<2x10x10xf64, #BCOO>
+
// CHECK:1
// CHECK-NEXT:2
// CHECK-NEXT:1
@@ -78,8 +101,6 @@ module {
vector.print %v: f64
}
- %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32>
- to tensor<10x10xf64, #SortedCOOI32>
// CHECK-NEXT:1
// CHECK-NEXT:2
// CHECK-NEXT:1
@@ -98,11 +119,23 @@ module {
vector.print %v: f64
}
+ // CHECK-NEXT:1
+ // CHECK-NEXT:2
+ // CHECK-NEXT:3
+ //
+ // CHECK-NEXT:4
+ // CHECK-NEXT:5
+ //
+ // Make sure the trailing zeros are not traversed.
+ // CHECK-NOT: 0
+ sparse_tensor.foreach in %bs : tensor<2x10x10xf64, #BCOO> do {
+ ^bb0(%0: index, %1: index, %2: index, %v: f64) :
+ vector.print %v: f64
+ }
+
%d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
to tensor<3xf64>, tensor<3x2xi32>, i32
-
-
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
vector.print %vd : vector<3xf64>
@@ -114,8 +147,26 @@ module {
// CHECK-NEXT: 3
vector.print %n : i32
+
+ %bd, %bi, %bn = sparse_tensor.unpack %bs batched_lvls=1 :
+ tensor<2x10x10xf64, #BCOO> to tensor<2x3xf64>, tensor<2x3x2xindex>, i32
+
+ // CHECK-NEXT: ( ( 1, 2, 3 ), ( 4, 5, 0 ) )
+ %vbd = vector.transfer_read %bd[%c0, %c0], %f0 : tensor<2x3xf64>, vector<2x3xf64>
+ vector.print %vbd : vector<2x3xf64>
+
+ // CHECK-NEXT: ( ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ), ( ( 2, 3 ), ( 4, 2 ), ( 0, 0 ) ) )
+ %vbi = vector.transfer_read %bi[%c0, %c0, %c0], %c0 : tensor<2x3x2xindex>, vector<2x3x2xindex>
+ vector.print %vbi : vector<2x3x2xindex>
+
+ // CHECK-NEXT: 3
+ vector.print %bn : i32
+
%d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
to tensor<3xf64>, tensor<3x2xindex>, index
+ // FIXME: This should be freed by one-shot-bufferization.
+ bufferization.dealloc_tensor %bd : tensor<2x3xf64>
+ bufferization.dealloc_tensor %bi : tensor<2x3x2xindex>
return
}
}
More information about the Mlir-commits
mailing list