[Mlir-commits] [mlir] [mlir][sparse] assemble SoA COO correctly. (PR #82449)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 17:07:36 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/82449.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+37-40)
- (modified) mlir/test/Dialect/SparseTensor/external.mlir (+24)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 9414d81e6bf5c6..a107cd71959abd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -22,16 +22,13 @@ using namespace sparse_tensor;
// Helper methods.
//===----------------------------------------------------------------------===//
-// TODO: reuse StorageLayout::foreachField?
-
-// TODO: we need COO AoS and SoA
-
// Convert type range to new types range, with sparse tensors externalized.
-void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
- SmallVectorImpl<Type> *extraTypes = nullptr) {
+static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+ SmallVectorImpl<Type> *extraTypes = nullptr) {
for (auto type : types) {
+ auto enc = getSparseTensorEncoding(type);
// All "dense" data passes through unmodified.
- if (!getSparseTensorEncoding(type)) {
+ if (!enc) {
convTypes.push_back(type);
continue;
}
@@ -42,29 +39,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
convTypes.push_back(vtp);
if (extraTypes)
extraTypes->push_back(vtp);
- // Convert the external representations of the pos/crd arrays.
- for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
- const auto lt = stt.getLvlType(lvl);
- if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
- auto ptp = RankedTensorType::get(shape, stt.getPosType());
- auto ctp = RankedTensorType::get(shape, stt.getCrdType());
- convTypes.push_back(ptp);
- convTypes.push_back(ctp);
- if (extraTypes) {
- extraTypes->push_back(ptp);
- extraTypes->push_back(ctp);
- }
- } else {
- assert(isDenseLT(lt)); // TODO: handle other cases
+
+ // Convert the external representation of the position/coordinate array.
+ foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
+ Type t, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level, LevelType) {
+ if (kind == SparseTensorFieldKind::CrdMemRef ||
+ kind == SparseTensorFieldKind::PosMemRef) {
+ ShapedType st = t.cast<ShapedType>();
+ auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
+ convTypes.push_back(rtp);
+ if (extraTypes)
+ extraTypes->push_back(rtp);
}
- }
+ return true;
+ });
}
}
// Convert input and output values to [dis]assemble ops for sparse tensors.
-void convVals(OpBuilder &builder, Location loc, TypeRange types,
- ValueRange fromVals, ValueRange extraVals,
- SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
+static void convVals(OpBuilder &builder, Location loc, TypeRange types,
+ ValueRange fromVals, ValueRange extraVals,
+ SmallVectorImpl<Value> &toVals, unsigned extra,
+ bool isIn) {
unsigned idx = 0;
for (auto type : types) {
// All "dense" data passes through unmodified.
@@ -85,29 +83,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
if (!isIn) {
inputs.push_back(extraVals[extra++]);
retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType());
+ cntTypes.push_back(builder.getIndexType()); // nnz
}
+
// Collect the external representations of the pos/crd arrays.
- for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
- const auto lt = stt.getLvlType(lvl);
- if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
+ foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level, LevelType) {
+ if (kind == SparseTensorFieldKind::CrdMemRef ||
+ kind == SparseTensorFieldKind::PosMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
- inputs.push_back(fromVals[idx++]);
} else {
- Type pTp = stt.getPosType();
- Type cTp = stt.getCrdType();
- inputs.push_back(extraVals[extra++]);
+ ShapedType st = t.cast<ShapedType>();
+ auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, pTp));
- retTypes.push_back(RankedTensorType::get(shape, cTp));
- cntTypes.push_back(pTp);
- cntTypes.push_back(cTp);
+ retTypes.push_back(rtp);
+ cntTypes.push_back(rtp.getElementType());
}
- } else {
- assert(isDenseLT(lt)); // TODO: handle other cases
}
- }
+ return true;
+ });
+
if (isIn) {
// Assemble multiple inputs into a single sparse tensor.
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index c17ba13e86c926..b5701ad2024264 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -100,3 +100,27 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
return %arg0 : tensor<64x64xf32, #sparse>
}
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_inout_coo_soa(
+// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
+// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
+// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
+// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
+// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
+// CHECK: sparse_tensor.disassemble %[[F]]
+// CHECK: return
+// CHECK: }
+// CHECK: func.func private @_internal_sparse_inout
+#sparse = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+func.func @sparse_inout_coo_soa(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
+ return %arg0 : tensor<64x64xf32, #sparse>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/82449
More information about the Mlir-commits
mailing list