[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)
Aart Bik
llvmlistbot at llvm.org
Thu Nov 9 16:05:51 PST 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/71880
>From b470eabac2a2aadf93da80a854adb6d810104fca Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 9 Nov 2023 15:55:49 -0800
Subject: [PATCH 1/2] [mlir][sparse] unify support of (dis)assemble between
direct IR/lib path
Note that the (dis)assemble operations still make some simplfying
assumptions (e.g. trailing 2-D COO in AoS format) but now at least
both the direct IR and support library path behave exactly the same.
Generalizing the ops is still TBD.
---
.../ExecutionEngine/SparseTensor/Storage.h | 65 +++--
.../SparseTensor/Transforms/CodegenUtils.cpp | 11 +
.../SparseTensor/Transforms/CodegenUtils.h | 4 +
.../Transforms/SparseTensorCodegen.cpp | 13 -
.../Transforms/SparseTensorConversion.cpp | 222 +++++++++++++-----
.../Transforms/SparseTensorPasses.cpp | 3 +-
.../Dialect/SparseTensor/CPU/sparse_pack.mlir | 116 ++++++---
.../SparseTensor/CPU/sparse_pack_libgen.mlir | 165 -------------
8 files changed, 304 insertions(+), 295 deletions(-)
delete mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
// Linearize the address.
- for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
- valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ for (uint64_t l = 0; l < lvlRank; l++)
+ valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
values[valIdx] = val;
return;
}
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
if (isCompressedLvl(l))
return positions[l][parentSz];
- if (isSingletonLvl(l))
- return parentSz; // New size is same as the parent.
- // TODO: support levels assignment for loose/2:4?
+ if (isLooseCompressedLvl(l))
+ return positions[l][2 * parentSz - 1];
+ if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
}
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
dim2lvl, lvl2dim) {
+ // Note that none of the buffers cany be reused because ownership
+ // of the memory passed from clients is not necessarily transferred.
+ // Therefore, all data is copied over into a new SparseTensorStorage.
+ //
+ // TODO: this needs to be generalized to all formats AND
+ // we need a proper audit of e.g. double compressed
+ // levels where some are not filled
+ //
uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
for (uint64_t l = 0; l < lvlRank; l++) {
- if (!isUniqueLvl(l) && isCompressedLvl(l)) {
- // A `compressed_nu` level marks the start of trailing COO start level.
- // Since the coordinate buffer used for trailing COO are passed in as AoS
- // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
- // copy the value from the provided buffers.
+ if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the coordinate buffer used for trailing COO
+ // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply copy the value from the provided buffers.
trailCOOLen = lvlRank - l;
break;
}
- assert(!isSingletonLvl(l) &&
- "Singleton level not following a compressed_nu level");
- if (isCompressedLvl(l)) {
+ if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- // Copies the lvlBuf into the vectors. The buffer can not be simply reused
- // because the memory passed from users is not necessarily allocated on
- // heap.
- positions[l].assign(posPtr, posPtr + parentSz + 1);
- coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ if (!isLooseCompressedLvl(l)) {
+ positions[l].assign(posPtr, posPtr + parentSz + 1);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ } else {
+ positions[l].assign(posPtr, posPtr + 2 * parentSz);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+ }
+ } else if (isSingletonLvl(l)) {
+ assert(0 && "general singleton not supported yet");
+ } else if (is2OutOf4Lvl(l)) {
+ assert(0 && "2Out4 not supported yet");
} else {
- // TODO: support levels assignment for loose/2:4?
assert(isDenseLvl(l));
}
parentSz = assembledSize(parentSz, l);
}
+ // Handle Aos vs. SoA mismatch for COO.
if (trailCOOLen != 0) {
uint64_t cooStartLvl = lvlRank - trailCOOLen;
- assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+ assert(!isUniqueLvl(cooStartLvl) &&
+ (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
- P crdLen = positions[cooStartLvl][parentSz];
+ P crdLen;
+ if (!isLooseCompressedLvl(cooStartLvl)) {
+ positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+ crdLen = positions[cooStartLvl][parentSz];
+ } else {
+ positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+ crdLen = positions[cooStartLvl][2 * parentSz - 1];
+ }
for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
coordinates[l].resize(crdLen);
for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
parentSz = assembledSize(parentSz, cooStartLvl);
}
+ // Copy the values buffer.
V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
values.assign(valPtr, valPtr + parentSz);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+ Value elem, Type dstTp) {
+ if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ // Scalars can only be converted to 0-ranked tensors.
+ assert(rtp.getRank() == 0);
+ elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+ return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+ }
+ return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+ Type dstTp);
+
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
- Type dstTp) {
- if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
- // Scalars can only be converted to 0-ranked tensors.
- if (rtp.getRank() != 0)
- return nullptr;
- elem = genCast(builder, loc, elem, rtp.getElementType());
- return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
- }
- return genCast(builder, loc, elem, dstTp);
-}
-
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
- StringRef name, TypeRange resultType,
- ValueRange operands,
- EmitCInterface emitCInterface) {
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
- emitCInterface);
- return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
- operands);
-}
-
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
};
/// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
- ValueRange ptr) {
- SmallString<15> name{"sparseValues",
- primaryTypeFunctionSuffix(tp.getElementType())};
- return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr) {
+ auto eltTp = stt.getElementType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+ SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type posTp = stt.getPosType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type crdTp = stt.getCrdType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
.getResult(0);
}
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(
stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
dimSizes.reserve(dimRank);
auto shape = op.getType().getShape();
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
: constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
LogicalResult
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resTp = op.getType();
- Type posTp = cast<ShapedType>(resTp).getElementType();
- SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
- Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
- replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
- EmitCInterface::On);
+ auto stt = getSparseTensorType(op.getTensor());
+ auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
+ rewriter.replaceOp(op, poss);
return success();
}
};
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: use `SparseTensorType::getCrdType` instead.
- Type resType = op.getType();
- const Type crdTp = cast<ShapedType>(resType).getElementType();
- SmallString<19> name{"sparseCoordinates",
- overheadTypeFunctionSuffix(crdTp)};
- Location loc = op->getLoc();
- Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
- // The function returns a MemRef without a layout.
- MemRefType callRetType = get1DMemRefType(crdTp, false);
- SmallVector<Value> operands{adaptor.getTensor(), lvl};
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
- operands, EmitCInterface::On);
- Value callRet =
- rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
- .getResult(0);
-
+ auto stt = getSparseTensorType(op.getTensor());
+ auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
- if (resType != callRetType)
- callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
- rewriter.replaceOp(op, callRet);
-
+ if (op.getType() != crds.getType())
+ crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+ rewriter.replaceOp(op, crds);
return success();
}
};
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = cast<ShapedType>(op.getType());
- rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
- adaptor.getOperands()));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ rewriter.replaceOp(op, vals);
return success();
}
};
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
- auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
- Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
- rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
- constantIndex(rewriter, loc, 0));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ auto zero = constantIndex(rewriter, op.getLoc(), 0);
+ rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
return success();
}
};
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
- // AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+ // Use a library method to transfer the external buffers from
+ // clients to the internal SparseTensorStorage. Since we cannot
+ // assume clients transfer ownership of the buffers, this method
+ // will copy all data over into a new SparseTensorStorage.
Value dst =
NewCallParams(rewriter, loc)
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
}
};
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+ : public OpConversionPattern<DisassembleOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // We simply expose the buffers to the external client. This
+ // assumes the client only reads the buffers (usually copying it
+ // to the external data structures, such as numpy arrays).
+ Location loc = op->getLoc();
+ auto stt = getSparseTensorType(op.getTensor());
+ SmallVector<Value> retVal;
+ SmallVector<Value> retLen;
+ // Get the values buffer first.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+ // Then get the positions and coordinates buffers.
+ const Level lvlRank = stt.getLvlRank();
+ Level trailCOOLen = 0;
+ for (Level l = 0; l < lvlRank; l++) {
+ if (!stt.isUniqueLvl(l) &&
+ (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the target coordinate buffer used for trailing
+ // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply use the internal buffers.
+ trailCOOLen = lvlRank - l;
+ break;
+ }
+ if (stt.isWithPos(l)) {
+ auto poss =
+ genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ }
+ if (stt.isWithCrd(l)) {
+ auto crds =
+ genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(crds);
+ retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
+ }
+ }
+ // Handle AoS vs. SoA mismatch for COO.
+ if (trailCOOLen != 0) {
+ uint64_t cooStartLvl = lvlRank - trailCOOLen;
+ assert(!stt.isUniqueLvl(cooStartLvl) &&
+ (stt.isCompressedLvl(cooStartLvl) ||
+ stt.isLooseCompressedLvl(cooStartLvl)));
+ // Positions.
+ auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
+ cooStartLvl);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ // Coordinates, copied over with:
+ // for (i = 0; i < crdLen; i++)
+ // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
+ auto buf =
+ genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+ auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
+ cooStartLvl);
+ auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
+ cooStartLvl + 1);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
+ auto two = constantIndex(rewriter, loc, 2);
+ auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
+ Type indexType = rewriter.getIndexType();
+ auto zero = constantZero(rewriter, loc, indexType);
+ auto one = constantOne(rewriter, loc, indexType);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
+ auto idx = forOp.getInductionVar();
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
+ auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
+ SmallVector<Value> args;
+ args.push_back(idx);
+ args.push_back(zero);
+ rewriter.create<memref::StoreOp>(loc, c0, buf, args);
+ args[1] = one;
+ rewriter.create<memref::StoreOp>(loc, c1, buf, args);
+ rewriter.setInsertionPointAfter(forOp);
+ auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(buf);
+ retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
+ }
+ // Converts MemRefs back to Tensors.
+ assert(retVal.size() + retLen.size() == op.getNumResults());
+ for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
+ auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
+ retVal[i] =
+ rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
+ }
+ // Appends the actual memory length used in each buffer returned.
+ retVal.append(retLen.begin(), retLen.end());
+ rewriter.replaceOp(op, retVal);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -752,5 +859,6 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
+ SparseTensorAssembleConverter, SparseTensorDisassembleConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index e1cbf3482708ad0..10ebfa7922088a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -198,7 +198,8 @@ struct SparseTensorConversionPass
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
- linalg::YieldOp, tensor::ExtractOp>();
+ linalg::YieldOp, tensor::ExtractOp,
+ tensor::FromElementsOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index a2f93614590f106..840e3c97ae28843 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -3,7 +3,7 @@
//
// Set-up that's shared across all tests in this directory. In principle, this
// config could be moved to lit.local.cfg. However, there are downstream users that
-// do not use these LIT config files. Hence why this is kept inline.
+// do not use these LIT config files. Hence why this is kept inline.
//
// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
@@ -17,15 +17,19 @@
// DEFINE: %{env} =
//--------------------------------------------------------------------------------------------------
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
// RUN: %{compile} | %{run} | FileCheck %s
//
-// Do the same run, but now with VLA vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=4
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-// TODO: support sparse_tensor.disassemble on libgen path.
-
#SortedCOO = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
}>
@@ -54,11 +58,13 @@ module {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f64
%i0 = arith.constant 0 : i32
+
//
- // Initialize a 3-dim dense tensor.
+ // Setup COO.
//
+
%data = arith.constant dense<
- [ 1.0, 2.0, 3.0]
+ [ 1.0, 2.0, 3.0 ]
> : tensor<3xf64>
%pos = arith.constant dense<
@@ -83,12 +89,16 @@ module {
%s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
to tensor<10x10xf64, #SortedCOO>
- %s5= sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
- to tensor<10x10xf64, #SortedCOOI32>
+ %s5 = sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+ to tensor<10x10xf64, #SortedCOOI32>
+
+ //
+ // Setup CSR.
+ //
%csr_data = arith.constant dense<
- [ 1.0, 2.0, 3.0, 4.0]
- > : tensor<4xf64>
+ [ 1.0, 2.0, 3.0 ]
+ > : tensor<3xf64>
%csr_pos32 = arith.constant dense<
[0, 1, 3]
@@ -97,12 +107,16 @@ module {
%csr_index32 = arith.constant dense<
[1, 0, 1]
> : tensor<3xi32>
- %csr= sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+ %csr = sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
to tensor<2x2xf64, #CSR>
+ //
+ // Setup BCOO.
+ //
+
%bdata = arith.constant dense<
- [ 1.0, 2.0, 3.0, 4.0, 5.0, 0.0]
- > : tensor<6xf64>
+ [ 1.0, 2.0, 3.0, 4.0, 5.0 ]
+ > : tensor<5xf64>
%bpos = arith.constant dense<
[0, 3, 3, 5]
@@ -116,10 +130,15 @@ module {
[ 4, 2],
[ 10, 10]]
> : tensor<6x2xindex>
+
%bs = sparse_tensor.assemble %bdata, %bpos, %bindex :
- tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
+ tensor<5xf64>, tensor<4xindex>, tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
- // CHECK:1
+ //
+ // Verify results.
+ //
+
+ // CHECK: 1
// CHECK-NEXT:2
// CHECK-NEXT:1
//
@@ -135,7 +154,7 @@ module {
vector.print %1: index
vector.print %2: index
vector.print %v: f64
- }
+ }
// CHECK-NEXT:1
// CHECK-NEXT:2
@@ -153,7 +172,7 @@ module {
vector.print %1: index
vector.print %2: index
vector.print %v: f64
- }
+ }
// CHECK-NEXT:0
// CHECK-NEXT:1
@@ -171,32 +190,43 @@ module {
vector.print %1: index
vector.print %2: index
vector.print %v: f64
- }
-
- %d_csr = tensor.empty() : tensor<4xf64>
- %p_csr = tensor.empty() : tensor<3xi32>
- %i_csr = tensor.empty() : tensor<3xi32>
- %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
- outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
- -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
-
- // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
- %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
- vector.print %vd_csr : vector<4xf64>
+ }
+ // CHECK-NEXT:0
+ // CHECK-NEXT:1
+ // CHECK-NEXT:2
+ // CHECK-NEXT:1
+ //
+ // CHECK-NEXT:0
+ // CHECK-NEXT:5
+ // CHECK-NEXT:6
+ // CHECK-NEXT:2
+ //
+ // CHECK-NEXT:0
+ // CHECK-NEXT:7
+ // CHECK-NEXT:8
+ // CHECK-NEXT:3
+ //
// CHECK-NEXT:1
// CHECK-NEXT:2
// CHECK-NEXT:3
+ // CHECK-NEXT:4
//
+ // CHECK-NEXT:1
// CHECK-NEXT:4
+ // CHECK-NEXT:2
// CHECK-NEXT:5
- //
- // Make sure the trailing zeros are not traversed.
- // CHECK-NOT: 0
sparse_tensor.foreach in %bs : tensor<2x10x10xf64, #BCOO> do {
^bb0(%0: index, %1: index, %2: index, %v: f64) :
+ vector.print %0: index
+ vector.print %1: index
+ vector.print %2: index
vector.print %v: f64
- }
+ }
+
+ //
+ // Verify disassemble operations.
+ //
%od = tensor.empty() : tensor<3xf64>
%op = tensor.empty() : tensor<2xi32>
@@ -213,6 +243,16 @@ module {
%vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
vector.print %vi : vector<3x2xi32>
+ %d_csr = tensor.empty() : tensor<4xf64>
+ %p_csr = tensor.empty() : tensor<3xi32>
+ %i_csr = tensor.empty() : tensor<3xi32>
+ %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
+ outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
+ -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
+
+ // CHECK-NEXT: ( 1, 2, 3 )
+ %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<3xf64>
+ vector.print %vd_csr : vector<3xf64>
%bod = tensor.empty() : tensor<6xf64>
%bop = tensor.empty() : tensor<4xindex>
@@ -221,15 +261,17 @@ module {
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
- // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
- %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
- vector.print %vbd : vector<6xf64>
+ // CHECK-NEXT: ( 1, 2, 3, 4, 5 )
+ %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<5xf64>
+ vector.print %vbd : vector<5xf64>
+
// CHECK-NEXT: 5
vector.print %ld : index
// CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
vector.print %vbi : vector<6x2xindex>
+
// CHECK-NEXT: 10
%si = tensor.extract %li[] : tensor<i64>
vector.print %si : i64
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir
deleted file mode 100644
index 6540c950ab675b0..000000000000000
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir
+++ /dev/null
@@ -1,165 +0,0 @@
-//--------------------------------------------------------------------------------------------------
-// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
-//
-// Set-up that's shared across all tests in this directory. In principle, this
-// config could be moved to lit.local.cfg. However, there are downstream users that
-// do not use these LIT config files. Hence why this is kept inline.
-//
-// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
-// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
-// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
-// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
-// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
-// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
-// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
-//
-// DEFINE: %{env} =
-//--------------------------------------------------------------------------------------------------
-
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with VLA vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=true vl=4
-// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-
-// TODO: This is considered to be a short-living tests and should be merged with sparse_pack.mlir
-// after sparse_tensor.disassemble is supported on libgen path.
-
-#SortedCOO = #sparse_tensor.encoding<{
- map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
-}>
-
-#SortedCOOI32 = #sparse_tensor.encoding<{
- map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),
- posWidth = 32,
- crdWidth = 32
-}>
-
-#CSR = #sparse_tensor.encoding<{
- map = (d0, d1) -> (d0 : dense, d1 : compressed),
- posWidth = 32,
- crdWidth = 32
-}>
-
-// TODO: "loose_compressed" is not supported by libgen path.
-// #BCOO = #sparse_tensor.encoding<{
-// map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique, high), d2 : singleton)
-//}>
-
-module {
- //
- // Main driver.
- //
- func.func @entry() {
- %c0 = arith.constant 0 : index
- %f0 = arith.constant 0.0 : f64
- %i0 = arith.constant 0 : i32
- //
- // Initialize a 3-dim dense tensor.
- //
- %data = arith.constant dense<
- [ 1.0, 2.0, 3.0]
- > : tensor<3xf64>
-
- %pos = arith.constant dense<
- [0, 3]
- > : tensor<2xindex>
-
- %index = arith.constant dense<
- [[ 1, 2],
- [ 5, 6],
- [ 7, 8]]
- > : tensor<3x2xindex>
-
- %pos32 = arith.constant dense<
- [0, 3]
- > : tensor<2xi32>
-
- %index32 = arith.constant dense<
- [[ 1, 2],
- [ 5, 6],
- [ 7, 8]]
- > : tensor<3x2xi32>
-
- %s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
- to tensor<10x10xf64, #SortedCOO>
- %s5= sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
- to tensor<10x10xf64, #SortedCOOI32>
-
- %csr_data = arith.constant dense<
- [ 1.0, 2.0, 3.0, 4.0]
- > : tensor<4xf64>
-
- %csr_pos32 = arith.constant dense<
- [0, 1, 3]
- > : tensor<3xi32>
-
- %csr_index32 = arith.constant dense<
- [1, 0, 1]
- > : tensor<3xi32>
- %csr= sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
- to tensor<2x2xf64, #CSR>
-
- // CHECK:1
- // CHECK-NEXT:2
- // CHECK-NEXT:1
- //
- // CHECK-NEXT:5
- // CHECK-NEXT:6
- // CHECK-NEXT:2
- //
- // CHECK-NEXT:7
- // CHECK-NEXT:8
- // CHECK-NEXT:3
- sparse_tensor.foreach in %s4 : tensor<10x10xf64, #SortedCOO> do {
- ^bb0(%1: index, %2: index, %v: f64) :
- vector.print %1: index
- vector.print %2: index
- vector.print %v: f64
- }
-
- // CHECK-NEXT:1
- // CHECK-NEXT:2
- // CHECK-NEXT:1
- //
- // CHECK-NEXT:5
- // CHECK-NEXT:6
- // CHECK-NEXT:2
- //
- // CHECK-NEXT:7
- // CHECK-NEXT:8
- // CHECK-NEXT:3
- sparse_tensor.foreach in %s5 : tensor<10x10xf64, #SortedCOOI32> do {
- ^bb0(%1: index, %2: index, %v: f64) :
- vector.print %1: index
- vector.print %2: index
- vector.print %v: f64
- }
-
- // CHECK-NEXT:0
- // CHECK-NEXT:1
- // CHECK-NEXT:1
- //
- // CHECK-NEXT:1
- // CHECK-NEXT:0
- // CHECK-NEXT:2
- //
- // CHECK-NEXT:1
- // CHECK-NEXT:1
- // CHECK-NEXT:3
- sparse_tensor.foreach in %csr : tensor<2x2xf64, #CSR> do {
- ^bb0(%1: index, %2: index, %v: f64) :
- vector.print %1: index
- vector.print %2: index
- vector.print %v: f64
- }
-
-
- bufferization.dealloc_tensor %s4 : tensor<10x10xf64, #SortedCOO>
- bufferization.dealloc_tensor %s5 : tensor<10x10xf64, #SortedCOOI32>
- bufferization.dealloc_tensor %csr : tensor<2x2xf64, #CSR>
-
- return
- }
-}
>From 80a73a6a86b8afa4105b44ede2883ef80af81b7f Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 9 Nov 2023 16:05:12 -0800
Subject: [PATCH 2/2] add ST defs
---
.../include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index e808057cf6b0a67..a97c185c12e67d3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -295,9 +295,15 @@ class SparseTensorType {
// `getLvlType` method instead of STEA's.
bool isDenseLvl(Level l) const { return isDenseDLT(getLvlType(l)); }
bool isCompressedLvl(Level l) const { return isCompressedDLT(getLvlType(l)); }
+ bool isLooseCompressedLvl(Level l) const {
+ return isLooseCompressedDLT(getLvlType(l));
+ }
bool isSingletonLvl(Level l) const { return isSingletonDLT(getLvlType(l)); }
+ bool is2OutOf4Lvl(Level l) const { return is2OutOf4DLT(getLvlType(l)); }
bool isOrderedLvl(Level l) const { return isOrderedDLT(getLvlType(l)); }
bool isUniqueLvl(Level l) const { return isUniqueDLT(getLvlType(l)); }
+ bool isWithPos(Level l) const { return isDLTWithPos(getLvlType(l)); }
+ bool isWithCrd(Level l) const { return isDLTWithCrd(getLvlType(l)); }
/// Returns the coordinate-overhead bitwidth, defaulting to zero.
unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }
More information about the Mlir-commits
mailing list