[Mlir-commits] [mlir] Peiming clean (PR #68057)
Peiming Liu
llvmlistbot at llvm.org
Mon Oct 2 17:14:32 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/68057
None
>From 23c44d18db8be75f0b176d13a8209605946f5481 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 3 Oct 2023 00:06:46 +0000
Subject: [PATCH 1/2] [mlir][sparse] unify lib/codegen rewriting rules for
sparse tensor concatenate operations.
---
.../ExecutionEngine/SparseTensor/Storage.h | 8 +-
.../Transforms/SparseTensorConversion.cpp | 21 +-
.../Transforms/SparseTensorRewriting.cpp | 6 +-
.../Dialect/SparseTensor/sparse_concat.mlir | 864 +++++++++---------
.../SparseTensor/sparse_concat_codegen.mlir | 427 ---------
5 files changed, 437 insertions(+), 889 deletions(-)
delete mode 100644 mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 68dcab6e64c7e45..3c6d6c5bf5c998e 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -783,8 +783,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
for (uint64_t l = 0; l < lvlRank; ++l) {
const auto crd = lvlCoords[l];
const auto cur = lvlCursor[l];
- if (crd > cur || (crd == cur && !isUniqueLvl(l)))
+ if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
+ (crd < cur && !isOrderedLvl(l))) {
return l;
+ }
+
if (crd < cur) {
assert(false && "non-lexicographic insertion");
return -1u;
@@ -900,8 +903,7 @@ class SparseTensorEnumeratorBase {
//===----------------------------------------------------------------------===//
template <typename P, typename C, typename V>
-class SparseTensorEnumerator final
- : public SparseTensorEnumeratorBase<V> {
+class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
using Base = SparseTensorEnumeratorBase<V>;
using StorageImpl = SparseTensorStorage<P, C, V>;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 987706f2f127ab8..fb36cfbbb5adb3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1434,17 +1434,16 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
void mlir::populateSparseTensorConversionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
const SparseTensorConversionOptions &options) {
- patterns
- .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
- SparseCastConverter, SparseTensorNewConverter,
- SparseTensorConcatConverter, SparseTensorAllocConverter,
- SparseTensorEmptyConverter, SparseTensorDeallocConverter,
- SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
- SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
- SparseTensorLoadConverter, SparseTensorInsertConverter,
- SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorOutConverter, SparseTensorAssembleConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+ SparseCastConverter, SparseTensorNewConverter,
+ SparseTensorAllocConverter, SparseTensorEmptyConverter,
+ SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
+ SparseTensorToCoordinatesConverter,
+ SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
+ SparseTensorLoadConverter, SparseTensorInsertConverter,
+ SparseTensorExpandConverter, SparseTensorCompressConverter,
+ SparseTensorOutConverter, SparseTensorAssembleConverter>(
+ typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 55db34f7050d3f3..3b5b224434bead1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1474,17 +1474,17 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
- patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+ patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
TensorReshapeRewriter>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
+
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
- patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter>(
- patterns.getContext());
+ patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
if (enableConvert)
patterns.add<ConvertRewriter>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index a24315755944320..2fb4529e5695e58 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -1,456 +1,430 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
+// RUN: | FileCheck %s
-#SparseMatrix = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
-#SparseMatrix_P = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : compressed, d0 : compressed)}>
-
-#SparseMatrix_D_P = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : dense, d0 : dense)}>
-
-// CHECK-LABEL: func.func @concat_mix_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = memref.alloc() : memref<5x4xf64>
-// CHECK: linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_0]] : memref<5x4xf64>)
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
-// CHECK: %[[TMP_12:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<2x4xf64>
-// CHECK: %[[TMP_13:.*]] = arith.cmpf une, %[[TMP_12]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_13]] {
-// CHECK: memref.store %[[TMP_12]], %[[TMP_0]][%[[TMP_arg2]], %[[TMP_arg3]]] : memref<5x4xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_9:.*]] = memref.cast %[[TMP_8]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_10:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_12:.*]] = func.call @getNextF64(%[[TMP_7]], %[[TMP_9]], %[[TMP_10]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_12]])
-// CHECK: } do {
-// CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_13:.*]] = arith.addi %[[TMP_12]], %[[TMP_c2]] : index
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_15:.*]] = memref.load %[[TMP_10]][] : memref<f64>
-// CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_13]], %[[TMP_14]]] : memref<5x4xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<5x4xf64>
-// CHECK: return %[[TMP_11]] : tensor<5x4xf64>
-// CHECK: }
-func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #SparseMatrix>) -> tensor<5x4xf64> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 0 : index}
- : tensor<2x4xf64>, tensor<3x4xf64, #SparseMatrix> to tensor<5x4xf64>
- return %0 : tensor<5x4xf64>
-}
-
-// CHECK-LABEL: func.func @concat_mix_sparse(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c2_i32:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c4_i32:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_0:.*]] = memref.cast %[[LvlSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP_0:.*]] = memref.cast %[[Iota_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[IotaP_0]], %[[IotaP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]])
-// CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
-// CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_arg3]], %[[TMP_9]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_22:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<2x4xf64>
-// CHECK: %[[TMP_23:.*]] = arith.cmpf une, %[[TMP_22]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_23]] {
-// CHECK: memref.store %[[TMP_22]], %[[TMP_8]][] : memref<f64>
-// CHECK: %[[TMP_24:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_8]], %[[TMP_10]], %[[IotaP_0]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_1:.*]] = memref.cast %[[LvlSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP_1:.*]] = memref.cast %[[Iota_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_17:.*]] = call @newSparseTensor(%[[DimSizesP_1]], %[[LvlSizesP_1]], %[[LvlTypesP_1]], %[[IotaP_1]], %[[IotaP_1]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_18:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_19:.*]] = memref.cast %[[TMP_18]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_20:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_22:.*]] = func.call @getNextF64(%[[TMP_17]], %[[TMP_19]], %[[TMP_20]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_22]])
-// CHECK: } do {
-// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_18]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_23:.*]] = arith.addi %[[TMP_22]], %[[TMP_c2]] : index
-// CHECK: %[[TMP_24:.*]] = memref.load %[[TMP_18]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_23]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_24]], %[[TMP_9]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[IotaP_0]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[IotaP_0]], %[[IotaP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]])
-// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[TMP_21]] : !llvm.ptr<i8>
-// CHECK: }
-func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #SparseMatrix>) -> tensor<5x4xf64, #SparseMatrix> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 0 : index}
- : tensor<2x4xf64>, tensor<3x4xf64, #SparseMatrix> to tensor<5x4xf64, #SparseMatrix>
- return %0 : tensor<5x4xf64, #SparseMatrix>
+#DCSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
+#DENSE = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : dense)}>
+#DENSE_P = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : dense, d0 : dense)}>
+// CHECK-LABEL: @concat_sparse_sparse(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_1]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_4]]
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_2]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_5]]
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_3]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_6]]
+// CHECK: }
+// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
+// CHECK: return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
+func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<9x4xf64, #DCSR> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<9x4xf64, #DCSR>
+ return %0 : tensor<9x4xf64, #DCSR>
}
-// CHECK-LABEL: func.func @concat_mix_sparse_perm_dim1(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x2xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c2_i32:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c4_i32:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_0:.*]] = memref.cast %[[LvlSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Lvl2Dim_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Lvl2DimP_0:.*]] = memref.cast %[[Lvl2Dim_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Dim2Lvl_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Dim2LvlP_0:.*]] = memref.cast %[[Dim2Lvl_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Dim2Lvl_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]])
-// CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_arg3]], %[[TMP_9]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_22:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<4x2xf64>
-// CHECK: %[[TMP_23:.*]] = arith.cmpf une, %[[TMP_22]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_23]] {
-// CHECK: memref.store %[[TMP_22]], %[[TMP_8]][] : memref<f64>
-// CHECK: %[[TMP_24:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_8]], %[[TMP_10]], %[[Dim2LvlP_0]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_1:.*]] = memref.cast %[[LvlSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP_1:.*]] = memref.cast %[[Iota_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_17:.*]] = call @newSparseTensor(%[[DimSizesP_1]], %[[LvlSizesP_1]], %[[LvlTypesP_1]], %[[IotaP_1]], %[[IotaP_1]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_18:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_19:.*]] = memref.cast %[[TMP_18]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_20:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_22:.*]] = func.call @getNextF64(%[[TMP_17]], %[[TMP_19]], %[[TMP_20]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_22]])
-// CHECK: } do {
-// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_18]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_18]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_22]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_24]], %[[TMP_9]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[Dim2LvlP_0]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]])
-// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[TMP_21]] : !llvm.ptr<i8>
-// CHECK: }
-func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3xf64, #SparseMatrix_P>) -> tensor<4x5xf64, #SparseMatrix_P> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
- : tensor<4x2xf64>, tensor<4x3xf64, #SparseMatrix_P> to tensor<4x5xf64, #SparseMatrix_P>
- return %0 : tensor<4x5xf64, #SparseMatrix_P>
+// CHECK-LABEL: @concat_sparse_sparse_dynamic(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_1]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_4]]
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_2]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_5]]
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_3]]
+// CHECK: }
+// CHECK: scf.yield %[[RET_6]]
+// CHECK: }
+// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
+// CHECK: return %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
+func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<?x?xf64, #DCSR> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DCSR>
+ return %0 : tensor<?x?xf64, #DCSR>
}
-// CHECK-LABEL: func.func @concat_mix_dense_perm_dim1(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x2xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = memref.alloc() : memref<4x5xf64>
-// CHECK: linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_0]] : memref<4x5xf64>)
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: %[[TMP_12:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<4x2xf64>
-// CHECK: %[[TMP_13:.*]] = arith.cmpf une, %[[TMP_12]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_13]] {
-// CHECK: memref.store %[[TMP_12]], %[[TMP_0]][%[[TMP_arg2]], %[[TMP_arg3]]] : memref<4x5xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_9:.*]] = memref.cast %[[TMP_8]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_10:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_12:.*]] = func.call @getNextF64(%[[TMP_7]], %[[TMP_9]], %[[TMP_10]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_12]])
-// CHECK: } do {
-// CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_14:.*]] = arith.addi %[[TMP_13]], %[[TMP_c2]] : index
-// CHECK: %[[TMP_15:.*]] = memref.load %[[TMP_10]][] : memref<f64>
-// CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_12]], %[[TMP_14]]] : memref<4x5xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<4x5xf64>
-// CHECK: return %[[TMP_11]] : tensor<4x5xf64>
-// CHECK: }
-func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3xf64, #SparseMatrix_P>) -> tensor<4x5xf64> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
- : tensor<4x2xf64>, tensor<4x3xf64, #SparseMatrix_P> to tensor<4x5xf64>
- return %0 : tensor<4x5xf64>
+// CHECK-LABEL: @concat_sparse_sparse_dense(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
+// CHECK: linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
+// CHECK: return %[[R]] : tensor<?x?xf64>
+func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<?x?xf64> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<?x?xf64>
+ return %0 : tensor<?x?xf64>
}
-// CHECK-LABEL: func.func @concat_mix_dense_perm_dim1_dyn(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<3x2xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK: %[[TMP_0:.*]] = memref.alloc() : memref<3x5xf64>
-// CHECK: %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<3x5xf64> to memref<?x?xf64>
-// CHECK: linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_0]] : memref<3x5xf64>)
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c3]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: %[[TMP_13:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<3x2xf64>
-// CHECK: %[[TMP_14:.*]] = arith.cmpf une, %[[TMP_13]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_14]] {
-// CHECK: memref.store %[[TMP_13]], %[[TMP_0]][%[[TMP_arg2]], %[[TMP_arg3]]] : memref<3x5xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_8:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_11:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_13:.*]] = func.call @getNextF64(%[[TMP_8]], %[[TMP_10]], %[[TMP_11]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_13]])
-// CHECK: } do {
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_9]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_15:.*]] = arith.addi %[[TMP_14]], %[[TMP_c2]] : index
-// CHECK: %[[TMP_16:.*]] = memref.load %[[TMP_11]][] : memref<f64>
-// CHECK: memref.store %[[TMP_16]], %[[TMP_0]][%[[TMP_13]], %[[TMP_15]]] : memref<3x5xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref<?x?xf64>
-// CHECK: return %[[TMP_12]] : tensor<?x?xf64>
-// CHECK: }
-func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<3x3xf64, #SparseMatrix>) -> tensor<?x?xf64> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
- : tensor<3x2xf64>, tensor<3x3xf64, #SparseMatrix> to tensor<?x?xf64>
- return %0 : tensor<?x?xf64>
+// CHECK-LABEL: @concat_sparse_sparse_annotated_dense(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
+// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @concat_sparse_sparse_annotated_dense(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<?x?xf64, #DENSE> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE>
+ return %0 : tensor<?x?xf64, #DENSE>
}
-// CHECK-LABEL: func.func @concat_annotated_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x2xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c4_i8:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_0:.*]] = memref.cast %[[LvlSizes_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Lvl2Dim_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Lvl2DimP_0:.*]] = memref.cast %[[Lvl2Dim_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Dim2Lvl_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Dim2LvlP_0:.*]] = memref.cast %[[Dim2Lvl_0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Dim2Lvl_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[NullPtr]])
-// CHECK: %[[Values_r:.*]] = call @sparseValuesF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> memref<?xf64>
-// CHECK: %[[Values:.*]] = memref.reshape %[[Values_r]]
-// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
-// CHECK: %[[TMP_22:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<4x2xf64>
-// CHECK: %[[TMP_23:.*]] = arith.cmpf une, %[[TMP_22]], %[[TMP_cst]] : f64
-// CHECK: scf.if %[[TMP_23]] {
-// CHECK: memref.store %[[TMP_22]], %[[Values]][%[[TMP_arg3]], %[[TMP_arg2]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK-DAG: %[[LvlSizes_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlSizesP_1:.*]] = memref.cast %[[LvlSizes_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[Iota_1:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[IotaP_1:.*]] = memref.cast %[[Iota_1]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_1]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_1]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_17:.*]] = call @newSparseTensor(%[[DimSizesP_1]], %[[LvlSizesP_1]], %[[LvlTypesP_1]], %[[IotaP_1]], %[[IotaP_1]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
-// CHECK: %[[TMP_18:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[TMP_19:.*]] = memref.cast %[[TMP_18]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[TMP_20:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.while : () -> () {
-// CHECK: %[[TMP_22:.*]] = func.call @getNextF64(%[[TMP_17]], %[[TMP_19]], %[[TMP_20]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[TMP_22]])
-// CHECK: } do {
-// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_18]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_18]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_20]][] : memref<f64>
-// CHECK: memref.store %[[TMP_25]], %[[Values]][%[[TMP_24]], %[[TMP_22]]] : memref<?x?xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[TMP_7]] : !llvm.ptr<i8>
-// CHECK: }
-func.func @concat_annotated_dense(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3xf64, #SparseMatrix_P>) -> tensor<4x5xf64, #SparseMatrix_D_P> {
- %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
- : tensor<4x2xf64>, tensor<4x3xf64, #SparseMatrix_P> to tensor<4x5xf64, #SparseMatrix_D_P>
- return %0 : tensor<4x5xf64, #SparseMatrix_D_P>
+// CHECK-LABEL: @concat_sparse_sparse_annotated_dense_permute(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_23]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
+// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @concat_sparse_sparse_annotated_dense_permute(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<?x?xf64, #DENSE_P> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE_P>
+ return %0 : tensor<?x?xf64, #DENSE_P>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
deleted file mode 100644
index 6f9c45842ffb2ee..000000000000000
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ /dev/null
@@ -1,427 +0,0 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
-// RUN: | FileCheck %s
-
-#DCSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
-#DENSE = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : dense)}>
-#DENSE_P = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : dense, d0 : dense)}>
-// CHECK-LABEL: @concat_sparse_sparse(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<9x4xf64, #sparse_tensor
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_1]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_4]]
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_2]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_5]]
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_3]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_6]]
-// CHECK: }
-// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-// CHECK: return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
-func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<9x4xf64, #DCSR> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<9x4xf64, #DCSR>
- return %0 : tensor<9x4xf64, #DCSR>
-}
-
-// CHECK-LABEL: @concat_sparse_sparse_dynamic(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_1]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_4]]
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_2]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_5]]
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: scf.yield %[[NEW_3]]
-// CHECK: }
-// CHECK: scf.yield %[[RET_6]]
-// CHECK: }
-// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-// CHECK: return %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
-func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64, #DCSR> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DCSR>
- return %0 : tensor<?x?xf64, #DCSR>
-}
-
-// CHECK-LABEL: @concat_sparse_sparse_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
-// CHECK: linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
-// CHECK: return %[[R]] : tensor<?x?xf64>
-func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64>
- return %0 : tensor<?x?xf64>
-}
-
-// CHECK-LABEL: @concat_sparse_sparse_annotated_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
-// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
-func.func @concat_sparse_sparse_annotated_dense(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64, #DENSE> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE>
- return %0 : tensor<?x?xf64, #DENSE>
-}
-
-// CHECK-LABEL: @concat_sparse_sparse_annotated_dense_permute(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_23]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
-// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
-func.func @concat_sparse_sparse_annotated_dense_permute(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64, #DENSE_P> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE_P>
- return %0 : tensor<?x?xf64, #DENSE_P>
-}
>From 1f2ed031c25057fb539638913b654ee31d7adf7f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 3 Oct 2023 00:11:58 +0000
Subject: [PATCH 2/2] remove deadcode after cleanup
---
.../Transforms/SparseTensorConversion.cpp | 295 ------------------
1 file changed, 295 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index fb36cfbbb5adb3f..a3361c2cd48c6dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -162,38 +162,6 @@ static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
return out;
}
-/// Populates the given sizes array for concatenation from type (for static
-/// sizes) and from an already-converted opaque pointer source (for dynamic
-/// sizes).
-static void concatDimSizesFromInputs(OpBuilder &builder, Location loc,
- SparseTensorType dstTp, ValueRange srcs,
- Dimension dim,
- SmallVectorImpl<Value> &dimSizes) {
- assert(dim < dstTp.getDimRank() && "Dimension is out of bounds");
- dimSizes.clear();
-
- // We first fills the sizes from an input tensor, and then
- // compute the size of the concatenation dimension if necessary.
- const auto srcTp = getSparseTensorType(srcs[0]);
- if (srcTp.hasEncoding())
- // Reuses sizes from an arbitrary input tensor is fine.
- fillDimSizes(builder, loc, srcTp, srcs[0], dimSizes);
- else
- sizesFromSrc(builder, dimSizes, loc, srcs[0]);
-
- if (const auto sz = dstTp.getStaticDimSize(dim)) {
- // Faithfully take the static size.
- dimSizes[dim] = constantIndex(builder, loc, *sz);
- } else {
- // Else, dynamically compute the size.
- for (const auto src : srcs.drop_front()) {
- const auto srcTp = getSparseTensorType(src);
- Value srcSz = createOrFoldDimCall(builder, loc, srcTp, src, dim);
- dimSizes[dim] = builder.create<arith::AddIOp>(loc, dimSizes[dim], srcSz);
- }
- }
-}
-
/// Generates an uninitialized buffer of the given size and type,
/// but returns it as type `memref<? x $tp>` (rather than as type
/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
@@ -467,107 +435,6 @@ static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
return true;
}
-// Generates a while loop that iterates over the COO list extracted
-// from `t`, using `bodyBuilder` to build the loop body.
-// while (elem = coo->getNext()) {
-// bodyBuilder
-// }
-// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
-// reduce code repetition!
-// TODO: rename to `genSparseIterationLoop`?
-static void genSparseCOOIterationLoop(
- ConversionPatternRewriter &rewriter, Location loc, Value t,
- SparseTensorType stt,
- function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilder) {
- assert(stt.hasEncoding() &&
- "Generating Sparse Tensor COO Loop on a Dense Tensor!");
- const Dimension dimRank = stt.getDimRank();
- const Type elemTp = stt.getElementType();
-
- // Start an iterator over the tensor (in coordinate order).
- const auto noPerm = stt.withoutDimToLvl();
- SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, t);
- Value iter = NewCallParams(rewriter, loc)
- .genBuffers(noPerm, dimSizes)
- .genNewCall(Action::kToIterator, t);
-
- // Construct a while loop over the iterator.
- const Type iTp = rewriter.getIndexType();
- Value srcDimCoords = genAlloca(rewriter, loc, dimRank, iTp);
- Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
- const SmallVector<Value> noArgs;
- const SmallVector<Type> noTypes;
- auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
- rewriter.setInsertionPointToEnd(before);
- Value cond = genGetNextCall(rewriter, loc, iter, srcDimCoords, elemPtr);
- rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
- Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
- rewriter.setInsertionPointToStart(after);
-
- const bool hasDenseDim =
- llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT);
- if (hasDenseDim) {
- Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
- Value isZero = genIsNonzero(rewriter, loc, elemV);
- scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, isZero, /*else*/ false);
- rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- }
- // Callback here to build loop body.
- bodyBuilder(rewriter, loc, srcDimCoords, elemPtr);
-
- // Exit the scope from the IfOp.
- if (hasDenseDim)
- rewriter.setInsertionPointToEnd(after);
-
- rewriter.create<scf::YieldOp>(loc);
- // Finish generating loop.
- rewriter.setInsertionPointAfter(whileOp);
-
- // Free memory for iterator.
- genDelIteratorCall(rewriter, loc, elemTp, iter);
-}
-
-// Generate loop that iterates over a dense tensor.
-// for i1 in dim1
-// ..
-// for ik in dimk
-// val = a[i1,..,ik]
-// if val != 0
-// bodyBuilder(v, [i1, ..., ik])
-// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
-// reduce code repetition!
-static void genDenseTensorIterationLoop(
- ConversionPatternRewriter &rewriter, Location loc, Value t,
- SparseTensorType stt,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
- assert(!stt.hasEncoding() &&
- "Generating Dense Tensor Loop on a Sparse Tensor!");
-
- const Dimension dimRank = stt.getDimRank();
- Value zero = constantIndex(rewriter, loc, 0);
- Value one = constantIndex(rewriter, loc, 1);
-
- SmallVector<Value> lo;
- SmallVector<Value> hi;
- SmallVector<Value> st;
-
- // Fill out loop iteration information.
- for (Dimension d = 0; d < dimRank; d++) {
- lo.push_back(zero);
- hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d));
- st.push_back(one);
- }
-
- scf::buildLoopNest(rewriter, loc, lo, hi, st, {},
- [&](OpBuilder &builder, Location loc, ValueRange ivs,
- ValueRange args) -> scf::ValueVector {
- // Invoke callback to build the body of the loop.
- bodyBuilder(builder, loc, ivs);
- return {};
- });
-}
-
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -1198,168 +1065,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse conversion rule for the concatenate operator.
-class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(ConcatenateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // The conversion works as follow:
- // (1). When output is sparse and not all dims are dense, and mix of inputs:
- // a_sparse = concat (b_dense, c_sparse, ....)
- // =>
- // coo_for_a = newSparseCOO(shapeOf(a))
- // for i, j, k // dense input
- // coo->add(adjustForOffset(i,j,k), b[i,j,k])
- //
- // for elem in sparse_input
- // coo->add(adjustForOffset(elem.coords), elem.value)
- // ...
- // a = newSparseTensor(coo_for_a)
- // return a
- //
- // (2). When output is dense or annotated all dense, and mix of inputs:
- // a_dense = concat (b_dense, c_sparse, ....)
- // =>
- // a = malloc(shapeOf(a)) or newSparseAllDense(shapeOf(a))
- // for i, j, k // dense input
- // a[ adjustForOffset(i,j,k) ] = b[i,j,k]
- //
- // for elem in sparse_input
- // a[ adjustForOffset(elem.coords) ] = elem.value
- // return a
- Location loc = op.getLoc();
- const auto dstTp = getSparseTensorType(op);
- const auto dstEnc = dstTp.getEncoding();
- const Type elemTp = dstTp.getElementType();
- const Dimension concatDim = op.getDimension();
- const Dimension dimRank = dstTp.getDimRank();
-
- Value dst; // destination tensor
- Value dstDimToLvl; // destination tensor permutation (if sparse out)
- // A pointer to the value being inserted (if dense => sparse)
- Value elemPtr;
- // Memory that holds the dim-coords for destination tensor (if sparse out)
- Value dstDimCoords;
- // The offset applied to the dimension to be concated (starting from 0)
- Value offset = constantIndex(rewriter, loc, 0);
-
- SmallVector<Value> dimSizes;
- concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim,
- dimSizes);
-
- NewCallParams params(rewriter, loc);
- const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
- Value dstTensor;
- if (dstTp.hasEncoding()) {
- // Start a new COO or an initialized annotated all dense sparse tensor.
- dst = params.genBuffers(dstTp, dimSizes)
- .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);
- dstDimCoords = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType());
- if (allDense) {
- dstTensor = dst;
- // Get the values buffer for the sparse tensor and reshape it to the
- // corresponding dense tensor shape.
- dst = genValuesCall(rewriter, loc,
- MemRefType::get({ShapedType::kDynamic}, elemTp),
- {dst});
- // Pass the `dstDimCoords` buffer for `reshapeValuesToLevels`
- // to reuse for storing level-sizes (yes, "level-sizes").
- // This is safe to do because `dstTp` is a dense-tensor type,
- // and therefore lvlRank == dimRank.
- dst = reshapeValuesToLevels(rewriter, loc, dstEnc, dimSizes, dst,
- dstDimCoords);
- } else {
- dstDimToLvl = params.getDimToLvl();
- elemPtr = genAllocaScalar(rewriter, loc, elemTp);
- }
- } else {
- // TODO: Dense buffers should be allocated/deallocated via the callback
- // in BufferizationOptions.
- dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
- }
- const Level lvlRank = dstTp.getLvlRank();
- const auto dcvs2lcvs = [&](ValueRange dcvs) -> SmallVector<Value> {
- SmallVector<Value> lcvs;
- lcvs.reserve(lvlRank);
- for (Level l = 0; l < lvlRank; l++)
- // FIXME: `toOrigDim` is deprecated
- lcvs.push_back(dcvs[toOrigDim(dstEnc, l)]);
- return lcvs;
- };
- for (const auto &it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
- Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
- Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
- const auto srcTp = getSparseTensorType(orignalOp);
- if (srcTp.hasEncoding()) {
- genSparseCOOIterationLoop(
- rewriter, loc, adaptedOp, srcTp,
- [&](OpBuilder &builder, Location loc, Value dimCoords,
- Value elemPtr) -> void {
- const auto dcvs =
- loadAll(builder, loc, dimRank, dimCoords, concatDim, offset);
- if (dstTp.hasEncoding() && !allDense) {
- // Case: sparse => sparse, except for annotated all dense.
- storeAll(builder, loc, dstDimCoords, dcvs);
- genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
- dstDimToLvl);
- } else {
- // Case: sparse => dense, or annotated all dense.
- const auto lcvs = allDense ? dcvs2lcvs(dcvs) : dcvs;
- insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lcvs);
- }
- });
- } else {
- genDenseTensorIterationLoop(
- rewriter, loc, adaptedOp, srcTp,
- [&](OpBuilder &builder, Location loc, ValueRange dcvs) -> void {
- if (dstTp.hasEncoding() && !allDense) {
- // Case: dense => sparse, except for annotated all dense.
- assert(dcvs.size() == static_cast<size_t>(dimRank));
- storeAll(builder, loc, dstDimCoords, dcvs, concatDim, offset);
- Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
- builder.create<memref::StoreOp>(loc, val, elemPtr);
- genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
- dstDimToLvl);
- } else {
- // Case: dense => dense, or annotated all dense.
- Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
- // Despite the name, this isn't actually level-cvs until
- // after the `dcvs2lcvs` call.
- SmallVector<Value> lcvs(dcvs);
- // Apply offset.
- lcvs[concatDim] =
- builder.create<arith::AddIOp>(loc, lcvs[concatDim], offset);
- if (allDense)
- lcvs = dcvs2lcvs(lcvs);
- builder.create<memref::StoreOp>(loc, val, dst, lcvs);
- }
- });
- }
- // Accumulate offset.
- // TODO: avoid calling sparseDimSize multiple times by caching the result!
- Value curDim =
- createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim);
- offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
- }
- if (!dstTp.hasEncoding()) {
- rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
- op, dstTp.getRankedTensorType(), dst);
- } else if (allDense) {
- rewriter.replaceOp(op, dstTensor);
- } else {
- // In sparse output case, the destination holds the COO.
- Value coo = dst;
- dst = params.genNewCall(Action::kFromCOO, coo);
- // Release resources.
- genDelCOOCall(rewriter, loc, elemTp, coo);
- rewriter.replaceOp(op, dst);
- }
- return success();
- }
-};
-
/// Sparse conversion rule for the output operator.
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
public:
More information about the Mlir-commits
mailing list