[Mlir-commits] [mlir] [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (PR #84079)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 5 14:20:54 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…rage layout.

---

Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff


14 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+17-14) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+13-17) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-7) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+13-12) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/external.mlir (+33-33) 
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+17-14) 
- (modified) mlir/test/Dialect/SparseTensor/pack_copy.mlir (+18-20) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+13-12) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+9-8) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+8-7) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+17-14) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir (+6-7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 }
 
 def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
-    Arguments<(ins TensorOf<[AnyType]>:$values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+    Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                   TensorOf<[AnyType]>:$values)>,
     Results<(outs AnySparseTensor: $result)> {
   let summary = "Returns a sparse tensor assembled from the given values and levels";
 
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
   }];
 
   let assemblyFormat =
-    "$values `,` $levels attr-dict"
-    "`:` type($values) `,` type($levels) `to` type($result)";
+    "` ` `(` $levels `)` `,` $values attr-dict"
+    " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
 
   let hasVerifier = 1;
 }
 
 def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
     Arguments<(ins AnySparseTensor:$tensor,
-                   TensorOf<[AnyType]>:$out_values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
-    Results<(outs TensorOf<[AnyType]>:$ret_values,
-                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnyIndexingScalarLike:$val_len,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                   TensorOf<[AnyType]>:$out_values)>,
+    Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                  TensorOf<[AnyType]>:$ret_values,
+                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                  AnyIndexingScalarLike:$val_len)> {
   let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
 
   let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
     //                  |0.0, 0.0, 0.0, 0.0|
     %v, %p, %c, %v_len, %p_len, %c_len =
         sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
-          outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-                            -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+          out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+          out_vals(%od) : tensor<3xf64> ->
+          tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
     // %v = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
     // %p = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
     // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
 
   let assemblyFormat =
     "$tensor `:` type($tensor) "
-    "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
-    "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+    "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+    "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+    "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+    "`(` type($lvl_lens) `)` `,` type($val_len)";
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     }
     // Convert the external representation of the values array.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
-    auto vtp = RankedTensorType::get(shape, stt.getElementType());
-    convTypes.push_back(vtp);
-    if (extraTypes)
-      extraTypes->push_back(vtp);
+    //    auto shape = stt.getBatchLvlShape();
+    //    shape.push_back(ShapedType::kDynamic);
+    //    auto vtp = RankedTensorType::get(shape, stt.getElementType());
+    //    convTypes.push_back(vtp);
+    //    if (extraTypes)
+    //      extraTypes->push_back(vtp);
 
     // Convert the external representation of the position/coordinate array.
     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                                                SparseTensorFieldKind kind,
                                                Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         ShapedType st = t.cast<ShapedType>();
         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
         convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
-    // Collect the external representation of the values array for
-    // input or the outgoing sparse tensor for output.
-    inputs.push_back(fromVals[idx++]);
-    if (!isIn) {
-      inputs.push_back(extraVals[extra++]);
-      retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType()); // nnz
-    }
+    if (!isIn)
+      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
     // Collect the external representations of the pos/crd arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
                                                      Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
-          cntTypes.push_back(rtp.getElementType());
+          cntTypes.push_back(builder.getIndexType());
         }
       }
       return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
-  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
-                                          ValueRange{rt, ct});
+  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+                                          vt);
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
         sz = desc.getValMemSize(rewriter, loc);
         src = desc.getValMemRef();
         dst = genToMemref(rewriter, loc, op.getOutValues());
-        // Values is the last field in descriptor, but it is the first
-        // operand in unpack operation.
-        // TODO: maybe change unpack/pack operation instead to be
-        // consistent.
-        retMem.insert(retMem.begin(), dst);
+
+        retMem.push_back(dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(),
-                      genScalarToTensor(rewriter, loc, sz, valLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
     SmallVector<Value> retLen;
-    // Get the values buffer first.
-    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
-    auto valLenTp = op.getValLen().getType();
-    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
-    retVal.push_back(vals);
-    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
     // Then get the positions and coordinates buffers.
     const Level lvlRank = stt.getLvlRank();
     Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
         auto poss =
             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(poss);
         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       }
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
         auto crds =
             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
-        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(crds);
         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
       }
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
                                    cooStartLvl);
       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(poss);
       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       // Coordinates, copied over with:
       //    for (i = 0; i < crdLen; i++)
       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
-      auto buf =
-          genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+      auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
                                       cooStartLvl);
       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
       args[1] = one;
       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
       rewriter.setInsertionPointAfter(forOp);
-      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(buf);
       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
     }
+    // Get the values buffer last.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
     // Converts MemRefs back to Tensors.
     assert(retVal.size() + retLen.size() == op.getNumResults());
     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
       retVal[i] =
           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
     }
+
     // Appends the actual memory length used in each buffer returned.
     retVal.append(retLen.begin(), retLen.end());
     rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
 // CHECK:           %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
 // CHECK:           %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
 // CHECK:           %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
 // CHECK:           return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
 // CHECK:         }
 func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
 // -----
 
 // CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME:    %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*5]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*7]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
diff --git a/...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/84079


More information about the Mlir-commits mailing list