[Mlir-commits] [mlir] [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (PR #84079)

Aart Bik llvmlistbot at llvm.org
Tue Mar 5 14:45:08 PST 2024


================
@@ -73,34 +67,27 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     // Convert the external representation of the values array.
     auto rtp = cast<RankedTensorType>(type);
     const SparseTensorType stt(rtp);
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
-    // Collect the external representation of the values array for
-    // input or the outgoing sparse tensor for output.
-    inputs.push_back(fromVals[idx++]);
-    if (!isIn) {
-      inputs.push_back(extraVals[extra++]);
-      retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType()); // nnz
-    }
+    if (!isIn)
+      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
     // Collect the external representations of the pos/crd arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
                                                      Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
           ShapedType st = t.cast<ShapedType>();
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
-          cntTypes.push_back(rtp.getElementType());
+          cntTypes.push_back(builder.getIndexType());
----------------
aartbik wrote:

is this correct? is it always an index type or did we use the same type as metadata?

https://github.com/llvm/llvm-project/pull/84079


More information about the Mlir-commits mailing list