[Mlir-commits] [mlir] 4fa00ce - [mlir][sparse] extend foreach operation to accept reduction arguments; fix sparse tensor rewriting patterns that do not propagate sparse tensor SSA properly.
Peiming Liu
llvmlistbot at llvm.org
Mon Nov 7 13:40:34 PST 2022
Author: Peiming Liu
Date: 2022-11-07T21:40:30Z
New Revision: 4fa00ce15c842aa8be495759723e2e2450591380
URL: https://github.com/llvm/llvm-project/commit/4fa00ce15c842aa8be495759723e2e2450591380
DIFF: https://github.com/llvm/llvm-project/commit/4fa00ce15c842aa8be495759723e2e2450591380.diff
LOG: [mlir][sparse] extend foreach operation to accept reduction arguments; fix sparse tensor rewriting patterns that do not propagate sparse tensor SSA properly.
This patch re-commit D137468 and D137463, which were reverted by mistakes.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137579
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5d667448e2f37..52a6aff752792 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -896,21 +896,44 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
[SingleBlockImplicitTerminator<"YieldOp">]>,
- Arguments<(ins AnyTensor:$tensor)>{
+ Arguments<(ins AnyTensor:$tensor,
+ Variadic<AnyType>:$initArgs)>,
+ Results<(outs Variadic<AnyType>:$results)> {
let summary = "Iterates over elements in a tensor";
let description = [{
Iterates over stored elements in a tensor (which are typically, but not always,
non-zero for sparse tensors) and executes the block.
- For an input tensor with rank n, the block must take n + 1 arguments. The
- first n arguments must be Index type, together indicating the current coordinates
- of the element being visited. The last argument must have the same type as the
+ For an input tensor with rank n, the block must take n + 1 (and additional loop
+ carried variables as described below) arguments. The first n arguments must be
+ Index type, together indicating the current coordinates of the element being visited.
+ The last argument must have the same type as the
tensor's element type, representing the actual value loaded from the input
tensor at the given coordinates.
- Note that foreach generated loop iterates over the stored elements in the storage
- order. However, no matter what storage order is used, the indices passed to the block
- always obey the original dimension order.
+ `sparse_tensor.foreach` can also operate on loop-carried variables and returns
+ the final values after loop termination. The initial values of the variables are
+ passed as additional SSA operands to the "sparse_tensor.foreach" following the n + 1
+ SSA values mentioned above (n coordinate and 1 value).
+
+ The region must terminate with a "sparse_tensor.yield" that passes the current
+ values of all loop-carried variables to the next iteration, or to the
+ result, if at the last iteration. The number and static types of loop-carried
+ variables may not change with iterations.
+
+ For example:
+ ```mlir
+ %c0 = arith.constant 0 : i32
+ %ret = sparse_tensor.foreach in %0 init(%c0): tensor<?x?xi32, #DCSR>, i32 -> i32 do {
+ ^bb0(%arg1: index, %arg2: index, %arg3: i32, %iter: i32):
+ %sum = arith.add %iter, %arg3
+ sparse_tensor.yield %sum
+ }
+ ```
+
+ It is important to note that foreach generated loop iterates over the stored elements
+ in the storage order. However, no matter what storage order is used, the indices passed
+ to the block always obey the original dimension order.
For example:
```mlir
@@ -918,10 +941,10 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
dimLevelType = [ "compressed", "compressed" ],
dimOrdering = affine_map<(i,j) -> (j,i)>
}>
-
+
// foreach on a column-major sparse tensor
sparse_tensor.foreach in %0 : tensor<2x3xf64, #COL_MAJOR> do {
- ^bb0(%row: index, %col: index, %arg3: f64):
+ ^bb0(%row: index, %col: index, %arg3: f64):
// [%row, %col] -> [0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1]
}
@@ -931,30 +954,25 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
// foreach on a row-major sparse tensor
sparse_tensor.foreach in %0 : tensor<2x3xf64, #ROW_MAJOR> do {
- ^bb0(%row: index, %col: index, %arg3: f64):
+ ^bb0(%row: index, %col: index, %arg3: f64):
// [%row, %col] -> [0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]
}
```
-
- Example:
-
- ```mlir
- sparse_tensor.foreach in %0 : tensor<?x?xf64, #DCSR> do {
- ^bb0(%arg1: index, %arg2: index, %arg3: f64):
- do something...
- }
- ```
}];
let builders = [
- OpBuilder<(
- ins "Value":$tensor,
- "function_ref<void(OpBuilder &, Location, ValueRange)>")>
+ OpBuilder<(ins "Value":$tensor,
+ "function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>")>,
+ OpBuilder<(ins "Value":$tensor,
+ "ValueRange":$iterArgs,
+ "function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>")>
];
- let regions = (region AnyRegion:$region);
- let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor) `do` $region";
+ let regions = (region SizedRegion<1>:$region);
+ let assemblyFormat = "`in` $tensor (`init``(`$initArgs^`)`)? attr-dict"
+ " `:` type($tensor) (`,` type($initArgs)^)?"
+ " (`->` type($results)^)? `do` $region";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6d6bd26251953..6a4177737df9f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -581,11 +581,20 @@ LogicalResult CompressOp::verify() {
void ForeachOp::build(
OpBuilder &builder, OperationState &result, Value tensor,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
- build(builder, result, tensor);
+ function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
+ bodyBuilder) {
+ build(builder, result, tensor, llvm::None, bodyBuilder);
+}
+
+void ForeachOp::build(
+ OpBuilder &builder, OperationState &result, Value tensor,
+ ValueRange initArgs,
+ function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
+ bodyBuilder) {
+ build(builder, result, initArgs.getTypes(), tensor, initArgs);
+ // Builds foreach body.
if (!bodyBuilder)
return;
-
auto rtp = tensor.getType().cast<RankedTensorType>();
int64_t rank = rtp.getRank();
@@ -594,31 +603,49 @@ void ForeachOp::build(
std::fill_n(std::back_inserter(blockArgTypes), rank, builder.getIndexType());
// Followed by one value.
blockArgTypes.push_back(rtp.getElementType());
+ // Followed by reduction variable.
+ blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
SmallVector<Location, 4> blockArgLocs;
- std::fill_n(std::back_inserter(blockArgLocs), rank + 1, tensor.getLoc());
+ std::fill_n(std::back_inserter(blockArgLocs), blockArgTypes.size(),
+ tensor.getLoc());
OpBuilder::InsertionGuard guard(builder);
auto ®ion = *result.regions.front();
Block *bodyBlock =
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
- bodyBuilder(builder, result.location, bodyBlock->getArguments());
+ bodyBuilder(builder, result.location,
+ bodyBlock->getArguments().slice(0, rank),
+ bodyBlock->getArguments()[rank],
+ bodyBlock->getArguments().drop_front(rank + 1));
}
LogicalResult ForeachOp::verify() {
auto t = getTensor().getType().cast<RankedTensorType>();
auto args = getBody()->getArguments();
- if (static_cast<size_t>(t.getRank()) + 1 != args.size())
+ if (static_cast<size_t>(t.getRank()) + 1 + getInitArgs().size() !=
+ args.size())
return emitError("Unmatched number of arguments in the block");
+ if (getNumResults() != getInitArgs().size())
+ return emitError("Mismatch in number of init arguments and results");
+
+ if (getResultTypes() != getInitArgs().getTypes())
+ return emitError("Mismatch in types of init arguments and results");
+
+ auto yield = cast<YieldOp>(getBody()->getTerminator());
+ if (yield.getNumOperands() != getNumResults() ||
+ yield.getOperands().getTypes() != getResultTypes())
+ return emitError("Mismatch in types of yield values and results");
+
for (int64_t i = 0, e = t.getRank(); i < e; i++)
if (args[i].getType() != IndexType::get(getContext()))
emitError(
llvm::formatv("Expecting Index type for argument at index {0}", i));
auto elemTp = t.getElementType();
- auto valueTp = args.back().getType();
+ auto valueTp = args[t.getRank()].getType();
if (elemTp != valueTp)
emitError(llvm::formatv("Unmatched element type between input tensor and "
"block argument, expected:{0}, got: {1}",
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index fc240b0b10c08..fcddcd27ed40b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -956,6 +956,9 @@ Value mlir::sparse_tensor::genValueForDense(OpBuilder &builder, Location loc,
return val;
}
+// FIXME:
+// 1. Dense tensors loop should be generated by loop emitter.
+// 2. Support reduction variables to propagate SSA chains properly.
void mlir::sparse_tensor::genDenseTensorOrSparseConstantIterLoop(
OpBuilder &builder, Location loc, Value src, unsigned rank,
function_ref<void(OpBuilder &, Location, Value, ValueRange)> bodyBuilder) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9c002f1ae0ec8..d0613c09503c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -356,8 +356,10 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
- rewriter.create<ForeachOp>(
- loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ ForeachOp foreachOp = rewriter.create<ForeachOp>(
+ loc, srcTensor, cooBuffer,
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+ ValueRange reduc) {
SmallVector<Value, 4> srcIndices;
SmallVector<Value, 4> dstIndices;
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
@@ -366,11 +368,11 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
translateIndicesArray(builder, loc, op.getReassociationIndices(),
srcIndices, srcSizes, dstSizes, dstIndices);
- builder.create<InsertOp>(loc, args.back(), cooBuffer, dstIndices);
- builder.create<sparse_tensor::YieldOp>(loc);
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstIndices);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
});
-
- rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
+ auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, t);
return success();
}
};
@@ -440,13 +442,16 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
rewriter.create<AllocTensorOp>(loc, cooTp, ValueRange()).getResult();
Value offset = constantIndex(rewriter, loc, 0);
+ ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Builds the indexing map.
// Build a for op for each input tensor to append new values into the
// output tensor.
- rewriter.create<ForeachOp>(
- loc, input, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ foreachOp = rewriter.create<ForeachOp>(
+ loc, input, cooBuffer,
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+ ValueRange reduc) {
SmallVector<Value, 4> indices;
for (int64_t i = 0; i < rank; i++) {
uint64_t dim =
@@ -457,8 +462,8 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
idx = builder.create<arith::AddIOp>(loc, idx, offset);
indices.push_back(idx);
}
- builder.create<InsertOp>(loc, args.back(), cooBuffer, indices);
- builder.create<sparse_tensor::YieldOp>(loc);
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
@@ -467,7 +472,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
assert(!ShapedType::isDynamic(d));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, d));
+ cooBuffer = foreachOp.getResult(0);
}
+
+ cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
rewriter.replaceOpWithNewOp<ConvertOp>(op, rtp, cooBuffer);
return success();
}
@@ -558,12 +566,13 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
sizesForTensor(rewriter, sizes, loc, srcTp, src);
Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
- rewriter.create<ForeachOp>(
- loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
- builder.create<memref::StoreOp>(loc, args.back(), dst,
- args.drop_back());
- builder.create<sparse_tensor::YieldOp>(loc);
- });
+ rewriter.create<ForeachOp>(loc, src, llvm::None,
+ [&](OpBuilder &builder, Location loc,
+ ValueRange args, Value v, ValueRange reduc) {
+ builder.create<memref::StoreOp>(loc, v, dst,
+ args);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
return success();
@@ -597,17 +606,19 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
srcTp = getUnorderedCOOFromType(srcTp);
tmpCoo =
rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
- rewriter.create<ForeachOp>(
- loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ auto foreachOp = rewriter.create<ForeachOp>(
+ loc, src, tmpCoo,
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+ ValueRange reduc) {
SmallVector<Value, 4> indices;
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
uint64_t dim = toStoredDim(encSrc, i);
indices.push_back(args[dim]);
}
- builder.create<InsertOp>(loc, args.back(), tmpCoo, indices);
- builder.create<sparse_tensor::YieldOp>(loc);
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
});
- src = tmpCoo;
+ src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
}
// Sort the COO tensor so that its elements are ordered via increasing
@@ -646,27 +657,31 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
getDynamicSizes(dstTp, srcSizes, dynDstSizes);
Value dst =
rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
- rewriter.create<ForeachOp>(
- loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ auto foreachOp = rewriter.create<ForeachOp>(
+ loc, src, dst,
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+ ValueRange reduc) {
SmallVector<Value, 4> indices;
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
uint64_t dim = toStoredDim(encDst, i);
indices.push_back(args[dim]);
}
- builder.create<InsertOp>(loc, args.back(), dst, indices);
- builder.create<sparse_tensor::YieldOp>(loc);
+ auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
+ builder.create<sparse_tensor::YieldOp>(loc, t);
});
- // Release the temporary COO if it is created.
+ // Release the temporary COO if it is created. Note that tmpCoo is
+ // invalidated due to foreach and updated to src.
if (tmpCoo)
- rewriter.create<DeallocTensorOp>(loc, tmpCoo);
+ rewriter.create<DeallocTensorOp>(loc, src);
// Directly replace op with dst results in bufferization error message
// "sparse tensor allocation should not escape function".
// As such, we insert a trivial tensor convert which will be removed by
// codegen.
rewriter.setInsertionPointAfter(op);
- rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, dst);
+ auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, t);
return success();
}
};
@@ -685,6 +700,8 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
int64_t rank = rtp.getRank();
auto enc = getSparseTensorEncoding(rtp);
+ SmallVector<Value> reduc = op.getInitArgs();
+
// 1. Generates loop for the sparse input.
SparseTensorLoopEmitter loopEmitter(ValueRange{input});
loopEmitter.initializeLoopEmit(rewriter, loc);
@@ -692,7 +709,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
// TODO: provide utility function for loop sequences that only contains
// one for loop?
loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast<size_t>(i));
- loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
+ // Note that reduc will be taken care of by loop emitter and get updated
+ // in place.
+ loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i, reduc);
}
SmallVector<Value, 4> coords;
@@ -707,15 +726,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
: rewriter.create<memref::LoadOp>(loc, vals, coords);
// 2. Inline the block in the foreach operator.
- Block::iterator inlinePos = rewriter.getInsertionPoint();
Block *srcBlock = op.getBody();
- // Remove sparse_tensor.yield.
- rewriter.eraseOp(srcBlock->getTerminator());
-
- for (int64_t i = 0; i < rank; i++) {
- loopEmitter.exitCurrentLoop(rewriter, loc);
- loopEmitter.exitCurrentLoopSeq();
- }
SmallVector<Value, 4> args;
// Remap coordinates.
@@ -725,11 +736,33 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}
// Remap value.
args.push_back(val);
+ // Remap reduction variables.
+ args.append(reduc);
+
+ // Remove sparse_tensor.yield.
+ SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
+ rewriter.eraseOp(srcBlock->getTerminator());
// Inline body.
- rewriter.mergeBlockBefore(srcBlock, &*inlinePos, args);
- // delete the foreach operator.
- rewriter.eraseOp(op);
+ if (!reducValue.empty()) {
+ rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
+ } else {
+ // This is annoying, since scf.for inserts a implicit yield op when
+ // there is no reduction variable upon creation, in this case we need to
+ // merge the block *before* the yield op.
+ rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args);
+ }
+
+ for (int64_t i = 0; i < rank; i++) {
+ // Link the reduction chain. Note that loop emitter update the reducValue
+ // in place.
+ loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
+ loopEmitter.exitCurrentLoopSeq();
+ }
+
+ // Replace the foreach operator with the value returned by the outtermost
+ // for loop.
+ rewriter.replaceOp(op, reducValue);
return success();
}
};
@@ -792,7 +825,8 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
.getResult(0);
Type eltTp = dstTp.getElementType();
Value value = genAllocaScalar(rewriter, loc, eltTp);
- scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1,
+ ArrayRef<Value>(cooBuffer));
rewriter.setInsertionPointToStart(forOp.getBody());
SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
@@ -807,13 +841,17 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
loc, indices, constantIndex(rewriter, loc, i)));
}
Value v = rewriter.create<memref::LoadOp>(loc, value);
- rewriter.create<InsertOp>(loc, v, cooBuffer, indicesArray);
+ auto t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
+ indicesArray);
+ rewriter.create<scf::YieldOp>(loc, ArrayRef<Value>(t));
rewriter.setInsertionPointAfter(forOp);
+ // Link SSA chain.
+ cooBuffer = forOp.getResult(0);
// Release the sparse tensor reader.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
-
+ cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
Value newOp = rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
// Release the unordered COO tensor buffer.
@@ -866,12 +904,14 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
ModuleOp module = op->getParentOfType<ModuleOp>();
// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
- loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ loc, src, llvm::None,
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+ ValueRange reduc) {
for (uint64_t i = 0; i < rank; i++) {
rewriter.create<memref::StoreOp>(loc, args[i], indices,
constantIndex(builder, loc, i));
}
- rewriter.create<memref::StoreOp>(loc, args.back(), value);
+ rewriter.create<memref::StoreOp>(loc, v, value);
SmallVector<Value, 4> operands{writer, rankValue, indices, value};
FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
EmitCInterface::On);
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index d67e11b92dd9c..cb1f16ef2cd20 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -116,6 +116,7 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-RWT: %[[V:.*]] = tensor.extract %[[A]]{{\[}}%[[FI]], %[[FJ]]] : tensor<2x4xf64>
// CHECK-RWT: %[[NZ:.*]] = arith.cmpf une, %[[V]], %[[F0]] : f64
// CHECK-RWT: scf.if %[[NZ]] {
+// // FIXME: the SSA chain is broken here!
// CHECK-RWT: %{{.*}} = sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[FI]], %[[FJ]]]
// CHECK-RWT: }
// CHECK-RWT: }
@@ -126,11 +127,13 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-RWT: %[[V2:.*]] = sparse_tensor.values %[[COO]]
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V2]]
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
-// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
-// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
-// CHECK-RWT: sparse_tensor.insert %[[FV]] into %[[DST]]{{\[}}%[[FI0]], %[[FI1]]]
+// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]])
+// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64, %[[R0:.*]]: tensor
+// CHECK-RWT: %[[RET:.*]] = sparse_tensor.insert %[[FV]] into %[[R0]]{{\[}}%[[FI0]], %[[FI1]]]
+// CHECK-RWT: sparse_tensor.yield %[[RET]]
// CHECK-RWT: }
-// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: %[[NT:.*]] = sparse_tensor.load %[[NEW_T]] hasInserts
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[NT]]
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
// CHECK-RWT: return %[[R]] : tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
@@ -179,6 +182,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-RWT: %[[I1r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C1]]] : tensor<2x2xi64>
// CHECK-RWT: %[[I1:.*]] = arith.index_cast %[[I1r]] : i64 to index
// CHECK-RWT: %[[V:.*]] = tensor.extract %[[SV]]{{\[}}%[[FI]]] : tensor<2xf32>
+// // FIXME: the SSA chain is broken here!
// CHECK-RWT: sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[I0]], %[[I1]]]
// CHECK-RWT: }
// CHECK-RWT: %[[TI0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
@@ -187,11 +191,13 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-RWT: %[[TV:.*]] = sparse_tensor.values %[[COO]]
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[TI0]], %[[TI1]] jointly %[[TV]]
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
-// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
-// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32):
-// CHECK-RWT: sparse_tensor.insert %[[F2V]] into %[[DST]]{{\[}}%[[F2I0]], %[[F2I1]]]
+// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]])
+// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32, %[[R0:.*]]: tensor
+// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.insert %[[F2V]] into %[[R0]]{{\[}}%[[F2I0]], %[[F2I1]]]
+// CHECK-RWT: sparse_tensor.yield %[[NEW_T]]
// CHECK-RWT: }
-// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]]
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
// CHECK-RWT: return %[[R]] : tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 92f9e46b90938..17145f8d37380 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -94,11 +94,13 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]]
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
-// CHECK-RWT: sparse_tensor.foreach in %[[A]]
-// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32):
-// CHECK-RWT: sparse_tensor.insert %[[FV2]] into %[[DST]]{{\[}}%[[FI2]]]
+// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[A]] init(%[[DST]])
+// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32, %[[T:.*]]: tensor<?xf32,
+// CHECK-RWT: %[[I:.*]] = sparse_tensor.insert %[[FV2]] into %[[T]]{{\[}}%[[FI2]]]
+// CHECK-RWT: sparse_tensor.yield %[[I]]
// CHECK-RWT: }
-// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]]
// CHECK-RWT: return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index dd27ce398c203..02fb97bc866c6 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -551,6 +551,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
// -----
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+ // expected-error at +1 {{Unmatched element type between input tensor and block argument}}
+ sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+ ^bb0(%1: index, %2: index, %v: f32) :
+ }
+ return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+ // expected-error at +1 {{Mismatch in number of init arguments and results}}
+ sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 do {
+ ^bb0(%1: index, %2: index, %v: f32, %r1 : i32) :
+ }
+ return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+ // expected-error at +1 {{Mismatch in types of init arguments and results}}
+ %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> i32 do {
+ ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) :
+ }
+ return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+ // expected-error at +1 {{Mismatch in types of yield values and results}}
+ %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> f32 do {
+ ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) :
+ sparse_tensor.yield %1 : index
+ }
+ return
+}
+
+// -----
+
// TODO: a test case with empty xs doesn't work due to some parser issues.
func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 79b616dec8304..3a6cf999df90a 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -18,18 +18,19 @@
// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
// CHECK: %[[VB:.*]] = memref.alloca()
-// CHECK: scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] {
+// CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]])
// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[V:.*]] = memref.load %[[VB]][]
-// CHECK: sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]]
+// CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]]
+// CHECK: scf.yield %[[T1]]
// CHECK: }
// CHECK: call @delSparseTensorReader(%[[R]])
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T]]
-// CHECK: bufferization.dealloc_tensor %[[T]]
+// CHECK: %[[T3:.*]] = sparse_tensor.load %[[T2]] hasInserts
+// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T3]]
+// CHECK: bufferization.dealloc_tensor %[[T3]]
// CHECK: return %[[R]]
-// CHECK: }
func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
return %0 : tensor<?x?xf32, #CSR>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 0ef58db148525..bc664ae3d2d00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -411,6 +411,26 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
return
}
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_tensor_foreach(
+// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[A1:.*]]: f32
+// CHECK-NEXT: %[[RET:.*]] = sparse_tensor.foreach in %[[A0]] init(%[[A1]])
+// CHECK-NEXT: ^bb0(%[[TMP_1:.*]]: index, %[[TMP_2:.*]]: index, %[[TMP_v:.*]]: f64, %[[TMP_r:.*]]: f32)
+// CHECK: sparse_tensor.yield %[[TMP_r]] : f32
+// CHECK: }
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+ %ret = sparse_tensor.foreach in %arg0 init(%arg1): tensor<2x4xf64, #DCSR>, f32 -> f32
+ do {
+ ^bb0(%1: index, %2: index, %v: f64, %r: f32) :
+ sparse_tensor.yield %r : f32
+ }
+ return
+}
+
// ----
// CHECK-LABEL: func @sparse_sort_1d0v(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 7280c6f5e7ba3..717819bd0cb16 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -19,16 +19,18 @@
// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] {
+// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] {
+// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: sparse_tensor.insert %[[TMP_28]] into %[[TMP_0]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_1]]
// CHECK: }
+// CHECK: scf.yield %[[RET_4]]
// CHECK: }
// CHECK: %[[TMP_8:.*]] = sparse_tensor.pointers %[[TMP_arg1]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor
// CHECK: %[[TMP_9:.*]] = sparse_tensor.indices %[[TMP_arg1]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor
@@ -37,17 +39,19 @@
// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] {
+// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] {
+// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: sparse_tensor.insert %[[TMP_28]] into %[[TMP_0]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_2]]
// CHECK: }
+// CHECK: scf.yield %[[RET_5]]
// CHECK: }
// CHECK: %[[TMP_15:.*]] = sparse_tensor.pointers %[[TMP_arg2]] {dimension = 0 : index} : tensor<4x4xf64, #sparse_tensor
// CHECK: %[[TMP_16:.*]] = sparse_tensor.indices %[[TMP_arg2]] {dimension = 0 : index} : tensor<4x4xf64, #sparse_tensor
@@ -56,19 +60,22 @@
// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] {
+// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] {
+// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: sparse_tensor.insert %[[TMP_28]] into %[[TMP_0]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: scf.yield %[[NEW_3]]
// CHECK: }
+// CHECK: scf.yield %[[RET_6]]
// CHECK: }
-// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_0]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
+// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
// CHECK: return %[[TMP_22]] : tensor<9x4xf64, #sparse_tensor
func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index c162bacffac96..94ee50197fa9c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -52,14 +52,16 @@
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[C10]] : index
// CHECK-RWT: %[[DI1:.*]] = arith.remui %[[SI]], %[[C10]] : index
-// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: %[[NT:.*]] = sparse_tensor.insert %[[SV]] into %[[R]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: scf.yield %[[NT:.*]]
// CHECK-RWT: }
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
@@ -111,25 +113,28 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
-// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
-// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
-// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
-// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
-// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
-// CHECK-RWT-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
-// CHECK-RWT-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
-// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
-// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
-// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
-// CHECK-RWT: %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index
-// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
-// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
-// CHECK-RWT }
-// CHECK-RWT: }
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
+// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK-RWT: %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK-RWT: %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index
+// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
+// CHECK-RWT: %[[R1:.*]] = sparse_tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[DI]]]
+// CHECK-RWT scf.yield %[[R1]]
+// CHECK-RWT }
+// CHECK-RWT scf.yield %[[RET_1]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
@@ -191,7 +196,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
// CHECK-RWT: %[[T1:.*]] = arith.muli %[[DD0]], %[[C10]] : index
@@ -200,9 +205,11 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-RWT: %[[T3:.*]] = arith.remui %[[SI]], %[[T2]] : index
// CHECK-RWT: %[[T4:.*]] = arith.divui %[[T2]], %[[C10]] : index
// CHECK-RWT: %[[DI1:.*]] = arith.divui %[[T3]], %[[T4]] : index
-// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: %[[NT:.*]] = sparse_tensor.insert %[[SV]] into %[[R]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: scf.yield %[[NT]]
// CHECK-RWT: }
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
// CHECK-RWT: return %[[T]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
@@ -260,28 +267,31 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
-// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
-// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
-// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
-// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
-// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
-// CHECK-RWT-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
-// CHECK-RWT-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
-// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
-// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
-// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
-// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
-// CHECK-RWT: %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index
-// CHECK-RWT: %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index
-// CHECK-RWT: %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index
-// CHECK-RWT: %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index
-// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index
-// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
-// CHECK-RWT }
-// CHECK-RWT: }
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R0:.*]] = %[[B]])
+// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK-RWT: %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[R1:.*]] = %[[R0]])
+// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK-RWT: %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index
+// CHECK-RWT: %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index
+// CHECK-RWT: %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index
+// CHECK-RWT: %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index
+// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index
+// CHECK-RWT: %[[NT:.*]] = sparse_tensor.insert %[[SV]] into %[[R1]]{{\[}}%[[DI]]]
+// CHECK-RWT scf.yield %[[NT]]
+// CHECK-RWT }
+// CHECK-RWT scf.yield %[[RET_1]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
// CHECK-RWT: return %[[T]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
More information about the Mlir-commits
mailing list