[Mlir-commits] [mlir] fc9f1d4 - [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (#84079)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 6 09:57:45 PST 2024


Author: Peiming Liu
Date: 2024-03-06T09:57:41-08:00
New Revision: fc9f1d49aae4328ef36e1d9ba606e93703fde970

URL: https://github.com/llvm/llvm-project/commit/fc9f1d49aae4328ef36e1d9ba606e93703fde970
DIFF: https://github.com/llvm/llvm-project/commit/fc9f1d49aae4328ef36e1d9ba606e93703fde970.diff

LOG: [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (#84079)

…rage layout.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
    mlir/test/Dialect/SparseTensor/external.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/pack_copy.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 }
 
 def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
-    Arguments<(ins TensorOf<[AnyType]>:$values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+    Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                   TensorOf<[AnyType]>:$values)>,
     Results<(outs AnySparseTensor: $result)> {
   let summary = "Returns a sparse tensor assembled from the given values and levels";
 
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
   }];
 
   let assemblyFormat =
-    "$values `,` $levels attr-dict"
-    "`:` type($values) `,` type($levels) `to` type($result)";
+    "` ` `(` $levels `)` `,` $values attr-dict"
+    " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
 
   let hasVerifier = 1;
 }
 
 def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
     Arguments<(ins AnySparseTensor:$tensor,
-                   TensorOf<[AnyType]>:$out_values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
-    Results<(outs TensorOf<[AnyType]>:$ret_values,
-                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnyIndexingScalarLike:$val_len,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                   TensorOf<[AnyType]>:$out_values)>,
+    Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                  TensorOf<[AnyType]>:$ret_values,
+                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                  AnyIndexingScalarLike:$val_len)> {
   let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
 
   let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
     //                  |0.0, 0.0, 0.0, 0.0|
     %v, %p, %c, %v_len, %p_len, %c_len =
         sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
-          outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-                            -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+          out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+          out_vals(%od) : tensor<3xf64> ->
+          tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
     // %v = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
     // %p = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
     // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
 
   let assemblyFormat =
     "$tensor `:` type($tensor) "
-    "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
-    "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+    "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+    "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+    "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+    "`(` type($lvl_lens) `)` `,` type($val_len)";
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..a91d32a23cac9f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -31,22 +31,16 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
       convTypes.push_back(type);
       continue;
     }
-    // Convert the external representation of the values array.
+
+    // Convert the external representation of the position/coordinate array
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
-    auto vtp = RankedTensorType::get(shape, stt.getElementType());
-    convTypes.push_back(vtp);
-    if (extraTypes)
-      extraTypes->push_back(vtp);
-
-    // Convert the external representation of the position/coordinate array.
     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
                                                Type t, FieldIndex,
                                                SparseTensorFieldKind kind,
                                                Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         ShapedType st = t.cast<ShapedType>();
         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
         convTypes.push_back(rtp);
@@ -70,29 +64,22 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
       toVals.push_back(fromVals[idx++]);
       continue;
     }
-    // Convert the external representation of the values array.
+    // Handle sparse data.
     auto rtp = cast<RankedTensorType>(type);
     const SparseTensorType stt(rtp);
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
-    // Collect the external representation of the values array for
-    // input or the outgoing sparse tensor for output.
-    inputs.push_back(fromVals[idx++]);
-    if (!isIn) {
-      inputs.push_back(extraVals[extra++]);
-      retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType()); // nnz
-    }
+    if (!isIn)
+      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
     // Collect the external representations of the pos/crd arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
                                                      Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
@@ -100,7 +87,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
-          cntTypes.push_back(rtp.getElementType());
+          cntTypes.push_back(builder.getIndexType());
         }
       }
       return true;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
-  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
-                                          ValueRange{rt, ct});
+  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+                                          vt);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
         sz = desc.getValMemSize(rewriter, loc);
         src = desc.getValMemRef();
         dst = genToMemref(rewriter, loc, op.getOutValues());
-        // Values is the last field in descriptor, but it is the first
-        // operand in unpack operation.
-        // TODO: maybe change unpack/pack operation instead to be
-        // consistent.
-        retMem.insert(retMem.begin(), dst);
+
+        retMem.push_back(dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(),
-                      genScalarToTensor(rewriter, loc, sz, valLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..010c3aa58b72c7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,13 +738,7 @@ class SparseTensorDisassembleConverter
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
     SmallVector<Value> retLen;
-    // Get the values buffer first.
-    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
-    auto valLenTp = op.getValLen().getType();
-    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
-    retVal.push_back(vals);
-    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
-    // Then get the positions and coordinates buffers.
+    // Get the positions and coordinates buffers.
     const Level lvlRank = stt.getLvlRank();
     Level trailCOOLen = 0;
     for (Level l = 0; l < lvlRank; l++) {
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
         auto poss =
             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(poss);
         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       }
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
         auto crds =
             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
-        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(crds);
         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
       }
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
                                    cooStartLvl);
       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(poss);
       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       // Coordinates, copied over with:
       //    for (i = 0; i < crdLen; i++)
       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
-      auto buf =
-          genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+      auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
                                       cooStartLvl);
       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
       args[1] = one;
       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
       rewriter.setInsertionPointAfter(forOp);
-      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(buf);
       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
     }
+    // Get the values buffer last.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
     // Converts MemRefs back to Tensors.
     assert(retVal.size() + retLen.size() == op.getNumResults());
     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
       retVal[i] =
           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
     }
+
     // Appends the actual memory length used in each buffer returned.
     retVal.append(retLen.begin(), retLen.end());
     rewriter.replaceOp(op, retVal);

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
 // CHECK:           %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
 // CHECK:           %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
 // CHECK:           %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
 // CHECK:           return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
 // CHECK:         }
 func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,

diff  --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
 // -----
 
 // CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME:    %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*5]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*7]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 395b812a7685b7..eac97f702f58bd 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -13,8 +13,8 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> {
 func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
                             -> tensor<?xf64, #SparseVector> {
   // expected-error at +1 {{the sparse-tensor must have static shape}}
-  %0 = sparse_tensor.assemble %values, %pos, %coordinates
-     : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<?xf64, #SparseVector>
+  %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+     : (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64> to tensor<?xf64, #SparseVector>
   return %0 : tensor<?xf64, #SparseVector>
 }
 
@@ -25,8 +25,8 @@ func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coo
 func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
                             -> tensor<100xf32, #SparseVector> {
   // expected-error at +1 {{input/output element-types don't match}}
-  %0 = sparse_tensor.assemble %values, %pos, %coordinates
-     : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<100xf32, #SparseVector>
+  %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+     : (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64> to tensor<100xf32, #SparseVector>
   return %0 : tensor<100xf32, #SparseVector>
 }
 
@@ -37,8 +37,8 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
 func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>)
                             -> tensor<100x2xf64, #SparseVector> {
   // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
-  %0 = sparse_tensor.assemble %values, %pos, %coordinates
-     : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32> to tensor<100x2xf64, #SparseVector>
+  %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+     : (tensor<2xi32>, tensor<6x3xi32>), tensor<6xf64> to tensor<100x2xf64, #SparseVector>
   return %0 : tensor<100x2xf64, #SparseVector>
 }
 
@@ -49,8 +49,8 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
 func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tensor<6xi32>)
                                      -> tensor<2x100xf64, #CSR> {
   // expected-error at +1 {{inconsistent number of fields between input/output}}
-  %0 = sparse_tensor.assemble %values, %coordinates
-     : tensor<6xf64>, tensor<6xi32> to tensor<2x100xf64, #CSR>
+  %0 = sparse_tensor.assemble (%coordinates), %values
+     : (tensor<6xi32>), tensor<6xf64> to tensor<2x100xf64, #CSR>
   return %0 : tensor<2x100xf64, #CSR>
 }
 
@@ -61,8 +61,9 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
 func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
   // expected-error at +1 {{input/output element-types don't match}}
   %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf32, #SparseVector>
-                  outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
-                  -> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index)
+                  out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x1xi32>)
+                  out_vals(%values : tensor<6xf64>)
+                  -> (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
   return
 }
 
@@ -73,8 +74,9 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten
 func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
   // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
   %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100x2xf64, #SparseVector>
-                  outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
-                  -> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index)
+                  out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x3xi32> )
+                  out_vals(%values : tensor<6xf64>)
+                  -> (tensor<2xi32>, tensor<6x3xi32>), tensor<6xf64>, (index, index), index
   return
 }
 
@@ -85,8 +87,9 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t
 func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
   // expected-error at +1 {{inconsistent number of fields between input/output}}
   %rv, %rc, %vl, %pl = sparse_tensor.disassemble %sp : tensor<2x100xf64, #CSR>
-             outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
-             -> tensor<6xf64>, (tensor<6xi32>), index, (index)
+             out_lvls(%coordinates : tensor<6xi32>)
+             out_vals(%values : tensor<6xf64>)
+             -> (tensor<6xi32>), tensor<6xf64>, (index), index
   return
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/pack_copy.mlir b/mlir/test/Dialect/SparseTensor/pack_copy.mlir
index e60f9bb7149b32..ec8f0b531fb218 100644
--- a/mlir/test/Dialect/SparseTensor/pack_copy.mlir
+++ b/mlir/test/Dialect/SparseTensor/pack_copy.mlir
@@ -19,26 +19,25 @@
 // This forces a copy for the values and positions.
 //
 // CHECK-LABEL: func.func @foo(
-// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>,
 // CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
-// CHECK-SAME: %[[POS:.*]]: memref<11xi32>)
-// CHECK:      %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xf64>
-// CHECK:      memref.copy %[[VAL]], %[[ALLOC1]] : memref<3xf64> to memref<3xf64>
+// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
+// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
 // CHECK:      %[[ALLOC2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<11xi32>
 // CHECK:      memref.copy %[[POS]], %[[ALLOC2]] : memref<11xi32> to memref<11xi32>
+// CHECK:      %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xf64>
+// CHECK:      memref.copy %[[VAL]], %[[ALLOC1]] : memref<3xf64> to memref<3xf64>
 // CHECK-NOT:  memref.copy
 // CHECK:      return
 //
-func.func @foo(%arg0: tensor<3xf64>  {bufferization.writable = false},
-               %arg1: tensor<3xi32>  {bufferization.writable = false},
-               %arg2: tensor<11xi32> {bufferization.writable = false}) -> (index) {
+func.func @foo(%arg1: tensor<3xi32>  {bufferization.writable = false},
+               %arg2: tensor<11xi32> {bufferization.writable = false},
+               %arg0: tensor<3xf64>  {bufferization.writable = false}) -> (index) {
     //
     // Pack the buffers into a sparse tensors.
     //
-    %pack = sparse_tensor.assemble %arg0, %arg2, %arg1
-      : tensor<3xf64>,
-        tensor<11xi32>,
-        tensor<3xi32> to tensor<10x10xf64, #CSR>
+    %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
+      : (tensor<11xi32>, tensor<3xi32>),
+         tensor<3xf64> to tensor<10x10xf64, #CSR>
 
     //
     // Scale the sparse tensor "in-place" (this has no impact on the final
@@ -64,22 +63,21 @@ func.func @foo(%arg0: tensor<3xf64>  {bufferization.writable = false},
 // Pass in the buffers of the sparse tensor, marked writable.
 //
 // CHECK-LABEL: func.func @bar(
-// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>,
 // CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
-// CHECK-SAME: %[[POS:.*]]: memref<11xi32>)
+// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
+// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
 // CHECK-NOT:  memref.copy
 // CHECK:      return
 //
-func.func @bar(%arg0: tensor<3xf64>  {bufferization.writable = true},
-               %arg1: tensor<3xi32>  {bufferization.writable = true},
-               %arg2: tensor<11xi32> {bufferization.writable = true}) -> (index) {
+func.func @bar(%arg1: tensor<3xi32>  {bufferization.writable = true},
+               %arg2: tensor<11xi32> {bufferization.writable = true},
+               %arg0: tensor<3xf64>  {bufferization.writable = true}) -> (index) {
     //
     // Pack the buffers into a sparse tensors.
     //
-    %pack = sparse_tensor.assemble %arg0, %arg2, %arg1
-      : tensor<3xf64>,
-        tensor<11xi32>,
-        tensor<3xi32> to tensor<10x10xf64, #CSR>
+    %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
+      : (tensor<11xi32>, tensor<3xi32>),
+         tensor<3xf64> to tensor<10x10xf64, #CSR>
 
     //
     // Scale the sparse tensor "in-place" (this has no impact on the final

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index f4a58df1d4d2d6..41094fbad9218f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -16,14 +16,14 @@ func.func @sparse_new(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
 #SparseVector = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed), posWidth=32, crdWidth=32}>
 
 // CHECK-LABEL: func @sparse_pack(
-// CHECK-SAME: %[[D:.*]]: tensor<6xf64>,
 // CHECK-SAME: %[[P:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>)
-//       CHECK: %[[R:.*]] = sparse_tensor.assemble %[[D]], %[[P]], %[[I]]
+// CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>,
+// CHECK-SAME: %[[D:.*]]: tensor<6xf64>)
+//       CHECK: %[[R:.*]] = sparse_tensor.assemble (%[[P]], %[[I]]), %[[D]]
 //       CHECK: return %[[R]] : tensor<100xf64, #{{.*}}>
-func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor<6x1xi32>)
+func.func @sparse_pack(%pos: tensor<2xi32>, %index: tensor<6x1xi32>, %data: tensor<6xf64>)
                             -> tensor<100xf64, #SparseVector> {
-  %0 = sparse_tensor.assemble %data, %pos, %index : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+  %0 = sparse_tensor.assemble (%pos, %index), %data: (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64>
                                              to tensor<100xf64, #SparseVector>
   return %0 : tensor<100xf64, #SparseVector>
 }
@@ -36,17 +36,18 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
 //  CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
 //  CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
 //  CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
-//       CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.disassemble %[[T]]
-//       CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
+//       CHECK: %[[P:.*]]:2, %[[D:.*]], %[[PL:.*]]:2, %[[DL:.*]] = sparse_tensor.disassemble %[[T]]
+//       CHECK: return %[[P]]#0, %[[P]]#1, %[[D]]
 func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
                          %od : tensor<6xf64>,
                          %op : tensor<2xindex>,
                          %oi : tensor<6x1xi32>)
-                       -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
-  %rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
-                  outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
-                  -> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index)
-  return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+                       -> (tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>) {
+  %rp, %ri, %rd, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
+                  out_lvls(%op, %oi : tensor<2xindex>, tensor<6x1xi32>)
+                  out_vals(%od : tensor<6xf64>)
+                  -> (tensor<2xindex>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
+  return %rp, %ri, %rd : tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 7cb699092f8833..a90194a74ee4a3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -31,8 +31,8 @@
 // CHECK:         }
 func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
                     -> tensor<100x100xf64, #COO> {
-  %0 = sparse_tensor.assemble %values, %pos, %coordinates
-     : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> to tensor<100x100xf64, #COO>
+  %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+     : (tensor<2xindex>, tensor<6x2xi32>), tensor<6xf64> to tensor<100x100xf64, #COO>
   return %0 : tensor<100x100xf64, #COO>
 }
 
@@ -60,9 +60,9 @@ func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinate
 // CHECK:           %[[VAL_18:.*]] = memref.subview %[[VAL_17]][0] {{\[}}%[[VAL_16]]] [1] : memref<6xf64> to memref<?xf64>
 // CHECK:           %[[VAL_19:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_16]]] [1] : memref<?xf64> to memref<?xf64>
 // CHECK:           memref.copy %[[VAL_19]], %[[VAL_18]] : memref<?xf64> to memref<?xf64>
-// CHECK:           %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
-// CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
-// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
+// CHECK-DAG:       %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
+// CHECK-DAG:       %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
+// CHECK-DAG:       %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
 // CHECK:           return %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
 // CHECK:         }
 func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
@@ -70,8 +70,9 @@ func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
                          %op : tensor<2xindex>,
                          %oi : tensor<6x2xi32>)
                        -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
-  %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<100x100xf64, #COO>
-                                 outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
-                                 -> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index)
+  %rp, %ri, %rd, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<100x100xf64, #COO>
+                                 out_lvls(%op, %oi : tensor<2xindex>, tensor<6x2xi32>)
+                                 out_vals(%od : tensor<6xf64>)
+                                 -> (tensor<2xindex>, tensor<6x2xi32>), tensor<6xf64>, (index, index), index
   return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 54de1024323b5f..aa17261724dbc0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -99,13 +99,13 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
 // CHECK-SAME:        %[[VAL_0:.*]]: tensor<?xf64>,
 // CHECK-SAME:        %[[VAL_1:.*]]: tensor<?xindex>,
 // CHECK-SAME:        %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
-// CHECK:           %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.assemble {{.*}} to tensor<1x2x2x2xf64, #[[$demap]]>
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
 // CHECK:           return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
 // CHECK:         }
 func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xindex>, %crd:tensor<?xindex>) -> tensor<2x4xf64, #BSR> {
-  %0 = sparse_tensor.assemble %val, %pos, %crd
-     : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<2x4xf64, #BSR>
+  %0 = sparse_tensor.assemble (%pos, %crd), %val
+     : (tensor<?xindex>, tensor<?xindex>), tensor<?xf64> to tensor<2x4xf64, #BSR>
   return %0 : tensor<2x4xf64, #BSR>
 }
 
@@ -115,7 +115,7 @@ func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xi
 // CHECK-SAME:         %[[VAL_2:.*]]: tensor<?xindex>,
 // CHECK-SAME:         %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
-// CHECK:           %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %{{.*}} = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
 // CHECK:           return
 // CHECK:         }
 func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
@@ -123,8 +123,9 @@ func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
                                               %op : tensor<?xindex>,
                                               %oi : tensor<?xindex>)
                                             -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
-  %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
-                                 outs(%od, %op, %oi : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>)
-                                 -> tensor<?xf64>, (tensor<?xindex>, tensor<?xindex>), index, (index, index)
+  %rp, %ri, %rd, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
+                                 out_lvls(%op, %oi : tensor<?xindex>, tensor<?xindex>)
+                                 out_vals(%od : tensor<?xf64>)
+                                 -> (tensor<?xindex>, tensor<?xindex>), tensor<?xf64>, (index, index), index
   return %rd, %rp, %ri : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 2b9b73a1990e65..b792d00681ddb4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -87,9 +87,9 @@ module {
         [  7,  8]]
     > : tensor<3x2xi32>
 
-    %s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
+    %s4 = sparse_tensor.assemble (%pos, %index), %data : (tensor<2xindex>, tensor<3x2xindex>), tensor<3xf64>
                                           to tensor<10x10xf64, #SortedCOO>
-    %s5 = sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+    %s5 = sparse_tensor.assemble (%pos32, %index32), %data : (tensor<2xi32>, tensor<3x2xi32>), tensor<3xf64>
                                           to tensor<10x10xf64, #SortedCOOI32>
 
     //
@@ -107,7 +107,7 @@ module {
     %csr_index32 = arith.constant dense<
        [1, 0, 1]
     > : tensor<3xi32>
-    %csr = sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
+    %csr = sparse_tensor.assemble (%csr_pos32, %csr_index32), %csr_data : (tensor<3xi32>, tensor<3xi32>), tensor<3xf64>
                                            to tensor<2x2xf64, #CSR>
 
     //
@@ -131,8 +131,8 @@ module {
        [ 10, 10]]
     > : tensor<6x2xindex>
 
-    %bs = sparse_tensor.assemble %bdata, %bpos, %bindex :
-          tensor<5xf64>, tensor<4xindex>,  tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
+    %bs = sparse_tensor.assemble (%bpos, %bindex), %bdata :
+          (tensor<4xindex>,  tensor<6x2xindex>), tensor<5xf64> to tensor<2x10x10xf64, #BCOO>
 
     //
     // Verify results.
@@ -231,9 +231,10 @@ module {
     %od = tensor.empty() : tensor<3xf64>
     %op = tensor.empty() : tensor<2xi32>
     %oi = tensor.empty() : tensor<3x2xi32>
-    %d, %p, %i, %dl, %pl, %il = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
-                 outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
-                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
+    %p, %i, %d, %dl, %pl, %il = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
+                 out_lvls(%op, %oi : tensor<2xi32>, tensor<3x2xi32>)
+                 out_vals(%od : tensor<3xf64>)
+                 -> (tensor<2xi32>, tensor<3x2xi32>), tensor<3xf64>, (i32, i64), index
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -246,9 +247,10 @@ module {
     %d_csr = tensor.empty() : tensor<4xf64>
     %p_csr = tensor.empty() : tensor<3xi32>
     %i_csr = tensor.empty() : tensor<3xi32>
-    %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
-                 outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
+    %rp_csr, %ri_csr, %rd_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
+                 out_lvls(%p_csr, %i_csr : tensor<3xi32>, tensor<3xi32>)
+                 out_vals(%d_csr : tensor<4xf64>)
+                 -> (tensor<3xi32>, tensor<3xi32>), tensor<4xf64>, (i32, i64), index
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<3xf64>
@@ -257,9 +259,10 @@ module {
     %bod = tensor.empty() : tensor<6xf64>
     %bop = tensor.empty() : tensor<4xindex>
     %boi = tensor.empty() : tensor<6x2xindex>
-    %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.disassemble %bs : tensor<2x10x10xf64, #BCOO>
-                    outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
+    %bp, %bi, %bd, %lp, %li, %ld = sparse_tensor.disassemble %bs : tensor<2x10x10xf64, #BCOO>
+                    out_lvls(%bop, %boi : tensor<4xindex>, tensor<6x2xindex>)
+                    out_vals(%bod : tensor<6xf64>)
+                    -> (tensor<4xindex>, tensor<6x2xindex>), tensor<6xf64>, (i32, tensor<i64>), index
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5 )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<5xf64>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
index da816c7fbb1172..8a65e2449c1574 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
@@ -71,11 +71,10 @@ module {
     %crd02 = arith.constant dense<
        [ 0, 1, 0, 1, 0, 0, 1, 0  ]> : tensor<8xi32>
 
-    %s0 = sparse_tensor.assemble %data0, %pos00, %crd00, %pos01, %crd01, %pos02, %crd02 :
-       tensor<8xf32>,
-       tensor<2xi64>, tensor<3xi32>,
-       tensor<4xi64>, tensor<5xi32>,
-       tensor<6xi64>, tensor<8xi32> to tensor<4x3x2xf32, #CCC>
+    %s0 = sparse_tensor.assemble (%pos00, %crd00, %pos01, %crd01, %pos02, %crd02), %data0 :
+       (tensor<2xi64>, tensor<3xi32>,
+        tensor<4xi64>, tensor<5xi32>,
+        tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
 
     //
     // Setup BatchedCSR.
@@ -89,7 +88,7 @@ module {
     %crd1 = arith.constant dense<
        [ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
 
-    %s1 = sparse_tensor.assemble %data1, %pos1, %crd1 : tensor<16xf32>, tensor<13xi64>, tensor<16xi32> to tensor<4x3x2xf32, #BatchedCSR>
+    %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
 
     //
     // Setup CSRDense.
@@ -103,7 +102,7 @@ module {
     %crd2 = arith.constant dense<
       [ 0, 1, 2, 0, 2, 0, 1, 2, 0, 1, 2 ]> : tensor<11xi32>
 
-    %s2 = sparse_tensor.assemble %data2, %pos2, %crd2 : tensor<22xf32>, tensor<5xi64>, tensor<11xi32> to tensor<4x3x2xf32, #CSRDense>
+    %s2 = sparse_tensor.assemble (%pos2, %crd2), %data2  : (tensor<5xi64>, tensor<11xi32>), tensor<22xf32> to tensor<4x3x2xf32, #CSRDense>
 
     //
     // Verify.


        


More information about the Mlir-commits mailing list