[Mlir-commits] [mlir] 7175f9d - [mlir][sparse] extend foreach operation to iterator over sparse constant.
Peiming Liu
llvmlistbot at llvm.org
Tue Nov 8 17:50:39 PST 2022
Author: Peiming Liu
Date: 2022-11-09T01:50:34Z
New Revision: 7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3
URL: https://github.com/llvm/llvm-project/commit/7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3
DIFF: https://github.com/llvm/llvm-project/commit/7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3.diff
LOG: [mlir][sparse] extend foreach operation to iterator over sparse constant.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137679
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d0613c09503c0..829b0453d53fa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -696,11 +696,49 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
auto loc = op.getLoc();
Value input = op.getTensor();
+ SmallVector<Value> reduc = op.getInitArgs();
auto rtp = input.getType().cast<RankedTensorType>();
int64_t rank = rtp.getRank();
- auto enc = getSparseTensorEncoding(rtp);
- SmallVector<Value> reduc = op.getInitArgs();
+ // Special-case: for each over a sparse constant uses its own rewriting
+ // rule.
+ if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ // Foreach on constant.
+ DenseElementsAttr indicesAttr = attr.getIndices();
+ DenseElementsAttr valuesAttr = attr.getValues();
+
+ SmallVector<Value> args;
+ for (int i = 0, e = valuesAttr.size(); i < e; i++) {
+ auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
+ for (int j = 0; j < rank; j++) {
+ auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
+ auto coord = rewriter.create<arith::ConstantIndexOp>(
+ loc, coordAttr.getInt());
+ // Remaps coordinates.
+ args.push_back(coord);
+ }
+ // Remaps value.
+ auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+ args.push_back(val);
+ // Remaps iteration args.
+ args.append(reduc);
+ auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
+ Operation *yield = cloned.getBody()->getTerminator();
+ rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+ // clean up
+ args.clear();
+ rewriter.eraseOp(cloned);
+ reduc = yield->getOperands();
+ rewriter.eraseOp(yield);
+ }
+ rewriter.replaceOp(op, reduc);
+ return success();
+ }
+ }
+
+ // Otherwise, use loop emitter to generate loops.
+ auto enc = getSparseTensorEncoding(rtp);
// 1. Generates loop for the sparse input.
SparseTensorLoopEmitter loopEmitter(ValueRange{input});
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
index aeb63a0379322..fed119edbb2d3 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
@@ -26,6 +26,18 @@
}>
module {
+ /// uses foreach operator to print coords and values.
+ func.func @foreach_print_const() {
+ // Initialize a tensor.
+ %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
+ sparse_tensor.foreach in %0 : tensor<8x7xf32> do {
+ ^bb0(%1: index, %2: index, %v: f32) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f32
+ }
+ return
+ }
/// uses foreach operator to print coords and values.
func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) {
@@ -111,6 +123,13 @@ module {
// CHECK: 0
// CHECK-NEXT: 0
// CHECK-NEXT: 1
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 6
+ // CHECK-NEXT: 5
+ call @foreach_print_const() : () -> ()
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
// CHECK-NEXT: 0
// CHECK-NEXT: 1
// CHECK-NEXT: 2
More information about the Mlir-commits
mailing list