[Mlir-commits] [mlir] [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (PR #84079)
Peiming Liu
llvmlistbot at llvm.org
Tue Mar 5 14:38:48 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/84079
>From 550eb4547ea17e085917c92e4a3717294dbb6031 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 5 Mar 2024 22:19:49 +0000
Subject: [PATCH] [mlir][sparse] use a consistent order between [dis]assembleOp
and storage layout.
---
.../SparseTensor/IR/SparseTensorOps.td | 31 +++++----
.../Transforms/SparseAssembler.cpp | 27 ++------
.../Transforms/SparseGPUCodegen.cpp | 4 +-
.../Transforms/SparseTensorCodegen.cpp | 10 +--
.../Transforms/SparseTensorConversion.cpp | 25 +++----
.../SparseTensor/GPU/gpu_spgemm_lib.mlir | 2 +-
mlir/test/Dialect/SparseTensor/external.mlir | 66 +++++++++----------
mlir/test/Dialect/SparseTensor/invalid.mlir | 31 +++++----
mlir/test/Dialect/SparseTensor/pack_copy.mlir | 38 +++++------
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 25 +++----
.../Dialect/SparseTensor/sparse_pack.mlir | 17 ++---
.../SparseTensor/sparse_reinterpret_map.mlir | 15 +++--
.../Dialect/SparseTensor/CPU/sparse_pack.mlir | 31 +++++----
.../SparseTensor/CPU/sparse_pack_d.mlir | 13 ++--
14 files changed, 164 insertions(+), 171 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
}
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
- Arguments<(ins TensorOf<[AnyType]>:$values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+ Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+ TensorOf<[AnyType]>:$values)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor assembled from the given values and levels";
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
}];
let assemblyFormat =
- "$values `,` $levels attr-dict"
- "`:` type($values) `,` type($levels) `to` type($result)";
+ "` ` `(` $levels `)` `,` $values attr-dict"
+ " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
- TensorOf<[AnyType]>:$out_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
- Results<(outs TensorOf<[AnyType]>:$ret_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- AnyIndexingScalarLike:$val_len,
- Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+ TensorOf<[AnyType]>:$out_values)>,
+ Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ TensorOf<[AnyType]>:$ret_values,
+ Variadic<AnyIndexingScalarLike>:$lvl_lens,
+ AnyIndexingScalarLike:$val_len)> {
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
- -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+ out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+ out_vals(%od) : tensor<3xf64> ->
+ tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
let assemblyFormat =
"$tensor `:` type($tensor) "
- "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
- "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+ "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+ "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+ "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+ "`(` type($lvl_lens) `)` `,` type($val_len)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..b003b5adb26a65 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,20 +33,14 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
- auto vtp = RankedTensorType::get(shape, stt.getElementType());
- convTypes.push_back(vtp);
- if (extraTypes)
- extraTypes->push_back(vtp);
-
// Convert the external representation of the position/coordinate array.
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
@@ -73,26 +67,19 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Convert the external representation of the values array.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
- // Collect the external representation of the values array for
- // input or the outgoing sparse tensor for output.
- inputs.push_back(fromVals[idx++]);
- if (!isIn) {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType()); // nnz
- }
+ if (!isIn)
+ inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
@@ -100,7 +87,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
- cntTypes.push_back(rtp.getElementType());
+ cntTypes.push_back(builder.getIndexType());
}
}
return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
- rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
- ValueRange{rt, ct});
+ rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+ vt);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
sz = desc.getValMemSize(rewriter, loc);
src = desc.getValMemRef();
dst = genToMemref(rewriter, loc, op.getOutValues());
- // Values is the last field in descriptor, but it is the first
- // operand in unpack operation.
- // TODO: maybe change unpack/pack operation instead to be
- // consistent.
- retMem.insert(retMem.begin(), dst);
+
+ retMem.push_back(dst);
Type valLenTp = op.getValLen().getType();
- retLen.insert(retLen.begin(),
- genScalarToTensor(rewriter, loc, sz, valLenTp));
+ retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
- // Get the values buffer first.
- auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
- auto valLenTp = op.getValLen().getType();
- auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
- retVal.push_back(vals);
- retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
- auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
- auto buf =
- genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+ auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
- auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
+ // Get the values buffer last.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
+
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
// -----
// CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
// -----
// CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
// -----
// CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
// -----
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 395b812a7685b7..eac97f702f58bd 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -13,8 +13,8 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> {
func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
-> tensor<?xf64, #SparseVector> {
// expected-error at +1 {{the sparse-tensor must have static shape}}
- %0 = sparse_tensor.assemble %values, %pos, %coordinates
- : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<?xf64, #SparseVector>
+ %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+ : (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64> to tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
@@ -25,8 +25,8 @@ func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coo
func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
-> tensor<100xf32, #SparseVector> {
// expected-error at +1 {{input/output element-types don't match}}
- %0 = sparse_tensor.assemble %values, %pos, %coordinates
- : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<100xf32, #SparseVector>
+ %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+ : (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64> to tensor<100xf32, #SparseVector>
return %0 : tensor<100xf32, #SparseVector>
}
@@ -37,8 +37,8 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>)
-> tensor<100x2xf64, #SparseVector> {
// expected-error at +1 {{input/output trailing COO level-ranks don't match}}
- %0 = sparse_tensor.assemble %values, %pos, %coordinates
- : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32> to tensor<100x2xf64, #SparseVector>
+ %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+ : (tensor<2xi32>, tensor<6x3xi32>), tensor<6xf64> to tensor<100x2xf64, #SparseVector>
return %0 : tensor<100x2xf64, #SparseVector>
}
@@ -49,8 +49,8 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tensor<6xi32>)
-> tensor<2x100xf64, #CSR> {
// expected-error at +1 {{inconsistent number of fields between input/output}}
- %0 = sparse_tensor.assemble %values, %coordinates
- : tensor<6xf64>, tensor<6xi32> to tensor<2x100xf64, #CSR>
+ %0 = sparse_tensor.assemble (%coordinates), %values
+ : (tensor<6xi32>), tensor<6xf64> to tensor<2x100xf64, #CSR>
return %0 : tensor<2x100xf64, #CSR>
}
@@ -61,8 +61,9 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
// expected-error at +1 {{input/output element-types don't match}}
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf32, #SparseVector>
- outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
- -> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index)
+ out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x1xi32>)
+ out_vals(%values : tensor<6xf64>)
+ -> (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
return
}
@@ -73,8 +74,9 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten
func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
// expected-error at +1 {{input/output trailing COO level-ranks don't match}}
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100x2xf64, #SparseVector>
- outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
- -> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index)
+ out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x3xi32> )
+ out_vals(%values : tensor<6xf64>)
+ -> (tensor<2xi32>, tensor<6x3xi32>), tensor<6xf64>, (index, index), index
return
}
@@ -85,8 +87,9 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t
func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
// expected-error at +1 {{inconsistent number of fields between input/output}}
%rv, %rc, %vl, %pl = sparse_tensor.disassemble %sp : tensor<2x100xf64, #CSR>
- outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
- -> tensor<6xf64>, (tensor<6xi32>), index, (index)
+ out_lvls(%coordinates : tensor<6xi32>)
+ out_vals(%values : tensor<6xf64>)
+ -> (tensor<6xi32>), tensor<6xf64>, (index), index
return
}
diff --git a/mlir/test/Dialect/SparseTensor/pack_copy.mlir b/mlir/test/Dialect/SparseTensor/pack_copy.mlir
index e60f9bb7149b32..ec8f0b531fb218 100644
--- a/mlir/test/Dialect/SparseTensor/pack_copy.mlir
+++ b/mlir/test/Dialect/SparseTensor/pack_copy.mlir
@@ -19,26 +19,25 @@
// This forces a copy for the values and positions.
//
// CHECK-LABEL: func.func @foo(
-// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>,
// CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
-// CHECK-SAME: %[[POS:.*]]: memref<11xi32>)
-// CHECK: %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xf64>
-// CHECK: memref.copy %[[VAL]], %[[ALLOC1]] : memref<3xf64> to memref<3xf64>
+// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
+// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
// CHECK: %[[ALLOC2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<11xi32>
// CHECK: memref.copy %[[POS]], %[[ALLOC2]] : memref<11xi32> to memref<11xi32>
+// CHECK: %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xf64>
+// CHECK: memref.copy %[[VAL]], %[[ALLOC1]] : memref<3xf64> to memref<3xf64>
// CHECK-NOT: memref.copy
// CHECK: return
//
-func.func @foo(%arg0: tensor<3xf64> {bufferization.writable = false},
- %arg1: tensor<3xi32> {bufferization.writable = false},
- %arg2: tensor<11xi32> {bufferization.writable = false}) -> (index) {
+func.func @foo(%arg1: tensor<3xi32> {bufferization.writable = false},
+ %arg2: tensor<11xi32> {bufferization.writable = false},
+ %arg0: tensor<3xf64> {bufferization.writable = false}) -> (index) {
//
// Pack the buffers into a sparse tensors.
//
- %pack = sparse_tensor.assemble %arg0, %arg2, %arg1
- : tensor<3xf64>,
- tensor<11xi32>,
- tensor<3xi32> to tensor<10x10xf64, #CSR>
+ %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
+ : (tensor<11xi32>, tensor<3xi32>),
+ tensor<3xf64> to tensor<10x10xf64, #CSR>
//
// Scale the sparse tensor "in-place" (this has no impact on the final
@@ -64,22 +63,21 @@ func.func @foo(%arg0: tensor<3xf64> {bufferization.writable = false},
// Pass in the buffers of the sparse tensor, marked writable.
//
// CHECK-LABEL: func.func @bar(
-// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>,
// CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
-// CHECK-SAME: %[[POS:.*]]: memref<11xi32>)
+// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
+// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
// CHECK-NOT: memref.copy
// CHECK: return
//
-func.func @bar(%arg0: tensor<3xf64> {bufferization.writable = true},
- %arg1: tensor<3xi32> {bufferization.writable = true},
- %arg2: tensor<11xi32> {bufferization.writable = true}) -> (index) {
+func.func @bar(%arg1: tensor<3xi32> {bufferization.writable = true},
+ %arg2: tensor<11xi32> {bufferization.writable = true},
+ %arg0: tensor<3xf64> {bufferization.writable = true}) -> (index) {
//
// Pack the buffers into a sparse tensors.
//
- %pack = sparse_tensor.assemble %arg0, %arg2, %arg1
- : tensor<3xf64>,
- tensor<11xi32>,
- tensor<3xi32> to tensor<10x10xf64, #CSR>
+ %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
+ : (tensor<11xi32>, tensor<3xi32>),
+ tensor<3xf64> to tensor<10x10xf64, #CSR>
//
// Scale the sparse tensor "in-place" (this has no impact on the final
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index f4a58df1d4d2d6..41094fbad9218f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -16,14 +16,14 @@ func.func @sparse_new(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
#SparseVector = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed), posWidth=32, crdWidth=32}>
// CHECK-LABEL: func @sparse_pack(
-// CHECK-SAME: %[[D:.*]]: tensor<6xf64>,
// CHECK-SAME: %[[P:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>)
-// CHECK: %[[R:.*]] = sparse_tensor.assemble %[[D]], %[[P]], %[[I]]
+// CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>,
+// CHECK-SAME: %[[D:.*]]: tensor<6xf64>)
+// CHECK: %[[R:.*]] = sparse_tensor.assemble (%[[P]], %[[I]]), %[[D]]
// CHECK: return %[[R]] : tensor<100xf64, #{{.*}}>
-func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor<6x1xi32>)
+func.func @sparse_pack(%pos: tensor<2xi32>, %index: tensor<6x1xi32>, %data: tensor<6xf64>)
-> tensor<100xf64, #SparseVector> {
- %0 = sparse_tensor.assemble %data, %pos, %index : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+ %0 = sparse_tensor.assemble (%pos, %index), %data: (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64>
to tensor<100xf64, #SparseVector>
return %0 : tensor<100xf64, #SparseVector>
}
@@ -36,17 +36,18 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
-// CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.disassemble %[[T]]
-// CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
+// CHECK: %[[P:.*]]:2, %[[D:.*]], %[[PL:.*]]:2, %[[DL:.*]] = sparse_tensor.disassemble %[[T]]
+// CHECK: return %[[P]]#0, %[[P]]#1, %[[D]]
func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
%od : tensor<6xf64>,
%op : tensor<2xindex>,
%oi : tensor<6x1xi32>)
- -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
- %rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
- outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
- -> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index)
- return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+ -> (tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>) {
+ %rp, %ri, %rd, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
+ out_lvls(%op, %oi : tensor<2xindex>, tensor<6x1xi32>)
+ out_vals(%od : tensor<6xf64>)
+ -> (tensor<2xindex>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
+ return %rp, %ri, %rd : tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 7cb699092f8833..a90194a74ee4a3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -31,8 +31,8 @@
// CHECK: }
func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
-> tensor<100x100xf64, #COO> {
- %0 = sparse_tensor.assemble %values, %pos, %coordinates
- : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> to tensor<100x100xf64, #COO>
+ %0 = sparse_tensor.assemble (%pos, %coordinates), %values
+ : (tensor<2xindex>, tensor<6x2xi32>), tensor<6xf64> to tensor<100x100xf64, #COO>
return %0 : tensor<100x100xf64, #COO>
}
@@ -60,9 +60,9 @@ func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinate
// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_17]][0] {{\[}}%[[VAL_16]]] [1] : memref<6xf64> to memref<?xf64>
// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_16]]] [1] : memref<?xf64> to memref<?xf64>
// CHECK: memref.copy %[[VAL_19]], %[[VAL_18]] : memref<?xf64> to memref<?xf64>
-// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
-// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
-// CHECK: %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
+// CHECK-DAG: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
+// CHECK-DAG: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
+// CHECK-DAG: %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
// CHECK: return %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
// CHECK: }
func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
@@ -70,8 +70,9 @@ func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
%op : tensor<2xindex>,
%oi : tensor<6x2xi32>)
-> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
- %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<100x100xf64, #COO>
- outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
- -> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index)
+ %rp, %ri, %rd, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<100x100xf64, #COO>
+ out_lvls(%op, %oi : tensor<2xindex>, tensor<6x2xi32>)
+ out_vals(%od : tensor<6xf64>)
+ -> (tensor<2xindex>, tensor<6x2xi32>), tensor<6xf64>, (index, index), index
return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 54de1024323b5f..aa17261724dbc0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -99,13 +99,13 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?xf64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
-// CHECK: %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.assemble {{.*}} to tensor<1x2x2x2xf64, #[[$demap]]>
// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
// CHECK: return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
// CHECK: }
func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xindex>, %crd:tensor<?xindex>) -> tensor<2x4xf64, #BSR> {
- %0 = sparse_tensor.assemble %val, %pos, %crd
- : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<2x4xf64, #BSR>
+ %0 = sparse_tensor.assemble (%pos, %crd), %val
+ : (tensor<?xindex>, tensor<?xindex>), tensor<?xf64> to tensor<2x4xf64, #BSR>
return %0 : tensor<2x4xf64, #BSR>
}
@@ -115,7 +115,7 @@ func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xi
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
-// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK: %{{.*}} = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
// CHECK: return
// CHECK: }
func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
@@ -123,8 +123,9 @@ func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
%op : tensor<?xindex>,
%oi : tensor<?xindex>)
-> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
- %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
- outs(%od, %op, %oi : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>)
- -> tensor<?xf64>, (tensor<?xindex>, tensor<?xindex>), index, (index, index)
+ %rp, %ri, %rd, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
+ out_lvls(%op, %oi : tensor<?xindex>, tensor<?xindex>)
+ out_vals(%od : tensor<?xf64>)
+ -> (tensor<?xindex>, tensor<?xindex>), tensor<?xf64>, (index, index), index
return %rd, %rp, %ri : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 2b9b73a1990e65..b792d00681ddb4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -87,9 +87,9 @@ module {
[ 7, 8]]
> : tensor<3x2xi32>
- %s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
+ %s4 = sparse_tensor.assemble (%pos, %index), %data : (tensor<2xindex>, tensor<3x2xindex>), tensor<3xf64>
to tensor<10x10xf64, #SortedCOO>
- %s5 = sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+ %s5 = sparse_tensor.assemble (%pos32, %index32), %data : (tensor<2xi32>, tensor<3x2xi32>), tensor<3xf64>
to tensor<10x10xf64, #SortedCOOI32>
//
@@ -107,7 +107,7 @@ module {
%csr_index32 = arith.constant dense<
[1, 0, 1]
> : tensor<3xi32>
- %csr = sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
+ %csr = sparse_tensor.assemble (%csr_pos32, %csr_index32), %csr_data : (tensor<3xi32>, tensor<3xi32>), tensor<3xf64>
to tensor<2x2xf64, #CSR>
//
@@ -131,8 +131,8 @@ module {
[ 10, 10]]
> : tensor<6x2xindex>
- %bs = sparse_tensor.assemble %bdata, %bpos, %bindex :
- tensor<5xf64>, tensor<4xindex>, tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
+ %bs = sparse_tensor.assemble (%bpos, %bindex), %bdata :
+ (tensor<4xindex>, tensor<6x2xindex>), tensor<5xf64> to tensor<2x10x10xf64, #BCOO>
//
// Verify results.
@@ -231,9 +231,10 @@ module {
%od = tensor.empty() : tensor<3xf64>
%op = tensor.empty() : tensor<2xi32>
%oi = tensor.empty() : tensor<3x2xi32>
- %d, %p, %i, %dl, %pl, %il = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
- -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
+ %p, %i, %d, %dl, %pl, %il = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
+ out_lvls(%op, %oi : tensor<2xi32>, tensor<3x2xi32>)
+ out_vals(%od : tensor<3xf64>)
+ -> (tensor<2xi32>, tensor<3x2xi32>), tensor<3xf64>, (i32, i64), index
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -246,9 +247,10 @@ module {
%d_csr = tensor.empty() : tensor<4xf64>
%p_csr = tensor.empty() : tensor<3xi32>
%i_csr = tensor.empty() : tensor<3xi32>
- %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
- outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
- -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
+ %rp_csr, %ri_csr, %rd_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
+ out_lvls(%p_csr, %i_csr : tensor<3xi32>, tensor<3xi32>)
+ out_vals(%d_csr : tensor<4xf64>)
+ -> (tensor<3xi32>, tensor<3xi32>), tensor<4xf64>, (i32, i64), index
// CHECK-NEXT: ( 1, 2, 3 )
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<3xf64>
@@ -257,9 +259,10 @@ module {
%bod = tensor.empty() : tensor<6xf64>
%bop = tensor.empty() : tensor<4xindex>
%boi = tensor.empty() : tensor<6x2xindex>
- %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.disassemble %bs : tensor<2x10x10xf64, #BCOO>
- outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
- -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
+ %bp, %bi, %bd, %lp, %li, %ld = sparse_tensor.disassemble %bs : tensor<2x10x10xf64, #BCOO>
+ out_lvls(%bop, %boi : tensor<4xindex>, tensor<6x2xindex>)
+ out_vals(%bod : tensor<6xf64>)
+ -> (tensor<4xindex>, tensor<6x2xindex>), tensor<6xf64>, (i32, tensor<i64>), index
// CHECK-NEXT: ( 1, 2, 3, 4, 5 )
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<5xf64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
index da816c7fbb1172..8a65e2449c1574 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
@@ -71,11 +71,10 @@ module {
%crd02 = arith.constant dense<
[ 0, 1, 0, 1, 0, 0, 1, 0 ]> : tensor<8xi32>
- %s0 = sparse_tensor.assemble %data0, %pos00, %crd00, %pos01, %crd01, %pos02, %crd02 :
- tensor<8xf32>,
- tensor<2xi64>, tensor<3xi32>,
- tensor<4xi64>, tensor<5xi32>,
- tensor<6xi64>, tensor<8xi32> to tensor<4x3x2xf32, #CCC>
+ %s0 = sparse_tensor.assemble (%pos00, %crd00, %pos01, %crd01, %pos02, %crd02), %data0 :
+ (tensor<2xi64>, tensor<3xi32>,
+ tensor<4xi64>, tensor<5xi32>,
+ tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
//
// Setup BatchedCSR.
@@ -89,7 +88,7 @@ module {
%crd1 = arith.constant dense<
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
- %s1 = sparse_tensor.assemble %data1, %pos1, %crd1 : tensor<16xf32>, tensor<13xi64>, tensor<16xi32> to tensor<4x3x2xf32, #BatchedCSR>
+ %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
//
// Setup CSRDense.
@@ -103,7 +102,7 @@ module {
%crd2 = arith.constant dense<
[ 0, 1, 2, 0, 2, 0, 1, 2, 0, 1, 2 ]> : tensor<11xi32>
- %s2 = sparse_tensor.assemble %data2, %pos2, %crd2 : tensor<22xf32>, tensor<5xi64>, tensor<11xi32> to tensor<4x3x2xf32, #CSRDense>
+ %s2 = sparse_tensor.assemble (%pos2, %crd2), %data2 : (tensor<5xi64>, tensor<11xi32>), tensor<22xf32> to tensor<4x3x2xf32, #CSRDense>
//
// Verify.
More information about the Mlir-commits
mailing list