[Mlir-commits] [mlir] [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (PR #84079)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 5 14:20:55 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
…rage layout.
---
Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+17-14)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+13-17)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+2-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-7)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+13-12)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/external.mlir (+33-33)
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+17-14)
- (modified) mlir/test/Dialect/SparseTensor/pack_copy.mlir (+18-20)
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+13-12)
- (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+9-8)
- (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+8-7)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+17-14)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir (+6-7)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
}
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
- Arguments<(ins TensorOf<[AnyType]>:$values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+ Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+ TensorOf<[AnyType]>:$values)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor assembled from the given values and levels";
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
}];
let assemblyFormat =
- "$values `,` $levels attr-dict"
- "`:` type($values) `,` type($levels) `to` type($result)";
+ "` ` `(` $levels `)` `,` $values attr-dict"
+ " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
- TensorOf<[AnyType]>:$out_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
- Results<(outs TensorOf<[AnyType]>:$ret_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- AnyIndexingScalarLike:$val_len,
- Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+ TensorOf<[AnyType]>:$out_values)>,
+ Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ TensorOf<[AnyType]>:$ret_values,
+ Variadic<AnyIndexingScalarLike>:$lvl_lens,
+ AnyIndexingScalarLike:$val_len)> {
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
- -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+ out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+ out_vals(%od) : tensor<3xf64> ->
+ tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
let assemblyFormat =
"$tensor `:` type($tensor) "
- "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
- "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+ "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+ "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+ "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+ "`(` type($lvl_lens) `)` `,` type($val_len)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
- auto vtp = RankedTensorType::get(shape, stt.getElementType());
- convTypes.push_back(vtp);
- if (extraTypes)
- extraTypes->push_back(vtp);
+ // auto shape = stt.getBatchLvlShape();
+ // shape.push_back(ShapedType::kDynamic);
+ // auto vtp = RankedTensorType::get(shape, stt.getElementType());
+ // convTypes.push_back(vtp);
+ // if (extraTypes)
+ // extraTypes->push_back(vtp);
// Convert the external representation of the position/coordinate array.
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
- // Collect the external representation of the values array for
- // input or the outgoing sparse tensor for output.
- inputs.push_back(fromVals[idx++]);
- if (!isIn) {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType()); // nnz
- }
+ if (!isIn)
+ inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
- cntTypes.push_back(rtp.getElementType());
+ cntTypes.push_back(builder.getIndexType());
}
}
return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
- rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
- ValueRange{rt, ct});
+ rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+ vt);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
sz = desc.getValMemSize(rewriter, loc);
src = desc.getValMemRef();
dst = genToMemref(rewriter, loc, op.getOutValues());
- // Values is the last field in descriptor, but it is the first
- // operand in unpack operation.
- // TODO: maybe change unpack/pack operation instead to be
- // consistent.
- retMem.insert(retMem.begin(), dst);
+
+ retMem.push_back(dst);
Type valLenTp = op.getValLen().getType();
- retLen.insert(retLen.begin(),
- genScalarToTensor(rewriter, loc, sz, valLenTp));
+ retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
- // Get the values buffer first.
- auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
- auto valLenTp = op.getValLen().getType();
- auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
- retVal.push_back(vals);
- retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
- auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
- auto buf =
- genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+ auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
- auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
+ // Get the values buffer last.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
+
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
// -----
// CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
// -----
// CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
// -----
// CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
// -----
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
diff --git a/...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/84079
More information about the Mlir-commits
mailing list