[Mlir-commits] [mlir] b60cf8c - [mlir][sparse] support coiteration with fused reshape tensor
Peiming Liu
llvmlistbot at llvm.org
Wed Mar 1 12:55:51 PST 2023
Author: Peiming Liu
Date: 2023-03-01T20:55:46Z
New Revision: b60cf8c972c6986ff80547e49f2943266e6b2615
URL: https://github.com/llvm/llvm-project/commit/b60cf8c972c6986ff80547e49f2943266e6b2615
DIFF: https://github.com/llvm/llvm-project/commit/b60cf8c972c6986ff80547e49f2943266e6b2615.diff
LOG: [mlir][sparse] support coiteration with fused reshape tensor
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145091
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb2468580a57a..28f68622e3e95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -127,6 +127,25 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
return add;
}
+Value LoopEmitter::genSparseCoord(OpBuilder &builder, Location loc, size_t tid,
+ size_t l) {
+ Value c = constantIndex(builder, loc, 0);
+ auto reass = getCollapseReassociation(tid, l);
+ for (unsigned i = 0; i < reass.size(); i++) {
+ auto lvl = reass[i];
+ // A load on the indices array yields the coordinate.
+ Value ptr = idxBuffer[tid][lvl];
+ Value off = genIndexLoad(builder, loc, ptr, pidxs[tid][l]);
+ // Linearized the coordinates within the same collapse reassociation.
+ c = builder.create<arith::AddIOp>(loc, c, off);
+ if (i != reass.size() - 1) {
+ c = builder.create<arith::MulIOp>(loc, c,
+ this->lvlSizes[tid][reass[i + 1]]);
+ }
+ }
+ return c;
+}
+
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, ArrayRef<unsigned> topSort) {
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
@@ -383,20 +402,9 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
Value c;
if (isSparseInput) {
assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
- c = constantIndex(builder, loc, 0);
- for (unsigned i = 0; i < reass.size(); i++) {
- auto lvl = reass[i];
- // For COO, the pidxs are always the same across consecutive levels.
- pidxs[tid][lvl] = iv;
- // Generating a load on the indices array yields the coordinate.
- Value ptr = idxBuffer[tid][lvl];
- Value off = genIndexLoad(builder, loc, ptr, iv);
- c = builder.create<arith::AddIOp>(loc, c, off);
- if (i != reass.size() - 1) {
- c = builder.create<arith::MulIOp>(loc, c,
- this->lvlSizes[tid][reass[i + 1]]);
- }
- }
+ // For COO, the position is the same across consecutive levels.
+ llvm::for_each(reass, [this, tid, iv](int lvl) { pidxs[tid][lvl] = iv; });
+ c = genSparseCoord(builder, loc, tid, dim);
} else {
// Dense tensor, the coordinates is the inducation variable.
c = iv;
@@ -555,16 +563,22 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
builder.setInsertionPointToStart(&whileOp.getBefore().front());
Value cond;
unsigned o = 0;
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
+ for (auto [t, lvl] : llvm::zip(tids, dims)) {
+ unsigned tid = t; // Why `t` can not be captured by lambda?
+ if (isCompressedDLT(dimTypes[tid][lvl]) ||
+ isSingletonDLT(dimTypes[tid][lvl])) {
Value op1 = before->getArgument(o);
- Value op2 = highs[tid][dim];
+ Value op2 = highs[tid][lvl];
Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
op1, op2);
cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
- // Update
- pidxs[tid][dim] = after->getArgument(o++);
+ // Update positions
+ Value pos = after->getArgument(o++);
+ auto reass = getCollapseReassociation(tid, lvl);
+ assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+ // For COO, the position is the same across consecutive levels.
+ llvm::for_each(reass,
+ [this, tid, pos](int lvl) { pidxs[tid][lvl] = pos; });
}
}
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
@@ -578,11 +592,10 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Prepares for next level.
if (isCompressedDLT(dimTypes[tid][dim]) ||
isSingletonDLT(dimTypes[tid][dim])) {
- Value ptr = idxBuffer[tid][dim];
- Value s = pidxs[tid][dim];
- Value load = genIndexLoad(builder, loc, ptr, s);
- coord[tid][dim] = load;
+ coord[tid][dim] = genSparseCoord(builder, loc, tid, dim);
if (isSparseSlices[tid]) {
+ Value load =
+ genIndexLoad(builder, loc, idxBuffer[tid][dim], pidxs[tid][dim]);
auto enc = getSparseTensorEncoding(tensors[tid].getType());
auto [trans, pred] =
genSliceLegitPredicate(builder, loc, load, enc, dim);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index b8c287140b4c6..09d4afbf5dc96 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -190,6 +190,10 @@ class LoopEmitter {
Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim,
Value iv);
+ /// Generates instructions to compute the coordinate of tesnors[tid] on `l`
+ /// under the current loop context.
+ Value genSparseCoord(OpBuilder &builder, Location loc, size_t tid, size_t l);
+
bool isOutputTensor(size_t tid) {
return hasOutput && tid == tensors.size() - 1;
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
index 33606b01c89b7..d163e3e980ac8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
@@ -23,7 +23,7 @@ module {
func.func private @printMemref3dF32(%ptr : tensor<?x?x?xf32>) attributes { llvm.emit_c_interface }
func.func private @printMemref2dF32(%ptr : tensor<?x?xf32>) attributes { llvm.emit_c_interface }
- func.func @test_sparse(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+ func.func @test_sparse_rhs(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
%collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
%0 = tensor.empty() : tensor<5x6xf32>
%cst = arith.constant 0.000000e+00 : f32
@@ -34,6 +34,17 @@ module {
return %ret1 : tensor<?x?x?xf32>
}
+ func.func @test_sparse_all(%arg0: tensor<5x6xf32, #COO_2D>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+ %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
+ %0 = tensor.empty() : tensor<5x6xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+ %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+ return %ret1 : tensor<?x?x?xf32>
+ }
+
func.func @test_dense(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32>) -> tensor<?x?x?xf32> {
%collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32> into tensor<6x6xf32>
%0 = tensor.empty() : tensor<5x6xf32>
@@ -58,6 +69,9 @@ module {
[ 6.0, 7.0, 8.0]
> : tensor<6x2x3xf32>
+ %s1 = sparse_tensor.convert %d1 : tensor<5x6xf32> to tensor<5x6xf32, #COO_2D>
+ %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
+
// CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
// CHECK-NEXT:[
// CHECK-SAME: [
@@ -96,13 +110,35 @@ module {
// CHECK-NEXT: [
// CHECK-SAME: [0, 0, 0],
// CHECK-NEXT: [0, 0, 0]]]
- %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
- %so1 = call @test_sparse(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
+ %so1 = call @test_sparse_rhs(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
call @printMemref3dF32(%so1) : (tensor<?x?x?xf32>) -> ()
+ // Same results.
+ // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
+ // CHECK-NEXT:[
+ // CHECK-SAME: [
+ // CHECK-SAME: [6, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 14, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 24, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]]]
+ %so2 = call @test_sparse_all(%s1, %s2): (tensor<5x6xf32, #COO_2D>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
+ call @printMemref3dF32(%so2) : (tensor<?x?x?xf32>) -> ()
+
+ bufferization.dealloc_tensor %s1 : tensor<5x6xf32, #COO_2D>
bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D>
bufferization.dealloc_tensor %do1 : tensor<?x?x?xf32>
bufferization.dealloc_tensor %so1 : tensor<?x?x?xf32>
+ bufferization.dealloc_tensor %so2 : tensor<?x?x?xf32>
return
}
}
More information about the Mlir-commits
mailing list