[Mlir-commits] [mlir] 63d31a4 - [mlir][sparse] Move some member functions from SparseTensorDescriptorImpl to MutSparseTensorDescriptor.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 4 13:05:49 PST 2023
Author: bixia1
Date: 2023-01-04T13:05:43-08:00
New Revision: 63d31a4d15aa5f7fe4a6c504bba9944e5b0afc6a
URL: https://github.com/llvm/llvm-project/commit/63d31a4d15aa5f7fe4a6c504bba9944e5b0afc6a
DIFF: https://github.com/llvm/llvm-project/commit/63d31a4d15aa5f7fe4a6c504bba9944e5b0afc6a.diff
LOG: [mlir][sparse] Move some member functions from SparseTensorDescriptorImpl to MutSparseTensorDescriptor.
This is to prepare for implementing AOS optimization.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D141002
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 7228a235324f5..ff53640bb8d3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "SparseTensorStorageLayout.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -551,4 +552,11 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Type valTp = get1DMemRefType(srcTp.getElementType(),
/*withLayout=*/false);
return builder.create<ToValuesOp>(loc, valTp, tensor);
+}
+
+Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
+ Value tensor) {
+ SmallVector<Value> fields;
+ auto desc = getMutDescriptorFromTensorTuple(tensor, fields);
+ return desc.getValMemSize(builder, loc);
}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 4ec9c25db176e..1c8cad5399d27 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -336,6 +336,9 @@ Value genToIndices(OpBuilder &builder, Location loc, Value tensor, uint64_t d,
/// Infers the result type and generates ToValuesOp.
Value genToValues(OpBuilder &builder, Location loc, Value tensor);
+/// Generates code to retrieve the values size for the sparse tensor.
+Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d138b6de7f94e..bb5128bf14d54 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -967,9 +967,9 @@ class SparseNumberOfEntriesConverter
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Query memSizes for the actually stored values size.
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
+ // Query memSizes for the actually stored values.
+ rewriter.replaceOp(
+ op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 45aac6d83a2a7..61244452bd5b7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -154,7 +154,7 @@ class SparseTensorSpecifier {
/// instead relies on this class to access the right value for the right field.
template <bool mut>
class SparseTensorDescriptorImpl {
-private:
+protected:
// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
// for mutable descriptors.
// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
@@ -220,21 +220,6 @@ class SparseTensorDescriptorImpl {
return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
}
- Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
- dim);
- }
-
- Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
- dim);
- }
-
- Value getValMemSize(OpBuilder &builder, Location loc) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
- std::nullopt);
- }
-
Value getPtrMemRef(unsigned ptrDim) const {
return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
}
@@ -262,25 +247,70 @@ class SparseTensorDescriptorImpl {
return fields[fidx];
}
+ ValueRange getMemRefFields() const {
+ ValueRange ret = fields;
+ // Drop the last metadata fields.
+ return ret.slice(0, fields.size() - 1);
+ }
+
+ Type getMemRefElementType(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ return getMemRefField(kind, dim)
+ .getType()
+ .template cast<MemRefType>()
+ .getElementType();
+ }
+
+ RankedTensorType getTensorType() const { return rType; }
+ ValueArrayRef getFields() const { return fields; }
+
+protected:
+ RankedTensorType rType;
+ ValueArrayRef fields;
+};
+
+class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
+public:
+ MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
+ : SparseTensorDescriptorImpl<true>(tp, buffers) {}
+
+ ///
+ /// Getters: get the value for required field.
+ ///
+
+ Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+ dim);
+ }
+
+ Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+ dim);
+ }
+
+ Value getValMemSize(OpBuilder &builder, Location loc) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+ std::nullopt);
+ }
+
///
/// Setters: update the value for required field (only enabled for
/// MutSparseTensorDescriptor).
///
template <typename T = Value>
- void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
- std::enable_if_t<mut, T> v) {
+ void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim, T v) {
fields[getMemRefFieldIndex(kind, dim)] = v;
}
template <typename T = Value>
- void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ void setMemRefField(unsigned fidx, T v) {
assert(fidx < fields.size() - 1);
fields[fidx] = v;
}
template <typename T = Value>
- void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ void setField(unsigned fidx, T v) {
assert(fidx < fields.size());
fields[fidx] = v;
}
@@ -288,42 +318,19 @@ class SparseTensorDescriptorImpl {
template <typename T = Value>
void setSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, Optional<unsigned> dim,
- std::enable_if_t<mut, T> v) {
+ T v) {
SparseTensorSpecifier md(fields.back());
md.setSpecifierField(builder, loc, v, kind, dim);
fields.back() = md;
}
template <typename T = Value>
- void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
- std::enable_if_t<mut, T> v) {
+ void setDimSize(OpBuilder &builder, Location loc, unsigned dim, T v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}
-
- ValueRange getMemRefFields() const {
- ValueRange ret = fields;
- // drop the last metadata fields
- return ret.slice(0, fields.size() - 1);
- }
-
- Type getMemRefElementType(SparseTensorFieldKind kind,
- Optional<unsigned> dim) const {
- return getMemRefField(kind, dim)
- .getType()
- .template cast<MemRefType>()
- .getElementType();
- }
-
- RankedTensorType getTensorType() const { return rType; }
- ValueArrayRef getFields() const { return fields; }
-
-private:
- RankedTensorType rType;
- ValueArrayRef fields;
};
using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {
More information about the Mlir-commits
mailing list