[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 9 16:01:01 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same.
Generalizing the ops is still TBD.
---
Patch is 37.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71880.diff
8 Files Affected:
- (modified) mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (+43-22)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (+11)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h (+4)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (-13)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+165-57)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+2-1)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+79-37)
- (removed) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir (-165)
``````````diff
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
// Linearize the address.
- for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
- valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ for (uint64_t l = 0; l < lvlRank; l++)
+ valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
values[valIdx] = val;
return;
}
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
if (isCompressedLvl(l))
return positions[l][parentSz];
- if (isSingletonLvl(l))
- return parentSz; // New size is same as the parent.
- // TODO: support levels assignment for loose/2:4?
+ if (isLooseCompressedLvl(l))
+ return positions[l][2 * parentSz - 1];
+ if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
}
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
dim2lvl, lvl2dim) {
+ // Note that none of the buffers cany be reused because ownership
+ // of the memory passed from clients is not necessarily transferred.
+ // Therefore, all data is copied over into a new SparseTensorStorage.
+ //
+ // TODO: this needs to be generalized to all formats AND
+ // we need a proper audit of e.g. double compressed
+ // levels where some are not filled
+ //
uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
for (uint64_t l = 0; l < lvlRank; l++) {
- if (!isUniqueLvl(l) && isCompressedLvl(l)) {
- // A `compressed_nu` level marks the start of trailing COO start level.
- // Since the coordinate buffer used for trailing COO are passed in as AoS
- // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
- // copy the value from the provided buffers.
+ if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the coordinate buffer used for trailing COO
+ // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply copy the value from the provided buffers.
trailCOOLen = lvlRank - l;
break;
}
- assert(!isSingletonLvl(l) &&
- "Singleton level not following a compressed_nu level");
- if (isCompressedLvl(l)) {
+ if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- // Copies the lvlBuf into the vectors. The buffer can not be simply reused
- // because the memory passed from users is not necessarily allocated on
- // heap.
- positions[l].assign(posPtr, posPtr + parentSz + 1);
- coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ if (!isLooseCompressedLvl(l)) {
+ positions[l].assign(posPtr, posPtr + parentSz + 1);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ } else {
+ positions[l].assign(posPtr, posPtr + 2 * parentSz);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+ }
+ } else if (isSingletonLvl(l)) {
+ assert(0 && "general singleton not supported yet");
+ } else if (is2OutOf4Lvl(l)) {
+ assert(0 && "2Out4 not supported yet");
} else {
- // TODO: support levels assignment for loose/2:4?
assert(isDenseLvl(l));
}
parentSz = assembledSize(parentSz, l);
}
+ // Handle Aos vs. SoA mismatch for COO.
if (trailCOOLen != 0) {
uint64_t cooStartLvl = lvlRank - trailCOOLen;
- assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+ assert(!isUniqueLvl(cooStartLvl) &&
+ (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
- P crdLen = positions[cooStartLvl][parentSz];
+ P crdLen;
+ if (!isLooseCompressedLvl(cooStartLvl)) {
+ positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+ crdLen = positions[cooStartLvl][parentSz];
+ } else {
+ positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+ crdLen = positions[cooStartLvl][2 * parentSz - 1];
+ }
for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
coordinates[l].resize(crdLen);
for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
parentSz = assembledSize(parentSz, cooStartLvl);
}
+ // Copy the values buffer.
V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
values.assign(valPtr, valPtr + parentSz);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+ Value elem, Type dstTp) {
+ if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ // Scalars can only be converted to 0-ranked tensors.
+ assert(rtp.getRank() == 0);
+ elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+ return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+ }
+ return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+ Type dstTp);
+
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
- Type dstTp) {
- if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
- // Scalars can only be converted to 0-ranked tensors.
- if (rtp.getRank() != 0)
- return nullptr;
- elem = genCast(builder, loc, elem, rtp.getElementType());
- return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
- }
- return genCast(builder, loc, elem, dstTp);
-}
-
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
- StringRef name, TypeRange resultType,
- ValueRange operands,
- EmitCInterface emitCInterface) {
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
- emitCInterface);
- return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
- operands);
-}
-
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
};
/// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
- ValueRange ptr) {
- SmallString<15> name{"sparseValues",
- primaryTypeFunctionSuffix(tp.getElementType())};
- return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr) {
+ auto eltTp = stt.getElementType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+ SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type posTp = stt.getPosType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type crdTp = stt.getCrdType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
.getResult(0);
}
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(
stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
dimSizes.reserve(dimRank);
auto shape = op.getType().getShape();
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
: constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
LogicalResult
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resTp = op.getType();
- Type posTp = cast<ShapedType>(resTp).getElementType();
- SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
- Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
- replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
- EmitCInterface::On);
+ auto stt = getSparseTensorType(op.getTensor());
+ auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
+ rewriter.replaceOp(op, poss);
return success();
}
};
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: use `SparseTensorType::getCrdType` instead.
- Type resType = op.getType();
- const Type crdTp = cast<ShapedType>(resType).getElementType();
- SmallString<19> name{"sparseCoordinates",
- overheadTypeFunctionSuffix(crdTp)};
- Location loc = op->getLoc();
- Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
- // The function returns a MemRef without a layout.
- MemRefType callRetType = get1DMemRefType(crdTp, false);
- SmallVector<Value> operands{adaptor.getTensor(), lvl};
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
- operands, EmitCInterface::On);
- Value callRet =
- rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
- .getResult(0);
-
+ auto stt = getSparseTensorType(op.getTensor());
+ auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
- if (resType != callRetType)
- callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
- rewriter.replaceOp(op, callRet);
-
+ if (op.getType() != crds.getType())
+ crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+ rewriter.replaceOp(op, crds);
return success();
}
};
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = cast<ShapedType>(op.getType());
- rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
- adaptor.getOperands()));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ rewriter.replaceOp(op, vals);
return success();
}
};
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
- auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
- Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
- rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
- constantIndex(rewriter, loc, 0));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ auto zero = constantIndex(rewriter, op.getLoc(), 0);
+ rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
return success();
}
};
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
- // AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+ // Use a library method to transfer the external buffers from
+ // clients to the internal SparseTensorStorage. Since we cannot
+ // assume clients transfer ownership of the buffers, this method
+ // will copy all data over into a new SparseTensorStorage.
Value dst =
NewCallParams(rewriter, loc)
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
}
};
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+ : public OpConversionPattern<DisassembleOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // We simply expose the buffers to the external client. This
+ // assumes the client only reads the buffers (usually copying it
+ // to the external data structures, such as numpy arrays).
+ Location loc = op->getLoc();
+ auto stt = getSparseTensorType(op.getTensor());
+ SmallVector<Value> retVal;
+ SmallVector<Value> retLen;
+ // Get the values buffer first.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+ // Then get the positions and coordinates buffers.
+ const Level lvlRank = stt.getLvlRank();
+ Level trailCOOLen = 0;
+ for (Level l = 0; l < lvlRank; l++) {
+ if (!stt.isUniqueLvl(l) &&
+ (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the target coordinate buffer used for trailing
+ // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply use the internal buffers.
+ trailCOOLen = lvlRank - l;
+ break;
+ }
+ if (stt.isWithPos(l)) {
+ auto poss =
+ genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ }
+ if (stt.isWithCrd(l)) {
+ auto crds =
+ genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, lo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/71880
More information about the Mlir-commits
mailing list