[Mlir-commits] [mlir] 65bfd5c - [mlir][sparse] proper in-place SDDMM with spy function
Aart Bik
llvmlistbot at llvm.org
Thu Jun 15 13:59:46 PDT 2023
Author: Aart Bik
Date: 2023-06-15T13:59:38-07:00
New Revision: 65bfd5cb2530c36fe08fc003132e44f95f0f5ae6
URL: https://github.com/llvm/llvm-project/commit/65bfd5cb2530c36fe08fc003132e44f95f0f5ae6
DIFF: https://github.com/llvm/llvm-project/commit/65bfd5cb2530c36fe08fc003132e44f95f0f5ae6.diff
LOG: [mlir][sparse] proper in-place SDDMM with spy function
This specific operation matches the cuSPARSE SDDMM semantics exactly.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D152969
Added:
mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 428bc49d14ac9..881e02ea0f91c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1136,10 +1136,26 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
Value e, LoopId ldx) {
- if (Operation *def = e.getDefiningOp()) {
+ if (auto arg = dyn_cast<BlockArgument>(e)) {
+ // Direct arguments of the original linalg op must be converted
+ // into dense tensor loads. Note that we should not encounter
+ // anything else. This needs to be verified by semi-ring ops.
+ linalg::GenericOp op = env.op();
+ if (arg.getOwner()->getParentOp() == op) {
+ const TensorId tid = env.makeTensorId(arg.getArgNumber());
+ OpOperand *t = &op->getOpOperand(tid);
+ assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
+ SmallVector<Value> args;
+ Value ptr = genSubscript(env, rewriter, t, args);
+ return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
+ }
+ } else if (Operation *def = e.getDefiningOp()) {
+ // Handle index computation.
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
+ // When still defined in new body, recurse into operands.
if (def->getBlock() == block) {
+ rewriter.setInsertionPoint(def);
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
def->setOperand(
@@ -1195,8 +1211,10 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
if (ee &&
(kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
kind == TensorExp::Kind::kBinaryBranch ||
- kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect))
+ kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) {
+ OpBuilder::InsertionGuard guard(rewriter);
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+ }
}
if (kind == TensorExp::Kind::kReduce)
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
new file mode 100755
index 0000000000000..8bc405a4ccf52
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+//
+// A SDDMM implementation with "spy" function and
+// in-place update of the sampling sparse matrix.
+//
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait_sampled_dense_dense = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,k)>, // A
+ affine_map<(i,j,k) -> (k,j)>, // B
+ affine_map<(i,j,k) -> (i,j)> // S
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)"
+}
+
+// CHECK-LABEL: func.func @sparse_sampled_dd(
+// CHECK-SAME: %[[VAL_0:.*0]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_1:.*1]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_2:.*2]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<8x8xf64>
+// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf64>
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x8xf64>
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<8x8xf64>
+// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f64
+// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_18]], %[[VAL_21]] : f64
+// CHECK: memref.store %[[VAL_22]], %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_2]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: return %[[VAL_23]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: }
+func.func @sparse_sampled_dd(%argA: tensor<8x8xf64>,
+ %argB: tensor<8x8xf64>,
+ %argS: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+ %f0 = arith.constant 0.0 : f64
+ %result = linalg.generic #trait_sampled_dense_dense
+ ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>) outs(%argS: tensor<8x8xf64, #SM>) {
+ ^bb(%a: f64, %b: f64, %s: f64):
+ %u = sparse_tensor.unary %s : f64 to f64
+ present={
+ ^bb0(%p: f64):
+ %mul = arith.mulf %a, %b : f64
+ sparse_tensor.yield %mul : f64
+ }
+ absent={}
+ %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
+ ^bb0(%p: f64, %q: f64):
+ %add = arith.addf %p, %q : f64
+ sparse_tensor.yield %add : f64
+ }
+ linalg.yield %r : f64
+ } -> tensor<8x8xf64, #SM>
+ return %result : tensor<8x8xf64, #SM>
+}
More information about the Mlir-commits
mailing list