[Mlir-commits] [mlir] 7175f9d - [mlir][sparse] extend foreach operation to iterator over sparse constant.

Peiming Liu llvmlistbot at llvm.org
Tue Nov 8 17:50:39 PST 2022


Author: Peiming Liu
Date: 2022-11-09T01:50:34Z
New Revision: 7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3

URL: https://github.com/llvm/llvm-project/commit/7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3
DIFF: https://github.com/llvm/llvm-project/commit/7175f9dde19b5e38c8a68f8b6a4aee723c6a92b3.diff

LOG: [mlir][sparse] extend foreach operation to iterator over sparse constant.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137679

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d0613c09503c0..829b0453d53fa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -696,11 +696,49 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
 
     auto loc = op.getLoc();
     Value input = op.getTensor();
+    SmallVector<Value> reduc = op.getInitArgs();
     auto rtp = input.getType().cast<RankedTensorType>();
     int64_t rank = rtp.getRank();
-    auto enc = getSparseTensorEncoding(rtp);
 
-    SmallVector<Value> reduc = op.getInitArgs();
+    // Special-case: for each over a sparse constant uses its own rewriting
+    // rule.
+    if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
+      if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+        // Foreach on constant.
+        DenseElementsAttr indicesAttr = attr.getIndices();
+        DenseElementsAttr valuesAttr = attr.getValues();
+
+        SmallVector<Value> args;
+        for (int i = 0, e = valuesAttr.size(); i < e; i++) {
+          auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
+          for (int j = 0; j < rank; j++) {
+            auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
+            auto coord = rewriter.create<arith::ConstantIndexOp>(
+                loc, coordAttr.getInt());
+            // Remaps coordinates.
+            args.push_back(coord);
+          }
+          // Remaps value.
+          auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+          args.push_back(val);
+          // Remaps iteration args.
+          args.append(reduc);
+          auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
+          Operation *yield = cloned.getBody()->getTerminator();
+          rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+          // clean up
+          args.clear();
+          rewriter.eraseOp(cloned);
+          reduc = yield->getOperands();
+          rewriter.eraseOp(yield);
+        }
+        rewriter.replaceOp(op, reduc);
+        return success();
+      }
+    }
+
+    // Otherwise, use loop emitter to generate loops.
+    auto enc = getSparseTensorEncoding(rtp);
 
     // 1. Generates loop for the sparse input.
     SparseTensorLoopEmitter loopEmitter(ValueRange{input});

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
index aeb63a0379322..fed119edbb2d3 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
@@ -26,6 +26,18 @@
 }>
 
 module {
+  /// uses foreach operator to print coords and values.
+  func.func @foreach_print_const() {
+    // Initialize a tensor.
+    %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
+    sparse_tensor.foreach in %0 : tensor<8x7xf32> do {
+      ^bb0(%1: index, %2: index, %v: f32) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f32
+     }
+     return
+  }
 
   /// uses foreach operator to print coords and values.
   func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) {
@@ -111,6 +123,13 @@ module {
     // CHECK: 0
     // CHECK-NEXT: 0
     // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 5
+    call @foreach_print_const() : () -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
     // CHECK-NEXT: 0
     // CHECK-NEXT: 1
     // CHECK-NEXT: 2


        


More information about the Mlir-commits mailing list