[Mlir-commits] [mlir] 1ea903e - [mlir][sparse][gpu] guard matvec COO AoS
Aart Bik
llvmlistbot at llvm.org
Mon Jun 12 16:50:06 PDT 2023
Author: Aart Bik
Date: 2023-06-12T16:49:58-07:00
New Revision: 1ea903e1641f907efc80278cbdf57b91b842716e
URL: https://github.com/llvm/llvm-project/commit/1ea903e1641f907efc80278cbdf57b91b842716e
DIFF: https://github.com/llvm/llvm-project/commit/1ea903e1641f907efc80278cbdf57b91b842716e.diff
LOG: [mlir][sparse][gpu] guard matvec COO AoS
Reviewed By: K-Wu
Differential Revision: https://reviews.llvm.org/D152738
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index e4e55574bbb68..fbe948da7445e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -350,13 +350,13 @@ static bool isAdmissibleCSR(SparseTensorType &aTp) {
/// Test for admissible types on operands (with output parameter `isCOO`).
static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
SparseTensorType cTp, bool enableRT,
- bool &isCOO) {
+ bool isMatVec, bool &isCOO) {
if (bTp.hasEncoding() || cTp.hasEncoding())
return false;
if (isAdmissibleCOO(aTp)) {
isCOO = true;
#ifdef CUSPARSE_COO_AOS
- return true;
+ return isMatVec;
#else
return enableRT;
#endif
@@ -424,7 +424,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType xTp = getSparseTensorType(x);
SparseTensorType yTp = getSparseTensorType(y);
- if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, isCOO))
+ if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
return failure();
// Start sparse kernel and copy data from host to device.
@@ -530,7 +530,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
- if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO))
+ if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
return failure();
// Start sparse kernel and copy data from host to device.
More information about the Mlir-commits
mailing list