[Mlir-commits] [mlir] 1ea903e - [mlir][sparse][gpu] guard matvec COO AoS

Aart Bik llvmlistbot at llvm.org
Mon Jun 12 16:50:06 PDT 2023


Author: Aart Bik
Date: 2023-06-12T16:49:58-07:00
New Revision: 1ea903e1641f907efc80278cbdf57b91b842716e

URL: https://github.com/llvm/llvm-project/commit/1ea903e1641f907efc80278cbdf57b91b842716e
DIFF: https://github.com/llvm/llvm-project/commit/1ea903e1641f907efc80278cbdf57b91b842716e.diff

LOG: [mlir][sparse][gpu] guard matvec COO AoS

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D152738

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index e4e55574bbb68..fbe948da7445e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -350,13 +350,13 @@ static bool isAdmissibleCSR(SparseTensorType &aTp) {
 /// Test for admissible types on operands (with output parameter `isCOO`).
 static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
                                SparseTensorType cTp, bool enableRT,
-                               bool &isCOO) {
+                               bool isMatVec, bool &isCOO) {
   if (bTp.hasEncoding() || cTp.hasEncoding())
     return false;
   if (isAdmissibleCOO(aTp)) {
     isCOO = true;
 #ifdef CUSPARSE_COO_AOS
-    return true;
+    return isMatVec;
 #else
     return enableRT;
 #endif
@@ -424,7 +424,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType xTp = getSparseTensorType(x);
   SparseTensorType yTp = getSparseTensorType(y);
-  if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, isCOO))
+  if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
     return failure();
 
   // Start sparse kernel and copy data from host to device.
@@ -530,7 +530,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
   SparseTensorType cTp = getSparseTensorType(c);
-  if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO))
+  if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
     return failure();
 
   // Start sparse kernel and copy data from host to device.


        


More information about the Mlir-commits mailing list