[Mlir-commits] [mlir] 35b3a0c - [mlir][sparse] support foreach on dense tensor.
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 20 17:12:42 PDT 2022
Author: Peiming Liu
Date: 2022-10-21T00:12:37Z
New Revision: 35b3a0ce8d3fb38db07d216036ec341551456b6c
URL: https://github.com/llvm/llvm-project/commit/35b3a0ce8d3fb38db07d216036ec341551456b6c
DIFF: https://github.com/llvm/llvm-project/commit/35b3a0ce8d3fb38db07d216036ec341551456b6c.diff
LOG: [mlir][sparse] support foreach on dense tensor.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136384
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index edd19d22732b..e8e5a3e9f6a8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -834,16 +834,16 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
[SingleBlockImplicitTerminator<"YieldOp">]>,
- Arguments<(ins AnySparseTensor:$tensor)>{
- let summary = "Iterates over non-zero elements in a sparse tensor";
+ Arguments<(ins AnyTensor:$tensor)>{
+ let summary = "Iterates over elements in a tensor";
let description = [{
- Iterates over every non-zero element in the given sparse tensor and executes
- the block.
+ Iterates over stored elements in a tensor (which are typically, but not always,
+ non-zero for sparse tensors) and executes the block.
- For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+ For an input tensor with rank n, the block must take n + 1 arguments. The
first n arguments must be Index type, together indicating the current coordinates
of the element being visited. The last argument must have the same type as the
- sparse tensor's element type, representing the actual value loaded from the input
+ tensor's element type, representing the actual value loaded from the input
tensor at the given coordinates.
Example:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4707115c41fc..fc5fb767f516 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -480,14 +480,17 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
for (int64_t i = 0; i < rank; i++)
loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
- Value vals = loopEmitter.getTensorValueBuffer(0);
- Value idx = loopEmitter.getLastLevelTensorPointerIndex(0);
- Value val = rewriter.create<memref::LoadOp>(op.getLoc(), vals, idx);
-
SmallVector<Value, 4> coords;
coords.reserve(rank);
loopEmitter.getCoordinateArray(coords);
+ Value vals = loopEmitter.getTensorValueBuffer(0);
+ Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0);
+ // Loads the value from sparse tensor using pointer index;
+ // loads the value from dense tensor using coordinate array.
+ Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
+ : rewriter.create<memref::LoadOp>(loc, vals, coords);
+
for (int64_t i = 0; i < rank; i++)
loopEmitter.exitCurrentLoop();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
index 685b34ba5a0c..aeb63a037932 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
@@ -78,6 +78,16 @@ module {
return
}
+ func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) {
+ sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+
//
// Main driver.
//
@@ -109,6 +119,19 @@ module {
// CHECK-NEXT: 5
// CHECK-NEXT: 1
// CHECK-NEXT: 1
+ // CHECK-NEXT: 6
+ call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> ()
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 1
// CHECK-NEXT: 6
call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> ()
// CHECK-NEXT: 0
More information about the Mlir-commits
mailing list