[Mlir-commits] [mlir] [mlir][sparse] support sparse constant to BSR conversion. (PR #71114)
Peiming Liu
llvmlistbot at llvm.org
Thu Nov 2 14:31:12 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/71114
>From 1e190f542db2a0c331516aed1f0695664ec1a47d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 2 Nov 2023 21:13:06 +0000
Subject: [PATCH 1/2] [mlir][sparse] support sparse constant to BSR conversion.
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 7 ++--
.../SparseTensor/Transforms/CodegenUtils.cpp | 34 +++++++++++--------
.../Transforms/SparseTensorRewriting.cpp | 11 +++---
.../CPU/sparse_conversion_block.mlir | 23 ++++++-------
4 files changed, 38 insertions(+), 37 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 99214fadf4ba3db..76ecd4171d81aa8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1665,11 +1665,10 @@ LogicalResult ForeachOp::verify() {
const Dimension dimRank = t.getDimRank();
const auto args = getBody()->getArguments();
- if (getOrder().has_value() &&
- (t.getEncoding() || !getOrder()->isPermutation()))
- return emitError("Only support permuted order on non encoded dense tensor");
+ if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
+ return emitError("Level traverse order does not match tensor's level rank");
- if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
+ if (dimRank + 1 + getInitArgs().size() != args.size())
return emitError("Unmatched number of arguments in the block");
if (getNumResults() != getInitArgs().size())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index f6fb59fa2c3b84b..db969436a30712d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
void sparse_tensor::foreachInSparseConstant(
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
function_ref<void(ArrayRef<Value>, Value)> callback) {
- const Dimension dimRank =
- SparseTensorType(getRankedTensorType(attr)).getDimRank();
+ if (!order)
+ order = builder.getMultiDimIdentityMap(attr.getType().getRank());
+
+ auto stt = SparseTensorType(getRankedTensorType(attr));
+ const Dimension dimRank = stt.getDimRank();
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
const auto values = attr.getValues().getValues<Attribute>();
@@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant(
// Sorts the sparse element attribute based on coordinates.
std::sort(elems.begin(), elems.end(),
- [order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) {
- const auto &lhsCoords = lhs.first;
- const auto &rhsCoords = rhs.first;
- for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: This only makes sense for permutations.
- // And since we don't check that `order` is a permutation,
- // it can also cause OOB errors when we use `l`.
- const Level l = order ? order.getDimPosition(d) : d;
- if (lhsCoords[l].getInt() == rhsCoords[l].getInt())
- continue;
- return lhsCoords[l].getInt() < rhsCoords[l].getInt();
- }
+ [order](const ElementAttr &lhs, const ElementAttr &rhs) {
if (std::addressof(lhs) == std::addressof(rhs))
return false;
+
+ auto lhsCoords = llvm::map_to_vector(
+ lhs.first, [](IntegerAttr i) { return i.getInt(); });
+ auto rhsCoords = llvm::map_to_vector(
+ rhs.first, [](IntegerAttr i) { return i.getInt(); });
+
+ SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
+ SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
+ // Sort the element based on the lvl coordinates.
+ for (Level l = 0; l < order.getNumResults(); l++) {
+ if (lhsLvlCrds[l] == rhsLvlCrds[l])
+ continue;
+ return lhsLvlCrds[l] < rhsLvlCrds[l];
+ }
llvm_unreachable("no equal coordinate in sparse element attr");
});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 13388dce6bbb5ec..7770bd857e88093 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
- // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
- const Dimension dimRank = stt.getDimRank();
- SmallVector<Value> dcvs = lcvs; // keep a copy
- for (Dimension d = 0; d < dimRank; d++) {
- auto l = op.getOrder()->getDimPosition(d);
- lcvs[l] = dcvs[d];
- }
+ // TODO: Support it so that we can do direct conversion from CSR->BSR.
+ llvm_unreachable(
+ "Level order not yet implemented on non-constant input tensors.");
}
+
Value vals = loopEmitter.getValBuffer()[0];
Value pos = loopEmitter.getPosits()[0].back();
// Loads the value from sparse tensor using position-index;
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
index ec14492c5b44999..eb9e23f9b7dc846 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -74,22 +74,21 @@ module {
//
// Initialize a 2-dim dense tensor.
//
- %t = arith.constant dense<[
- [ 1.0, 2.0, 3.0, 4.0 ],
- [ 5.0, 6.0, 7.0, 8.0 ]
- ]> : tensor<2x4xf64>
+ %t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3],
+ [1, 0], [1, 1], [1, 2], [1, 3]],
+ [ 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
+ %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+ %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
+ %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
- %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
- %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
- %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
-
- %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
- %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+ %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
+ %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #CSR> to memref<?xf64>
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
- // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
- // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+ // CHECK: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+ // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
// CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
call @dumpf64(%v1) : (memref<?xf64>) -> ()
call @dumpf64(%v2) : (memref<?xf64>) -> ()
>From 4780f9badabf553648370608f412fe59166cf6d4 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 2 Nov 2023 21:30:43 +0000
Subject: [PATCH 2/2] add dense test
---
.../CPU/sparse_conversion_block.mlir | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
index eb9e23f9b7dc846..ccd61aea0ba1684 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -79,20 +79,29 @@ module {
[ 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
+ %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
+ [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64>
+
+ // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
- %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
- %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
+ %2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+ %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
+ %4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
- %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #CSR> to memref<?xf64>
- %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
+ %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+ %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref<?xf64>
+ %v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref<?xf64>
+
// CHECK: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+ // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
// CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
call @dumpf64(%v1) : (memref<?xf64>) -> ()
call @dumpf64(%v2) : (memref<?xf64>) -> ()
call @dumpf64(%v3) : (memref<?xf64>) -> ()
+ call @dumpf64(%v4) : (memref<?xf64>) -> ()
return
}
More information about the Mlir-commits
mailing list