[Mlir-commits] [mlir] [mlir][sparse][gpu] add CSC to libgen GPU sparsification using cuSparse (PR #67713)
Aart Bik
llvmlistbot at llvm.org
Thu Sep 28 10:35:04 PDT 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/67713
Add CSC, but also adds BSR as a future format. Coming soon!
>From 081a71e9a217f928d36776f5352aaec2ec0e1733 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 28 Sep 2023 10:22:50 -0700
Subject: [PATCH] [mlir][sparse][gpu] add CSC to libgen GPU sparsification
using cuSparse
Also adds BSR as a future format. Coming soon!
---
.../Transforms/SparseGPUCodegen.cpp | 155 ++++++++++--------
.../GPU/CUDA/sparse-matmul-lib.mlir | 48 +++++-
.../GPU/CUDA/sparse-matvec-lib.mlir | 34 +++-
3 files changed, 167 insertions(+), 70 deletions(-)
mode change 100644 => 100755 mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
mode change 100644 => 100755 mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 91b346c8a9b4c4d..ba3fd751ce1c39e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;
namespace {
+// Sparse formats supported by cuSparse.
+enum class CuSparseFormat {
+ kNone,
+ kCOO,
+ kCSR,
+ kCSC,
+ kBSR, // TODO: coming soon!
+};
+
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
@@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
return false;
}
-/// Determines if the given value is a dense tensor instead of a sparse one.
+/// Test for dense tensor.
static bool isDenseTensor(Value v) {
- return (sparse_tensor::getSparseTensorType(v).isAllDense());
+ auto sTp = getSparseTensorType(v);
+ return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
+}
+
+/// Test for suitable positions/coordinates width.
+static bool isAdmissibleMetaData(SparseTensorType &aTp) {
+ return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
+ (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
}
-/// Test for sorted COO with suitable data and coordinates types.
+/// Test for sorted COO matrix with suitable metadata.
static bool isAdmissibleCOO(SparseTensorType &aTp) {
- return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
+ return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
+ aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
- (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
- (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
- aTp.getCrdWidth() == 64);
+ isAdmissibleMetaData(aTp);
}
-/// Test for CSR with suitable data and coordinates types.
+/// Test for CSR matrix with suitable metadata.
static bool isAdmissibleCSR(SparseTensorType &aTp) {
- return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
- aTp.isUniqueLvl(1) &&
- (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
- (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
- aTp.getCrdWidth() == 64);
+ return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
+ aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
+ aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
}
-/// Test for admissible types on operands (with output parameter `isCOO`).
-static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
- SparseTensorType cTp, bool enableRT,
- bool isMatVec, bool &isCOO) {
+/// Test for CSC matrix with suitable metadata.
+static bool isAdmissibleCSC(SparseTensorType &aTp) {
+ return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
+ aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
+ aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
+}
+
+/// Returns a suitable sparse format for the operation and given operand
+/// types with cuSparse, or kNone if none is available.
+static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
+ SparseTensorType bTp,
+ SparseTensorType cTp, bool enableRT,
+ bool isMatVec) {
+ // The other operands have a dense type.
if (bTp.hasEncoding() || cTp.hasEncoding())
- return false;
- if (isAdmissibleCOO(aTp)) {
- isCOO = true;
+ return CuSparseFormat::kNone;
+ // Now check for suitable operand type for the main operand.
+ if (isAdmissibleCOO(aTp))
#ifdef CUSPARSE_COO_AOS
- return isMatVec;
+ return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#else
- return enableRT;
+ return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#endif
- }
- return isAdmissibleCSR(aTp);
+ if (isAdmissibleCSR(aTp))
+ return CuSparseFormat::kCSR;
+ if (isAdmissibleCSC(aTp))
+ return CuSparseFormat::kCSC;
+ return CuSparseFormat::kNone;
}
/// Generates the first positions/coordinates of a sparse matrix.
static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
- bool isCOO, bool enableRT) {
- if (isCOO) {
+ CuSparseFormat format, bool enableRT) {
+ if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
return genToCoordinatesBuffer(builder, loc, a);
}
- // CSR uses positions.
+ // Formats CSR/CSC and BSR use positions at 1.
return genToPositions(builder, loc, a, 1);
}
/// Generates the second coordinates of a sparse matrix.
static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
- bool isCOO, bool enableRT) {
+ CuSparseFormat format, bool enableRT) {
+ bool isCOO = format == CuSparseFormat::kCOO;
if (isCOO && !enableRT)
return Value(); // nothing needed
+ // Formats CSR/CSC and BSR use coordinates at 1.
return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
}
-/// Generates the sparse matrix multiplication.
+/// Generates the sparse matrix handle.
static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
Type tokenTp, Value token, Value sz1, Value sz2,
Value nseA, Value rowA, Value colA, Value valA,
- bool isCOO, bool enableRT) {
- if (isCOO) {
+ CuSparseFormat format, bool enableRT) {
+ if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT) {
assert(colA);
@@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
#endif
}
assert(colA);
- return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
+ if (format == CuSparseFormat::kCSR)
+ return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
+ sz2, nseA, rowA, colA, valA);
+ assert(format == CuSparseFormat::kCSC);
+ return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
sz2, nseA, rowA, colA, valA);
}
@@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
- // Only admissible sparse matrix format and dense vectors.
- bool isCOO = false;
+ // Only admissible sparse matrix format and dense vectors (no BSR).
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType xTp = getSparseTensorType(x);
SparseTensorType yTp = getSparseTensorType(y);
- if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
+ auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();
// Start sparse kernel and copy data from host to device.
@@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
- Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
- Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+ Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+ Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = genToValues(rewriter, loc, a);
Value memX, memY;
Value castR, castC, castV, castX, castY;
@@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
- rowA, colA, valA, isCOO, enableRT);
+ rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
@@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
Value dnY = dvecY.getResult(0);
token = dvecY.getAsyncToken();
-
auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
// Precompute buffersize for SpMV.
@@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
- // Only admissible sparse matrix format and dense matrices.
- bool isCOO = false;
+ // Only admissible sparse matrix format and dense matrices (no BSR).
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
- if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
+ auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();
// Start sparse kernel and copy data from host to device.
@@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
- Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
- Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+ Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+ Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = genToValues(rewriter, loc, a);
Value bufB, bufC;
Value castR, castC, castV, castB, castBufC;
@@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
- rowA, colA, valA, isCOO, enableRT);
+ rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
SmallVector<Value>{szm, szn});
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();
-
auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
// Precompute buffersize for SpMM.
@@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();
-
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
// Perform the SpMM.
@@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
SmallVector<Value> tokens;
// Only CSR <- CSR x CSR supported.
- bool isCOO = false;
+ auto format = CuSparseFormat::kCSR;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
@@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
- Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
- Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+ Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+ Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value amemV = genToValues(rewriter, loc, a);
- Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
- Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
+ Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
+ Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
Value bmemV = genToValues(rewriter, loc, b);
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
@@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
- rowA, colA, valA, isCOO, enableRT);
+ rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
Operation *spGenB =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
- rowB, colB, valB, isCOO, enableRT);
+ rowB, colB, valB, format, enableRT);
Value spMatB = spGenB->getResult(0);
token = spGenB->getResult(1);
@@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
token = e3.getAsyncToken();
Operation *spGenC =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
- rowC, colC, valC, isCOO, enableRT);
+ rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
@@ -1045,14 +1074,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
- // Only admissible sparse matrix format and dense matrices, no COO.
- bool isCOO = false;
+ // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
- if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
- return failure();
- if (isCOO)
+ auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
+ format == CuSparseFormat::kCSC)
return failure();
// The SDDMM does the in-place operation.
@@ -1071,8 +1099,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value bufB = genTensorToMemref(rewriter, loc, b);
if (!isZeroCopy)
matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
- Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
- Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
+ Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
+ Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
Value memV = genToValues(rewriter, loc, c);
Value castB, castA, castR, castC, castV;
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
@@ -1107,10 +1135,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();
-
Operation *spGenC =
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
- rowC, colC, valC, isCOO, enableRT);
+ rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
old mode 100644
new mode 100755
index 6782f2d0e2014d7..9e397e0ad5b5dc2
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
@@ -34,6 +34,12 @@
crdWidth = 32
}>
+#CSC = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d1 : dense, d0 : compressed),
+ posWidth = 64,
+ crdWidth = 64
+}>
+
module {
llvm.func @mgpuCreateSparseEnv()
llvm.func @mgpuDestroySparseEnv()
@@ -58,6 +64,16 @@ module {
return %D: tensor<8x8xf32>
}
+ // Computes C = A x B with A sparse CSC.
+ func.func @matmulCSC(%A: tensor<8x8xf32, #CSC>,
+ %B: tensor<8x8xf32>,
+ %C: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %D = linalg.matmul
+ ins(%A, %B: tensor<8x8xf32, #CSC>, tensor<8x8xf32>)
+ outs(%C: tensor<8x8xf32>) -> tensor<8x8xf32>
+ return %D: tensor<8x8xf32>
+ }
+
func.func @dump(%mat: tensor<8x8xf32>) {
%f0 = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
@@ -107,6 +123,7 @@ module {
// Convert to a "sparse" matrix A.
%Acoo = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
%Acsr = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
+ %Acsc = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSC>
// Initial C matrices.
%C0 = tensor.generate {
@@ -125,10 +142,16 @@ module {
%1 = call @matmulCSR(%Acsr, %DA, %C0) : (tensor<8x8xf32, #CSR>,
tensor<8x8xf32>,
tensor<8x8xf32>) -> tensor<8x8xf32>
- %2 = call @matmulCOO(%Acoo, %DA, %C1) : (tensor<8x8xf32, #SortedCOO>,
+ %2 = call @matmulCSC(%Acsc, %DA, %C0) : (tensor<8x8xf32, #CSC>,
+ tensor<8x8xf32>,
+ tensor<8x8xf32>) -> tensor<8x8xf32>
+ %3 = call @matmulCOO(%Acoo, %DA, %C1) : (tensor<8x8xf32, #SortedCOO>,
+ tensor<8x8xf32>,
+ tensor<8x8xf32>) -> tensor<8x8xf32>
+ %4 = call @matmulCSR(%Acsr, %DA, %C1) : (tensor<8x8xf32, #CSR>,
tensor<8x8xf32>,
tensor<8x8xf32>) -> tensor<8x8xf32>
- %3 = call @matmulCSR(%Acsr, %DA, %C1) : (tensor<8x8xf32, #CSR>,
+ %5 = call @matmulCSC(%Acsc, %DA, %C1) : (tensor<8x8xf32, #CSC>,
tensor<8x8xf32>,
tensor<8x8xf32>) -> tensor<8x8xf32>
@@ -153,6 +176,24 @@ module {
// CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
// CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
//
+ // CHECK: ( 140, 168, 196, 224, 252, 280, 308, 336 )
+ // CHECK-NEXT: ( 168, 204, 240, 276, 312, 348, 384, 420 )
+ // CHECK-NEXT: ( 196, 240, 284, 328, 372, 416, 460, 504 )
+ // CHECK-NEXT: ( 224, 276, 328, 380, 432, 484, 536, 588 )
+ // CHECK-NEXT: ( 252, 312, 372, 432, 492, 552, 612, 672 )
+ // CHECK-NEXT: ( 280, 348, 416, 484, 552, 620, 688, 756 )
+ // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
+ // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
+ //
+ // CHECK: ( 141, 169, 197, 225, 253, 281, 309, 337 )
+ // CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
+ // CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
+ // CHECK-NEXT: ( 225, 277, 329, 381, 433, 485, 537, 589 )
+ // CHECK-NEXT: ( 253, 313, 373, 433, 493, 553, 613, 673 )
+ // CHECK-NEXT: ( 281, 349, 417, 485, 553, 621, 689, 757 )
+ // CHECK-NEXT: ( 309, 385, 461, 537, 613, 689, 765, 841 )
+ // CHECK-NEXT: ( 337, 421, 505, 589, 673, 757, 841, 925 )
+ //
// CHECK: ( 141, 169, 197, 225, 253, 281, 309, 337 )
// CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
// CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
@@ -175,10 +216,13 @@ module {
call @dump(%1) : (tensor<8x8xf32>) -> ()
call @dump(%2) : (tensor<8x8xf32>) -> ()
call @dump(%3) : (tensor<8x8xf32>) -> ()
+ call @dump(%4) : (tensor<8x8xf32>) -> ()
+ call @dump(%5) : (tensor<8x8xf32>) -> ()
// Release the resources.
bufferization.dealloc_tensor %Acoo : tensor<8x8xf32, #SortedCOO>
bufferization.dealloc_tensor %Acsr : tensor<8x8xf32, #CSR>
+ bufferization.dealloc_tensor %Acsc : tensor<8x8xf32, #CSC>
llvm.call @mgpuDestroySparseEnv(): () -> ()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
old mode 100644
new mode 100755
index eadc408fcc441b5..b569806b4028a8a
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
@@ -34,6 +34,12 @@
crdWidth = 32
}>
+#CSC = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d1 : dense, d0 : compressed),
+ posWidth = 64,
+ crdWidth = 64
+}>
+
module {
llvm.func @mgpuCreateSparseEnv()
llvm.func @mgpuDestroySparseEnv()
@@ -54,6 +60,14 @@ module {
return %y_out : tensor<?xf64>
}
+ // Compute matrix vector y = Ax on CSC with 64-bit positions and coordinates.
+ func.func @matvecCSC(%A: tensor<?x?xf64, #CSC>, %x: tensor<?xf64>, %y_in: tensor<?xf64>) -> tensor<?xf64> {
+ %y_out = linalg.matvec
+ ins(%A, %x: tensor<?x?xf64, #CSC>, tensor<?xf64>)
+ outs(%y_in: tensor<?xf64>) -> tensor<?xf64>
+ return %y_out : tensor<?xf64>
+ }
+
func.func @main() {
llvm.call @mgpuCreateSparseEnv() : () -> ()
%f0 = arith.constant 0.0 : f64
@@ -73,6 +87,7 @@ module {
// Convert to a "sparse" m x n matrix A.
%Acoo = sparse_tensor.convert %DA : tensor<64x64xf64> to tensor<?x?xf64, #SortedCOO>
%Acsr = sparse_tensor.convert %DA : tensor<64x64xf64> to tensor<?x?xf64, #CSR>
+ %Acsc = sparse_tensor.convert %DA : tensor<64x64xf64> to tensor<?x?xf64, #CSC>
// Initialize dense vector with n elements:
// (1, 2, 3, 4, ..., n)
@@ -103,19 +118,25 @@ module {
%1 = call @matvecCSR(%Acsr, %x, %y0) : (tensor<?x?xf64, #CSR>,
tensor<?xf64>,
tensor<?xf64>) -> tensor<?xf64>
- %2 = call @matvecCOO(%Acoo, %x, %y1) : (tensor<?x?xf64, #SortedCOO>,
+ %2 = call @matvecCSC(%Acsc, %x, %y0) : (tensor<?x?xf64, #CSC>,
+ tensor<?xf64>,
+ tensor<?xf64>) -> tensor<?xf64>
+ %3 = call @matvecCOO(%Acoo, %x, %y1) : (tensor<?x?xf64, #SortedCOO>,
+ tensor<?xf64>,
+ tensor<?xf64>) -> tensor<?xf64>
+ %4 = call @matvecCSR(%Acsr, %x, %y1) : (tensor<?x?xf64, #CSR>,
tensor<?xf64>,
tensor<?xf64>) -> tensor<?xf64>
- %3 = call @matvecCSR(%Acsr, %x, %y1) : (tensor<?x?xf64, #CSR>,
+ %5 = call @matvecCSC(%Acsc, %x, %y1) : (tensor<?x?xf64, #CSC>,
tensor<?xf64>,
tensor<?xf64>) -> tensor<?xf64>
//
// Sanity check on the results.
//
- // CHECK-COUNT-2: ( 87360, 89440, 91520, 93600, 95680, 97760, 99840, 101920, 104000, 106080, 108160, 110240, 112320, 114400, 116480, 118560, 120640, 122720, 124800, 126880, 128960, 131040, 133120, 135200, 137280, 139360, 141440, 143520, 145600, 147680, 149760, 151840, 153920, 156000, 158080, 160160, 162240, 164320, 166400, 168480, 170560, 172640, 174720, 176800, 178880, 180960, 183040, 185120, 187200, 189280, 191360, 193440, 195520, 197600, 199680, 201760, 203840, 205920, 208000, 210080, 212160, 214240, 216320, 218400 )
+ // CHECK-COUNT-3: ( 87360, 89440, 91520, 93600, 95680, 97760, 99840, 101920, 104000, 106080, 108160, 110240, 112320, 114400, 116480, 118560, 120640, 122720, 124800, 126880, 128960, 131040, 133120, 135200, 137280, 139360, 141440, 143520, 145600, 147680, 149760, 151840, 153920, 156000, 158080, 160160, 162240, 164320, 166400, 168480, 170560, 172640, 174720, 176800, 178880, 180960, 183040, 185120, 187200, 189280, 191360, 193440, 195520, 197600, 199680, 201760, 203840, 205920, 208000, 210080, 212160, 214240, 216320, 218400 )
//
- // CHECK-COUNT-2: ( 87361, 89441, 91521, 93601, 95681, 97761, 99841, 101921, 104001, 106081, 108161, 110241, 112321, 114401, 116481, 118561, 120641, 122721, 124801, 126881, 128961, 131041, 133121, 135201, 137281, 139361, 141441, 143521, 145601, 147681, 149761, 151841, 153921, 156001, 158081, 160161, 162241, 164321, 166401, 168481, 170561, 172641, 174721, 176801, 178881, 180961, 183041, 185121, 187201, 189281, 191361, 193441, 195521, 197601, 199681, 201761, 203841, 205921, 208001, 210081, 212161, 214241, 216321, 218401 )
+ // CHECK-COUNT-3: ( 87361, 89441, 91521, 93601, 95681, 97761, 99841, 101921, 104001, 106081, 108161, 110241, 112321, 114401, 116481, 118561, 120641, 122721, 124801, 126881, 128961, 131041, 133121, 135201, 137281, 139361, 141441, 143521, 145601, 147681, 149761, 151841, 153921, 156001, 158081, 160161, 162241, 164321, 166401, 168481, 170561, 172641, 174721, 176801, 178881, 180961, 183041, 185121, 187201, 189281, 191361, 193441, 195521, 197601, 199681, 201761, 203841, 205921, 208001, 210081, 212161, 214241, 216321, 218401 )
//
%pb0 = vector.transfer_read %0[%c0], %f0 : tensor<?xf64>, vector<64xf64>
vector.print %pb0 : vector<64xf64>
@@ -125,10 +146,15 @@ module {
vector.print %pb2 : vector<64xf64>
%pb3 = vector.transfer_read %3[%c0], %f0 : tensor<?xf64>, vector<64xf64>
vector.print %pb3 : vector<64xf64>
+ %pb4 = vector.transfer_read %4[%c0], %f0 : tensor<?xf64>, vector<64xf64>
+ vector.print %pb4 : vector<64xf64>
+ %pb5 = vector.transfer_read %5[%c0], %f0 : tensor<?xf64>, vector<64xf64>
+ vector.print %pb5 : vector<64xf64>
// Release the resources.
bufferization.dealloc_tensor %Acoo : tensor<?x?xf64, #SortedCOO>
bufferization.dealloc_tensor %Acsr : tensor<?x?xf64, #CSR>
+ bufferization.dealloc_tensor %Acsc : tensor<?x?xf64, #CSC>
llvm.call @mgpuDestroySparseEnv() : () -> ()
return
More information about the Mlir-commits
mailing list