[Mlir-commits] [mlir] f40ee6e - [mlir][sparse] assemble SoA COO correctly. (#82449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 20 18:46:39 PST 2024


Author: Peiming Liu
Date: 2024-02-20T18:46:34-08:00
New Revision: f40ee6e83f263fc4240c5b8d31a7e0e148a28cf6

URL: https://github.com/llvm/llvm-project/commit/f40ee6e83f263fc4240c5b8d31a7e0e148a28cf6
DIFF: https://github.com/llvm/llvm-project/commit/f40ee6e83f263fc4240c5b8d31a7e0e148a28cf6.diff

LOG: [mlir][sparse] assemble SoA COO correctly. (#82449)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
    mlir/test/Dialect/SparseTensor/external.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 9414d81e6bf5c6..cd6b9b49893731 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -22,13 +22,9 @@ using namespace sparse_tensor;
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-// TODO: reuse StorageLayout::foreachField?
-
-// TODO: we need COO AoS and SoA
-
 // Convert type range to new types range, with sparse tensors externalized.
-void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
-               SmallVectorImpl<Type> *extraTypes = nullptr) {
+static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+                      SmallVectorImpl<Type> *extraTypes = nullptr) {
   for (auto type : types) {
     // All "dense" data passes through unmodified.
     if (!getSparseTensorEncoding(type)) {
@@ -42,29 +38,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     convTypes.push_back(vtp);
     if (extraTypes)
       extraTypes->push_back(vtp);
-    // Convert the external representations of the pos/crd arrays.
-    for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
-      const auto lt = stt.getLvlType(lvl);
-      if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
-        auto ptp = RankedTensorType::get(shape, stt.getPosType());
-        auto ctp = RankedTensorType::get(shape, stt.getCrdType());
-        convTypes.push_back(ptp);
-        convTypes.push_back(ctp);
-        if (extraTypes) {
-          extraTypes->push_back(ptp);
-          extraTypes->push_back(ctp);
-        }
-      } else {
-        assert(isDenseLT(lt)); // TODO: handle other cases
+
+    // Convert the external representation of the position/coordinate array.
+    foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
+                                               Type t, FieldIndex,
+                                               SparseTensorFieldKind kind,
+                                               Level, LevelType) {
+      if (kind == SparseTensorFieldKind::CrdMemRef ||
+          kind == SparseTensorFieldKind::PosMemRef) {
+        ShapedType st = t.cast<ShapedType>();
+        auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
+        convTypes.push_back(rtp);
+        if (extraTypes)
+          extraTypes->push_back(rtp);
       }
-    }
+      return true;
+    });
   }
 }
 
 // Convert input and output values to [dis]assemble ops for sparse tensors.
-void convVals(OpBuilder &builder, Location loc, TypeRange types,
-              ValueRange fromVals, ValueRange extraVals,
-              SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
+static void convVals(OpBuilder &builder, Location loc, TypeRange types,
+                     ValueRange fromVals, ValueRange extraVals,
+                     SmallVectorImpl<Value> &toVals, unsigned extra,
+                     bool isIn) {
   unsigned idx = 0;
   for (auto type : types) {
     // All "dense" data passes through unmodified.
@@ -85,29 +82,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
     if (!isIn) {
       inputs.push_back(extraVals[extra++]);
       retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType());
+      cntTypes.push_back(builder.getIndexType()); // nnz
     }
+
     // Collect the external representations of the pos/crd arrays.
-    for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
-      const auto lt = stt.getLvlType(lvl);
-      if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
+    foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level, LevelType) {
+      if (kind == SparseTensorFieldKind::CrdMemRef ||
+          kind == SparseTensorFieldKind::PosMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
-          inputs.push_back(fromVals[idx++]);
         } else {
-          Type pTp = stt.getPosType();
-          Type cTp = stt.getCrdType();
-          inputs.push_back(extraVals[extra++]);
+          ShapedType st = t.cast<ShapedType>();
+          auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
-          retTypes.push_back(RankedTensorType::get(shape, pTp));
-          retTypes.push_back(RankedTensorType::get(shape, cTp));
-          cntTypes.push_back(pTp);
-          cntTypes.push_back(cTp);
+          retTypes.push_back(rtp);
+          cntTypes.push_back(rtp.getElementType());
         }
-      } else {
-        assert(isDenseLT(lt)); // TODO: handle other cases
       }
-    }
+      return true;
+    });
+
     if (isIn) {
       // Assemble multiple inputs into a single sparse tensor.
       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);

diff  --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index c17ba13e86c926..b5701ad2024264 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -100,3 +100,27 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
   return %arg0 : tensor<64x64xf32, #sparse>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_inout_coo_soa(
+// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
+// CHECK:         sparse_tensor.disassemble %[[F]]
+// CHECK:         return
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_inout
+#sparse = #sparse_tensor.encoding<{
+   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+func.func @sparse_inout_coo_soa(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
+  return %arg0 : tensor<64x64xf32, #sparse>
+}


        


More information about the Mlir-commits mailing list