[Mlir-commits] [mlir] d114dfb - [mlir][sparse] Refactor the code that reshapes the values buffer for annotated all dense tensors.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 11 16:02:51 PST 2023
Author: bixia1
Date: 2023-01-11T16:02:46-08:00
New Revision: d114dfba2dc4187bd9b98291963e64706acab6b5
URL: https://github.com/llvm/llvm-project/commit/d114dfba2dc4187bd9b98291963e64706acab6b5
DIFF: https://github.com/llvm/llvm-project/commit/d114dfba2dc4187bd9b98291963e64706acab6b5.diff
LOG: [mlir][sparse] Refactor the code that reshapes the values buffer for annotated all dense tensors.
Move the functionality to codegen utils for sharing with the codegen path.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D141514
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index ff53640bb8d3f..03d9e6b293d5b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -526,6 +526,38 @@ void sparse_tensor::foreachInSparseConstant(
}
}
+void sparse_tensor::storeIndices(OpBuilder &builder, Location loc,
+ unsigned rank, Value ind, ValueRange ivs,
+ unsigned offsetDim, Value offset) {
+ for (unsigned i = 0; i < rank; i++) {
+ Value idx = ivs[i];
+ if (offsetDim == i && offset)
+ idx = builder.create<arith::AddIOp>(loc, idx, offset);
+ builder.create<memref::StoreOp>(loc, idx, ind,
+ constantIndex(builder, loc, i));
+ }
+}
+
+Value sparse_tensor::reshapeValuesToLevels(
+ OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc,
+ const SmallVectorImpl<Value> &dimSizes, Value valuesBuffer,
+ Value idxBuffer) {
+ // Use the dstIdx to store the level sizes.
+ unsigned rank = enc.getDimLevelType().size();
+ SmallVector<Value> lvlSizes;
+ for (unsigned i = 0; i < dimSizes.size(); i++)
+ lvlSizes.push_back(dimSizes[toOrigDim(enc, i)]);
+ storeIndices(builder, loc, rank, idxBuffer, lvlSizes);
+ // The memref ReshapeOp requires the sizes buffer to have a static
+ // shape.
+ idxBuffer = builder.create<memref::CastOp>(
+ loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer);
+ SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
+ Type elemTp = valuesBuffer.getType().cast<MemRefType>().getElementType();
+ return builder.create<memref::ReshapeOp>(loc, MemRefType::get(shape, elemTp),
+ valuesBuffer, idxBuffer);
+}
+
Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
Value tensor, uint64_t d) {
RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1c8cad5399d27..2ff60ebdf4fcd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -217,6 +217,20 @@ void foreachInSparseConstant(
Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
function_ref<void(ArrayRef<Value>, Value)> callback);
+/// Converts the vector indices and store it into the memory pointed by
+/// `ind`, apply (optional) `offset` on `offsetDim`.
+void storeIndices(OpBuilder &builder, Location loc, unsigned rank, Value ind,
+ ValueRange ivs, unsigned offsetDim = 0,
+ Value offset = Value());
+
+/// Reshapes the linear values buffer for an annotated all dense sparse tensor
+/// to match the shape of the corresponding dense tensor to support direct
+/// access of the buffer through indices.
+Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr enc,
+ const SmallVectorImpl<Value> &dimSizes,
+ Value valuesBuffer, Value idxBuffer);
+
//===----------------------------------------------------------------------===//
// Inlined constant generators.
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index bf61164682c88..aaeb041eb7bbc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -428,20 +428,6 @@ static SmallVector<Value> loadIndices(OpBuilder &builder, Location loc,
return ivs;
}
-/// Converts the vector indices and store it into the memory pointed by
-/// `ind`, apply (optional) `offset` on `offsetDim`.
-static void storeIndices(OpBuilder &builder, Location loc, unsigned rank,
- Value ind, ValueRange ivs, unsigned offsetDim = 0,
- Value offset = Value()) {
- for (unsigned i = 0; i < rank; i++) {
- Value idx = ivs[i];
- if (offsetDim == i && offset)
- idx = builder.create<arith::AddIOp>(loc, idx, offset);
- builder.create<memref::StoreOp>(loc, idx, ind,
- constantIndex(builder, loc, i));
- }
-}
-
/// Inserts a value stored in `elemPtr` into a dense tensor created by
/// allocDenseTensor().
static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
@@ -1375,19 +1361,8 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
dst = genValuesCall(rewriter, loc,
MemRefType::get({ShapedType::kDynamic}, elemTp),
{dst});
-
// Use the dstIdx to store the level sizes.
- SmallVector<Value> lvlSizes;
- for (unsigned i = 0; i < sizes.size(); i++)
- lvlSizes.push_back(sizes[toOrigDim(encDst, i)]);
- storeIndices(rewriter, loc, rank, dstIdx, lvlSizes);
- // The memref ReshapeOp requires the sizes buffer to have a static
- // shape.
- Value typedBuffer = rewriter.create<memref::CastOp>(
- loc, MemRefType::get({rank}, rewriter.getIndexType()), dstIdx);
- SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
- dst = rewriter.create<memref::ReshapeOp>(
- loc, MemRefType::get(shape, elemTp), dst, typedBuffer);
+ dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, dst, dstIdx);
} else {
dstPerm = params.getDim2LvlMap();
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
More information about the Mlir-commits
mailing list