[Mlir-commits] [mlir] aedf5d5 - [mlir][sparse] Improve concatenate operator rewriting for dense tensor results.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 28 07:56:07 PST 2022
Author: bixia1
Date: 2022-11-28T07:56:01-08:00
New Revision: aedf5d58312dd3b41864186ab7f4fc108b0c8bf3
URL: https://github.com/llvm/llvm-project/commit/aedf5d58312dd3b41864186ab7f4fc108b0c8bf3
DIFF: https://github.com/llvm/llvm-project/commit/aedf5d58312dd3b41864186ab7f4fc108b0c8bf3.diff
LOG: [mlir][sparse] Improve concatenate operator rewriting for dense tensor results.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D138465
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 958aab6698e9c..fba9c78f02ff8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -199,6 +199,29 @@ static LogicalResult genForeachOnSparseConstant(ForeachOp op,
return success();
}
+/// Populates the given sizes array for concatenation from types (for static
+/// sizes) and from the source tensors (for dynamic sizes).
+static void concatSizesFromInputs(OpBuilder &builder,
+ SmallVectorImpl<Value> &sizes, Location loc,
+ ShapedType dstTp, ValueRange srcs,
+ unsigned dim) {
+ auto dstShape = dstTp.getShape();
+ sizesFromSrc(builder, sizes, loc, srcs[0]);
+
+ // Sum up on the `dim` if the dimension is dynamic.
+ if (dstShape[dim] != ShapedType::kDynamic) {
+ // Faithfully take the static size.
+ sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
+ } else {
+ // Else, compute the shape dynamically.
+ for (const auto &src : srcs.drop_front()) {
+ Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
+ // Sum up all the sizes.
+ sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
+ }
+ }
+}
+
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
@@ -458,83 +481,94 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto rtp = op.getType().cast<RankedTensorType>();
- size_t conDim = op.getDimension().getZExtValue();
- SmallVector<Value> dynSizes;
- if (!rtp.hasStaticShape()) {
- ArrayRef<int64_t> rShape = rtp.getShape();
- for (const auto &d : llvm::enumerate(rShape)) {
- if (d.value() == ShapedType::kDynamic) {
- Value v =
- createOrFoldDimOp(rewriter, loc, op.getOperand(0), d.index());
- rewriter.create<tensor::DimOp>(loc, op.getOperand(0), d.index());
- if (conDim == d.index()) {
- // Adding the size of the concatenating dimension.
- for (const auto &opnd : op.getOperands().drop_front()) {
- Value t = createOrFoldDimOp(rewriter, loc, opnd, d.index());
- v = rewriter.create<arith::AddIOp>(loc, v, t);
- }
- }
- dynSizes.push_back(v);
- }
- }
- }
+ Location loc = op.getLoc();
+ auto dstTp = op.getType().cast<RankedTensorType>();
+ uint64_t conDim = op.getDimension().getZExtValue();
+ SmallVector<Value> sizes;
+ concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
// %t = concatenate %s1, %s2, %s3 {dim = 1}
// ==>
- // %tmp = bufferization.alloc_tensor : unordered COO
+ // if (isSparseDst)
+ // %tmp = bufferization.alloc_tensor : unordered COO
+ // else
+ // %tmp = memref.alloc : dense tensor
// foreach in %s1 : insert d0, d1, %tmp
// foreach in %s2 : insert d0, d1 + size(s1), %tmp
// foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
- // %t = sparse_tensor.cast %tmp
- auto cooTp = getUnorderedCOOFromType(rtp);
- auto cooBuffer =
- rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
- auto rank = rtp.getRank();
+ // %t = convert_to_dest_tensor(%tmp)
+ SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+ Value dst; // Destination tensor for inserting source tensor values.
+ if (encDst) {
+ SmallVector<Value> dynSizes;
+ getDynamicSizes(dstTp, sizes, dynSizes);
+ RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
+ dst = rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
+ } else {
+ // TODO: Dense buffers should be allocated/deallocated via the callback
+ // in BufferizationOptions.
+ dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
+ }
+
+ int64_t rank = dstTp.getRank();
Value offset = constantIndex(rewriter, loc, 0);
+ SmallVector<Value> initArgs;
+ if (encDst)
+ initArgs.push_back(dst);
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Build a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
- loc, input, cooBuffer,
+ loc, input, initArgs,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
SmallVector<Value> indices;
for (int64_t i = 0; i < rank; i++) {
Value idx = args[i];
if (i == static_cast<int64_t>(conDim))
- // transform coordinates on matching dim
+ // Transform coordinates for the concatenating dim.
idx = builder.create<arith::AddIOp>(loc, idx, offset);
indices.push_back(idx);
}
- Value cond = genIsNonzero(rewriter, loc, v);
- scf::IfOp ifOp = builder.create<scf::IfOp>(
- loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
- rewriter.create<scf::YieldOp>(loc, t);
- rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- rewriter.create<scf::YieldOp>(loc, reduc.front());
- rewriter.setInsertionPointAfter(ifOp);
- rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
+ if (encDst) {
+ Value cond = genIsNonzero(rewriter, loc, v);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(
+ loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value t =
+ builder.create<InsertOp>(loc, v, reduc.front(), indices);
+ rewriter.create<scf::YieldOp>(loc, t);
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, reduc.front());
+ rewriter.setInsertionPointAfter(ifOp);
+ rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
+ } else {
+ builder.create<memref::StoreOp>(loc, v, dst, indices);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ }
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
// dynamically.
- auto d = input.getType().cast<RankedTensorType>().getShape()[conDim];
+ int64_t d = input.getType().cast<RankedTensorType>().getShape()[conDim];
assert(!ShapedType::isDynamic(d));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, d));
- cooBuffer = foreachOp.getResult(0);
+ if (encDst) {
+ dst = foreachOp.getResult(0);
+ initArgs[0] = dst;
+ }
}
- cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
- Value converted =
- rewriter.create<ConvertOp>(loc, rtp, cooBuffer).getResult();
- rewriter.create<DeallocTensorOp>(loc, cooBuffer);
- rewriter.replaceOp(op, converted);
+ if (encDst) {
+ dst = rewriter.create<LoadOp>(loc, dst, true);
+ Value converted = rewriter.create<ConvertOp>(loc, dstTp, dst).getResult();
+ rewriter.create<DeallocTensorOp>(loc, dst);
+ rewriter.replaceOp(op, converted);
+ } else {
+ rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
+ }
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 9d26f3b561a01..af03df6e4a55a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -19,12 +19,12 @@
// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
+// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
@@ -39,12 +39,12 @@
// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
+// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
+// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
@@ -60,12 +60,12 @@
// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
+// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
+// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
@@ -106,12 +106,12 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
+// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
@@ -126,12 +126,12 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
+// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
+// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
@@ -147,12 +147,12 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
+// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
+// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
@@ -173,4 +173,86 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
tensor<3x4xf64, #DCSR>,
tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DCSR>
return %0 : tensor<?x?xf64, #DCSR>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: @concat_sparse_sparse_dense(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
+// CHECK: linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
+// CHECK: %[[TMP_1:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_2:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_3:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_4:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_8:.*]] = sparse_tensor.pointers %[[TMP_arg1]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_9:.*]] = sparse_tensor.indices %[[TMP_arg1]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_10:.*]] = sparse_tensor.pointers %[[TMP_arg1]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_11:.*]] = sparse_tensor.indices %[[TMP_arg1]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[TMP_15:.*]] = sparse_tensor.pointers %[[TMP_arg2]] {dimension = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_16:.*]] = sparse_tensor.indices %[[TMP_arg2]] {dimension = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_17:.*]] = sparse_tensor.pointers %[[TMP_arg2]] {dimension = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_18:.*]] = sparse_tensor.indices %[[TMP_arg2]] {dimension = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
+// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
+// CHECK: memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
+// CHECK: return %[[R]] : tensor<?x?xf64>
+func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
+ %arg1: tensor<3x4xf64, #DCSR>,
+ %arg2: tensor<4x4xf64, #DCSR>)
+ -> tensor<?x?xf64> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DCSR>,
+ tensor<3x4xf64, #DCSR>,
+ tensor<4x4xf64, #DCSR> to tensor<?x?xf64>
+ return %0 : tensor<?x?xf64>
+}
More information about the Mlir-commits
mailing list