[Mlir-commits] [mlir] 83a5083 - [mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)
Peiming Liu
llvmlistbot at llvm.org
Tue Jan 17 12:54:33 PST 2023
Author: Peiming Liu
Date: 2023-01-17T20:54:27Z
New Revision: 83a50839b7ba3cb60cd46403cc517237a73d5276
URL: https://github.com/llvm/llvm-project/commit/83a50839b7ba3cb60cd46403cc517237a73d5276
DIFF: https://github.com/llvm/llvm-project/commit/83a50839b7ba3cb60cd46403cc517237a73d5276.diff
LOG: [mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)
Use SparseTensorDescriptor whenever not calling setters, to avoid needing to create a temporal buffer for simple query purposes.
Reviewed By: bixia, wrengr
Differential Revision: https://reviews.llvm.org/D141953
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 9466f6f147538..f47d3046f6bae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -593,7 +593,5 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(tensor, fields);
- return desc.getValMemSize(builder, loc);
-}
\ No newline at end of file
+ return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 975403ee739e4..4a1a0c9258610 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -102,11 +102,9 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
}
/// Gets the dimension size for the given sparse tensor at the given
-/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
-/// attached to the given tensor type.
-static std::optional<Value>
-sizeFromTensorAtDim(OpBuilder &builder, Location loc,
- const SparseTensorDescriptor &desc, unsigned dim) {
+/// original dimension 'dim'.
+static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
+ SparseTensorDescriptor desc, unsigned dim) {
RankedTensorType rtp = desc.getTensorType();
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
@@ -119,17 +117,12 @@ sizeFromTensorAtDim(OpBuilder &builder, Location loc,
return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
}
-// Gets the dimension size at the given stored dimension 'd', either as a
+// Gets the dimension size at the given stored level 'lvl', either as a
// constant for a static size, or otherwise dynamically through memSizes.
-Value sizeAtStoredDim(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc, unsigned d) {
- RankedTensorType rtp = desc.getTensorType();
- unsigned dim = toOrigDim(rtp, d);
- auto shape = rtp.getShape();
- if (!ShapedType::isDynamic(shape[dim]))
- return constantIndex(builder, loc, shape[dim]);
-
- return desc.getDimSize(builder, loc, d);
+static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
+ SparseTensorDescriptor desc, unsigned lvl) {
+ return sizeFromTensorAtDim(builder, loc, desc,
+ toOrigDim(desc.getTensorType(), lvl));
}
static void createPushback(OpBuilder &builder, Location loc,
@@ -174,7 +167,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDim(rtp, r));
- Value size = sizeAtStoredDim(builder, loc, desc, r);
+ Value size = sizeFromTensorAtLvl(builder, loc, desc, r);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
@@ -436,7 +429,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// Construct the new position as:
// pos[d] = size * pos[d-1] + i[d]
// <insert @ pos[d] at next dimension d + 1>
- Value size = sizeAtStoredDim(builder, loc, desc, d);
+ Value size = sizeFromTensorAtLvl(builder, loc, desc, d);
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
}
@@ -517,7 +510,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
/// Generations insertion finalization code.
static void genEndInsert(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc) {
+ SparseTensorDescriptor desc) {
RankedTensorType rtp = desc.getTensorType();
unsigned rank = rtp.getShape().size();
for (unsigned d = 0; d < rank; d++) {
@@ -654,10 +647,7 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
- if (!sz)
- return failure();
-
- rewriter.replaceOp(op, *sz);
+ rewriter.replaceOp(op, sz);
return success();
}
};
@@ -727,8 +717,7 @@ class SparseTensorDeallocConverter
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -746,8 +735,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
@@ -780,11 +768,10 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
// recursively rewrite the new DimOp on the **original** tensor.
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
- assert(sz); // This for sure is a sparse tensor
// Generate a memref for `sz` elements of type `t`.
auto genAlloc = [&](Type t) {
auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
- return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
+ return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
};
// Allocate temporary buffers for values/filled-switch and added.
// We do not use stack buffers for this, since the expanded size may
@@ -957,8 +944,7 @@ class SparseToIndicesBufferConverter
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getAOSMemRef());
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 9ca7149056ddd..7ffb2c8fbbf85 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -202,20 +202,9 @@ class SparseTensorSpecifier {
/// field in a consistent way.
/// Users should not make assumption on how a sparse tensor is laid out but
/// instead relies on this class to access the right value for the right field.
-template <bool mut>
+template <typename ValueArrayRef>
class SparseTensorDescriptorImpl {
protected:
- // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
- // for mutable descriptors.
- // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
- // buffers to append value for some special cases, though users should be
- // responsible to restore the buffer to legal states after their use. It is
- // probably not a clean way, but it is the most efficient way to avoid copying
- // the fields into another SmallVector. If a more clear way is wanted, we
- // should change it to MutableArrayRef instead.
- using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
- ValueRange>::type;
-
SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
: rType(tp.cast<RankedTensorType>()), fields(fields) {
assert(getSparseTensorEncoding(tp) &&
@@ -223,8 +212,8 @@ class SparseTensorDescriptorImpl {
fields.size());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
- static_assert(
- std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
+ static_assert(std::is_trivially_copyable_v<
+ SparseTensorDescriptorImpl<ValueArrayRef>>);
}
public:
@@ -262,12 +251,12 @@ class SparseTensorDescriptorImpl {
Value getMemRefField(SparseTensorFieldKind kind,
std::optional<unsigned> dim) const {
- return fields[getMemRefFieldIndex(kind, dim)];
+ return getField(getMemRefFieldIndex(kind, dim));
}
Value getMemRefField(unsigned fidx) const {
assert(fidx < fields.size() - 1);
- return fields[fidx];
+ return getField(fidx);
}
Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
@@ -293,6 +282,31 @@ class SparseTensorDescriptorImpl {
.getElementType();
}
+ Value getField(unsigned fidx) const {
+ assert(fidx < fields.size());
+ return fields[fidx];
+ }
+
+ ValueRange getMemRefFields() const {
+ ValueRange ret = fields;
+ // Drop the last metadata fields.
+ return ret.slice(0, fields.size() - 1);
+ }
+
+ std::pair<unsigned, unsigned>
+ getIdxMemRefIndexAndStride(unsigned idxDim) const {
+ StorageLayout layout(getSparseTensorEncoding(rType));
+ return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
+ idxDim);
+ }
+
+ Value getAOSMemRef() const {
+ auto enc = getSparseTensorEncoding(rType);
+ unsigned cooStart = getCOOStart(enc);
+ assert(cooStart < enc.getDimLevelType().size());
+ return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
+ }
+
RankedTensorType getTensorType() const { return rType; }
ValueArrayRef getFields() const { return fields; }
@@ -301,25 +315,38 @@ class SparseTensorDescriptorImpl {
ValueArrayRef fields;
};
-class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
+/// Uses ValueRange for immuatable descriptors;
+class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
public:
- MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
- : SparseTensorDescriptorImpl<true>(tp, buffers) {}
+ SparseTensorDescriptor(Type tp, ValueRange buffers)
+ : SparseTensorDescriptorImpl<ValueRange>(tp, buffers) {}
- Value getField(unsigned fidx) const {
- assert(fidx < fields.size());
- return fields[fidx];
- }
+ Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
+ unsigned idxDim) const;
+};
- ValueRange getMemRefFields() const {
- ValueRange ret = fields;
- // Drop the last metadata fields.
- return ret.slice(0, fields.size() - 1);
+/// Uses SmallVectorImpl<Value> & for mutable descriptors.
+/// Using SmallVector for mutable descriptor allows users to reuse it as a
+/// tmp buffers to append value for some special cases, though users should
+/// be responsible to restore the buffer to legal states after their use. It
+/// is probably not a clean way, but it is the most efficient way to avoid
+/// copying the fields into another SmallVector. If a more clear way is
+/// wanted, we should change it to MutableArrayRef instead.
+class MutSparseTensorDescriptor
+ : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
+public:
+ MutSparseTensorDescriptor(Type tp, SmallVectorImpl<Value> &buffers)
+ : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(tp, buffers) {}
+
+ // Allow implicit type conversion from mutable descriptors to immutable ones
+ // (but not vice versa).
+ /*implicit*/ operator SparseTensorDescriptor() const {
+ return SparseTensorDescriptor(rType, fields);
}
///
- /// Setters: update the value for required field (only enabled for
- /// MutSparseTensorDescriptor).
+ /// Adds additional setters for mutable descriptor, update the value for
+ /// required field.
///
void setMemRefField(SparseTensorFieldKind kind, std::optional<unsigned> dim,
@@ -348,29 +375,6 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}
-
- std::pair<unsigned, unsigned>
- getIdxMemRefIndexAndStride(unsigned idxDim) const {
- StorageLayout layout(getSparseTensorEncoding(rType));
- return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
- idxDim);
- }
-
- Value getAOSMemRef() const {
- auto enc = getSparseTensorEncoding(rType);
- unsigned cooStart = getCOOStart(enc);
- assert(cooStart < enc.getDimLevelType().size());
- return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
- }
-};
-
-class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
-public:
- SparseTensorDescriptor(Type tp, ValueArrayRef buffers)
- : SparseTensorDescriptorImpl<false>(tp, buffers) {}
-
- Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
- unsigned idxDim) const;
};
/// Returns the "tuple" value of the adapted tensor.
@@ -386,7 +390,7 @@ inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
}
inline Value genTuple(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc) {
+ SparseTensorDescriptor desc) {
return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
}
More information about the Mlir-commits
mailing list