[Mlir-commits] [mlir] f14c8eb - [mlir][sparse][gpu] refine SDDMM pattern for cuSPARSE
Aart Bik
llvmlistbot at llvm.org
Wed Jun 21 18:32:03 PDT 2023
Author: Aart Bik
Date: 2023-06-21T18:31:55-07:00
New Revision: f14c8eb595d0ba84dde09a0d5860f117c11e9ca1
URL: https://github.com/llvm/llvm-project/commit/f14c8eb595d0ba84dde09a0d5860f117c11e9ca1
DIFF: https://github.com/llvm/llvm-project/commit/f14c8eb595d0ba84dde09a0d5860f117c11e9ca1.diff
LOG: [mlir][sparse][gpu] refine SDDMM pattern for cuSPARSE
Old pattern was missing some cases (e.g. swapping the arguments)
but it also allowed too many cases (e.g. non-empty "absent" or
different arguments for add/mul). This fixes the issues.
Reviewed By: K-Wu
Differential Revision: https://reviews.llvm.org/D153487
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 86c429adf1d04..073aa8419ea53 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -301,12 +301,25 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
// Library helper methods.
//===----------------------------------------------------------------------===//
-/// Helper to detect a * b.
-static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
+/// Helper to detect a + b with arguments taken from given block.
+static bool matchAddOfArgs(Block *block, Value val) {
if (auto *def = val.getDefiningOp()) {
- if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
- Value a = op.getBlock()->getArguments()[0];
- Value b = op.getBlock()->getArguments()[1];
+ if (isa<arith::AddFOp, arith::AddIOp>(def)) {
+ Value a = block->getArguments()[0];
+ Value b = block->getArguments()[1];
+ return (def->getOperand(0) == a && def->getOperand(1) == b) ||
+ (def->getOperand(0) == b && def->getOperand(1) == a);
+ }
+ }
+ return false;
+}
+
+/// Helper to detect a * b with arguments taken from given block.
+static bool matchMulOfArgs(Block *block, Value val) {
+ if (auto *def = val.getDefiningOp()) {
+ if (isa<arith::MulFOp, arith::MulIOp>(def)) {
+ Value a = block->getArguments()[0];
+ Value b = block->getArguments()[1];
return (def->getOperand(0) == a && def->getOperand(1) == b) ||
(def->getOperand(0) == b && def->getOperand(1) == a);
}
@@ -318,67 +331,47 @@ static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
- if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
+ if (isa<arith::AddFOp, arith::AddIOp>(def)) {
Value x = op.getBlock()->getArguments()[2];
return (def->getOperand(0) == x &&
- matchMulOfArgs(op, def->getOperand(1))) ||
+ matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
(def->getOperand(1) == x &&
- matchMulOfArgs(op, def->getOperand(0)));
+ matchMulOfArgs(op.getBlock(), def->getOperand(0)));
}
}
return false;
}
-// Helper to detect c = c \spy (a * b)
+// Helper to detect c += spy(s) x (a * b)
static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
- auto def = yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>();
- if (!def)
- return false;
+ // The linalg yields a custom reduce result.
Value s_out = op.getBlock()->getArguments()[2];
- for (auto *use : s_out.getUsers()) {
- if (!(isa<sparse_tensor::ReduceOp>(use) ||
- isa<sparse_tensor::UnaryOp>(use)))
- return false;
- // The sparse matrix should be specified as the pattern in the two
- // operators.
- if (s_out != use->getOperand(0))
+ if (auto redOp =
+ yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
+ // The reduce consumes the output.
+ Value other;
+ if (s_out == redOp->getOperand(0))
+ other = redOp->getOperand(1);
+ else if (s_out == redOp->getOperand(1))
+ other = redOp->getOperand(0);
+ else
return false;
-
- // the above logic makes sure the pattern involves reduction and unary,
- // i.e.,
- // %1 = sparse_tensor.unary
- // %2 = sparse_tensor.reduce
- // we need to make sure %1 produces A*B and %2 uses summation as the
- // reduction operator.
- if (isa<sparse_tensor::ReduceOp>(use)) {
- auto reduceSpOp = cast<sparse_tensor::ReduceOp>(use);
- auto yieldSpOp = cast<sparse_tensor::YieldOp>(
- reduceSpOp.getRegion().front().getTerminator());
- auto *reduce = yieldSpOp.getOperand(0).getDefiningOp();
- if (!isa_and_nonnull<arith::AddFOp>(reduce) &&
- !isa_and_nonnull<arith::AddIOp>(reduce))
- return false;
- }
- if (isa<sparse_tensor::UnaryOp>(use)) {
- auto unarySpOp = cast<sparse_tensor::UnaryOp>(use);
- auto yieldSpOp = cast<sparse_tensor::YieldOp>(
- unarySpOp.getRegion(0).front().getTerminator());
- auto *unary = yieldSpOp.getOperand(0).getDefiningOp();
- if (!isa_and_nonnull<arith::MulIOp>(unary) &&
- !isa_and_nonnull<arith::MulFOp>(unary))
+ // The reduce op also consumes an unary which also consumes the output
+ // and does not define an absent value.
+ if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
+ if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
return false;
-
- // we also need to make sure the unary operation is used by the reduction
- // operation.
- for (auto *useUnary : unarySpOp->getUsers()) {
- if (!isa<sparse_tensor::ReduceOp>(useUnary)) {
- return false;
- }
- }
+ // And the bodies are as expected.
+ auto yieldUn = cast<sparse_tensor::YieldOp>(
+ unOp.getRegion(0).front().getTerminator());
+ auto yieldRed = cast<sparse_tensor::YieldOp>(
+ redOp.getRegion().front().getTerminator());
+ return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
+ matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
}
}
- return true;
+ return false;
}
/// Test for sorted COO with suitable data and coordinates types.
@@ -679,37 +672,24 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
return success();
}
-// TODO: identify alpha and beta and pass them to the CUDA calls
/// Match and rewrite SDDMM kernel.
static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
-
- // For now, this pass reuses C and copies the result non-zero elements to
- // overwrite C's.
- // As an ad hoc solution, this pass also assumes the linalg takes a,b,c as
- // input argument, and c as the output. It recognizes this pattern and rewrite
- // it.
-
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value b = op.getOperand(1);
Value c = op.getOperand(2);
-
SmallVector<Value> tokens;
- // Only admissible sparse matrix format and dense matrices.
+ // Only admissible sparse matrix format and dense matrices, no COO.
bool isCOO = false;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
-
if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
return failure();
-
- // cusparse currently does not support COO in its SDDMM kernel.
- if (isCOO) {
+ if (isCOO)
return failure();
- }
// The SDDMM does the in-place operation.
// Start sparse kernel and copy data from host to device.
@@ -798,8 +778,8 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
+ // Done.
rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
-
return success();
}
@@ -933,6 +913,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
bindDims(getContext(), i, j, k);
// TODO: more robust patterns, tranposed versions, more kernels...
+ // TODO: identify alpha and beta and pass them to the CUDA calls
// Recognize a SpMV kernel.
if (numLoops == 2 && numTensors == 3 &&
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
index 0ac9b4d76b38c..2cd9e2847a623 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops \
-// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s
+// RUN: mlir-opt %s --sparsification="enable-gpu-libgen" | FileCheck %s
#trait_sampled_dense_dense = {
indexing_maps = [
@@ -22,8 +21,6 @@
#CSR = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
-module {
-
// CHECK-LABEL: func.func @sparse_sampled_dd(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
@@ -82,7 +79,7 @@ module {
// A kernel that computes a direct sampled matrix matrix multiplication
// (with sparse result).
// Compute SDDMM C = C\spy AB
-//
+//
func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>,
%argA: tensor<8x8xf64>,
%argB: tensor<8x8xf64>) -> tensor<8x8xf64, #CSR> {
@@ -106,6 +103,4 @@ func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>,
linalg.yield %r : f64
} -> tensor<8x8xf64, #CSR>
return %result : tensor<8x8xf64, #CSR>
- }
-
}
More information about the Mlir-commits
mailing list