[Mlir-commits] [mlir] [mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… (PR #69362)
Peiming Liu
llvmlistbot at llvm.org
Tue Oct 17 10:59:52 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69362
…ting rules.
>From f461a5c264367a8e9eeef5b5dd3107696a5aa24d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 17 Oct 2023 17:37:27 +0000
Subject: [PATCH] [mlir][sparse] avoid tensor->memref conversion in sparse
tensor rewriting rules.
---
.../Transforms/SparseTensorRewriting.cpp | 107 +++++--------
.../SparseTensor/convert_sparse2dense.mlir | 35 ++---
.../Dialect/SparseTensor/sparse_concat.mlir | 148 +++++++++---------
3 files changed, 132 insertions(+), 158 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bfee3aa1d7ee8e..e50b14975e83d63 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
};
+// A trivial wrapper to help generate different operations for dense/sparse
+// tensors.
struct TensorLike {
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
- ValueRange sizes)
- : isSparse(rtt.getEncoding() != nullptr) {
+ ValueRange sizes) {
SmallVector<Value> dynSzs;
getDynamicSizes(rtt, sizes, dynSzs);
- if (isSparse)
- val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
- else
- val = allocDenseTensor(builder, loc, rtt, sizes);
- };
-
- void insertOrStore(OpBuilder &builder, Location loc, Value v,
- ValueRange crds) {
- if (isSparse)
- val = builder.create<InsertOp>(loc, v, val, crds);
- else
- builder.create<memref::StoreOp>(loc, v, val, crds);
+ val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+ if (!isSparse()) {
+ Value c0 = constantZero(builder, loc, rtt.getElementType());
+ val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
+ }
}
- Value getSSA() const {
- // We don't need to maintain the SSA chain for a memref value.
- return isSparse ? val : nullptr;
+ void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
+ // TODO: Unify these two.
+ if (isSparse())
+ val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
+ else
+ val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}
Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
- if (isSparse)
+ if (isSparse())
return builder.create<LoadOp>(loc, val, true);
- return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+ return val;
}
- void updateSSA(Value v) {
- // Dense memref is a non-SSA value.
- assert(isSparse);
- val = v;
+ bool isSparse() const {
+ return getSparseTensorEncoding(val.getType()) != nullptr;
}
-private:
- bool isSparse;
- Value val; // either a memref (for dense tensor) or a sparse tensor.
+ Value val;
};
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
- Value iterArg = dstBuf.getSSA();
+ Value iterArg = dstBuf.val;
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Builds a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
- loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
+ loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// FIXME: `toStoredDim` is deprecated
dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
}
-
- if (!reduc.empty())
- dstBuf.updateSSA(reduc.front());
-
+ // Enters foreach, updates the SSA chain.
+ dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+ builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- dstBuf.insertOrStore(builder, loc, v, dstLcvs);
- builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+ dstBuf.insert(builder, loc, v, dstLcvs);
+ builder.create<scf::YieldOp>(loc, dstBuf.val);
// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
- assert(!reduc.empty());
- dstBuf.updateSSA(ifOp.getResult(0));
+ dstBuf.val = ifOp.getResult(0);
} else {
- dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+ dstBuf.insert(builder, loc, v, dstLcvs);
}
- if (reduc.empty())
- builder.create<sparse_tensor::YieldOp>(loc);
- else
- builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+ builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(
loc, offset, constantIndex(rewriter, loc, *sh));
- if (!foreachOp.getResults().empty()) {
- iterArg = foreachOp.getResult(0);
- dstBuf.updateSSA(iterArg);
- }
+ iterArg = foreachOp.getResult(0);
+ dstBuf.val = iterArg;
}
- if (!foreachOp.getResults().empty())
- dstBuf.updateSSA(iterArg);
-
+ dstBuf.val = iterArg;
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
- Value iterArg = dstBuf.getSSA();
auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
+ loc, src, dstBuf.val, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
- if (!reduc.empty())
- dstBuf.updateSSA(reduc.front());
-
+ dstBuf.val = reduc.front();
const Dimension dimRank = dstStt.getDimRank();
const Level lvlRank = dstStt.getLvlRank();
SmallVector<Value> lcvs(lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}
if (!skipZeroCheck) {
- assert(!reduc.empty());
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+ builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- dstBuf.insertOrStore(builder, loc, v, lcvs);
- builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+ dstBuf.insert(builder, loc, v, lcvs);
+ builder.create<scf::YieldOp>(loc, dstBuf.val);
// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
- dstBuf.updateSSA(ifOp.getResult(0));
+ dstBuf.val = ifOp.getResult(0);
} else {
- dstBuf.insertOrStore(builder, loc, v, lcvs);
+ dstBuf.insert(builder, loc, v, lcvs);
}
- if (reduc.empty())
- builder.create<sparse_tensor::YieldOp>(loc);
- else
- builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+ builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
rewriter.setInsertionPointAfter(foreachOp);
// Exits the for loop, links the SSA chain.
- if (!foreachOp.getResults().empty())
- dstBuf.updateSSA(foreachOp.getResult(0));
+ dstBuf.val = foreachOp.getResult(0);
Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index c22f051a0d5854d..e2dcb068e11851e 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -14,11 +14,10 @@
// CHECK-LABEL: func.func @sparse_convert_1d
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
return %0 : tensor<13xi32>
@@ -26,11 +25,10 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
// CHECK-LABEL: func.func @sparse_convert_1d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
return %0 : tensor<?xi32>
@@ -38,11 +36,10 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
// CHECK-LABEL: func.func @sparse_convert_2d
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
@@ -50,11 +47,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
// CHECK-LABEL: func.func @sparse_convert_2d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
return %0 : tensor<?x4xf64>
@@ -62,11 +58,10 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
return %0 : tensor<2x?xf64>
@@ -74,11 +69,10 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
@@ -86,11 +80,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
// CHECK-LABEL: func.func @sparse_convert_3d
// CHECK-NOT: sparse_tensor.reorder_coo
-// CHECK: memref.alloc
+// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
-// CHECK: memref.store
-// CHECK: bufferization.to_tensor
+// CHECK: tensor.insert
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
return %0 : tensor<2x3x4xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index bdfab54dc6daeb5..f3d3dd28563e891 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -176,77 +176,83 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
return %0 : tensor<?x?xf64, #DCSR>
}
-// CHECK-LABEL: @concat_sparse_sparse_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
-// CHECK: linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
-// CHECK: return %[[R]] : tensor<?x?xf64>
+// CHECK-LABEL: func.func @concat_sparse_sparse_dense(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_3]]) : tensor<?x?xf64>
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : f64) outs(%[[VAL_10]] : tensor<?x?xf64>) -> tensor<?x?xf64>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_8]] iter_args(%[[VAL_21:.*]] = %[[VAL_11]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK: %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_8]] iter_args(%[[VAL_28:.*]] = %[[VAL_21]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref<?xf64>
+// CHECK: %[[VAL_31:.*]] = tensor.insert %[[VAL_30]] into %[[VAL_28]]{{\[}}%[[VAL_22]], %[[VAL_29]]] : tensor<?x?xf64>
+// CHECK: scf.yield %[[VAL_31]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_26]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[VAL_33:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[VAL_34:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[VAL_35:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[VAL_36:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK: %[[VAL_39:.*]] = scf.for %[[VAL_40:.*]] = %[[VAL_37]] to %[[VAL_38]] step %[[VAL_8]] iter_args(%[[VAL_41:.*]] = %[[VAL_19]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_40]], %[[VAL_8]] : index
+// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK: %[[VAL_46:.*]] = scf.for %[[VAL_47:.*]] = %[[VAL_43]] to %[[VAL_45]] step %[[VAL_8]] iter_args(%[[VAL_48:.*]] = %[[VAL_41]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_47]]] : memref<?xf64>
+// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_42]], %[[VAL_9]] : index
+// CHECK: %[[VAL_52:.*]] = tensor.insert %[[VAL_50]] into %[[VAL_48]]{{\[}}%[[VAL_51]], %[[VAL_49]]] : tensor<?x?xf64>
+// CHECK: scf.yield %[[VAL_52]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_46]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: %[[VAL_53:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[VAL_54:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[VAL_55:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[VAL_56:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[VAL_57:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK: %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_58]] to %[[VAL_59]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_39]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_54]]{{\[}}%[[VAL_61]]] : memref<?xindex>
+// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_55]]{{\[}}%[[VAL_61]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_61]], %[[VAL_8]] : index
+// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_55]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_8]] iter_args(%[[VAL_69:.*]] = %[[VAL_62]]) -> (tensor<?x?xf64>) {
+// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_56]]{{\[}}%[[VAL_68]]] : memref<?xindex>
+// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_57]]{{\[}}%[[VAL_68]]] : memref<?xf64>
+// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_63]], %[[VAL_5]] : index
+// CHECK: %[[VAL_73:.*]] = tensor.insert %[[VAL_71]] into %[[VAL_69]]{{\[}}%[[VAL_72]], %[[VAL_70]]] : tensor<?x?xf64>
+// CHECK: scf.yield %[[VAL_73]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_67]] : tensor<?x?xf64>
+// CHECK: }
+// CHECK: return %[[VAL_60]] : tensor<?x?xf64>
+// CHECK: }
func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
%arg2: tensor<4x4xf64, #DCSR>)
More information about the Mlir-commits
mailing list