[Mlir-commits] [mlir] 35b3a0c - [mlir][sparse] support foreach on dense tensor.

Peiming Liu llvmlistbot at llvm.org
Thu Oct 20 17:12:42 PDT 2022


Author: Peiming Liu
Date: 2022-10-21T00:12:37Z
New Revision: 35b3a0ce8d3fb38db07d216036ec341551456b6c

URL: https://github.com/llvm/llvm-project/commit/35b3a0ce8d3fb38db07d216036ec341551456b6c
DIFF: https://github.com/llvm/llvm-project/commit/35b3a0ce8d3fb38db07d216036ec341551456b6c.diff

LOG: [mlir][sparse] support foreach on dense tensor.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136384

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index edd19d22732b..e8e5a3e9f6a8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -834,16 +834,16 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
     [SingleBlockImplicitTerminator<"YieldOp">]>,
-    Arguments<(ins AnySparseTensor:$tensor)>{
-  let summary = "Iterates over non-zero elements in a sparse tensor";
+    Arguments<(ins AnyTensor:$tensor)>{
+  let summary = "Iterates over elements in a tensor";
   let description = [{
-     Iterates over every non-zero element in the given sparse tensor and executes
-     the block.
+     Iterates over stored elements in a tensor (which are typically, but not always,
+     non-zero for sparse tensors) and executes the block.
 
-     For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+     For an input tensor with rank n, the block must take n + 1 arguments. The
      first n arguments must be Index type, together indicating the current coordinates
      of the element being visited. The last argument must have the same type as the
-     sparse tensor's element type, representing the actual value loaded from the input
+     tensor's element type, representing the actual value loaded from the input
      tensor at the given coordinates.
 
      Example:

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4707115c41fc..fc5fb767f516 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -480,14 +480,17 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     for (int64_t i = 0; i < rank; i++)
       loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
 
-    Value vals = loopEmitter.getTensorValueBuffer(0);
-    Value idx = loopEmitter.getLastLevelTensorPointerIndex(0);
-    Value val = rewriter.create<memref::LoadOp>(op.getLoc(), vals, idx);
-
     SmallVector<Value, 4> coords;
     coords.reserve(rank);
     loopEmitter.getCoordinateArray(coords);
 
+    Value vals = loopEmitter.getTensorValueBuffer(0);
+    Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0);
+    // Loads the value from sparse tensor using pointer index;
+    // loads the value from dense tensor using coordinate array.
+    Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
+                    : rewriter.create<memref::LoadOp>(loc, vals, coords);
+
     for (int64_t i = 0; i < rank; i++)
       loopEmitter.exitCurrentLoop();
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
index 685b34ba5a0c..aeb63a037932 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
@@ -78,6 +78,16 @@ module {
      return
   }
 
+  func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) {
+    sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+      vector.print %1: index
+      vector.print %2: index
+      vector.print %v: f64
+   }
+   return
+  }
+  
   //
   // Main driver.
   //
@@ -109,6 +119,19 @@ module {
     // CHECK-NEXT: 5
     // CHECK-NEXT: 1
     // CHECK-NEXT: 1
+    // CHECK-NEXT: 6    
+    call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
     // CHECK-NEXT: 6
     call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> ()
     // CHECK-NEXT: 0


        


More information about the Mlir-commits mailing list