[Mlir-commits] [mlir] 76a80a0 - [mlir][sparse][gpu] sparsifier GPU libgen for SpGEMM in cuSparse
Aart Bik
llvmlistbot at llvm.org
Thu Aug 10 14:52:25 PDT 2023
Author: Aart Bik
Date: 2023-08-10T14:52:16-07:00
New Revision: 76a80a080872350d70fc3b3d57b9db8bee54e1df
URL: https://github.com/llvm/llvm-project/commit/76a80a080872350d70fc3b3d57b9db8bee54e1df
DIFF: https://github.com/llvm/llvm-project/commit/76a80a080872350d70fc3b3d57b9db8bee54e1df.diff
LOG: [mlir][sparse][gpu] sparsifier GPU libgen for SpGEMM in cuSparse
With working integration end-to-end test
Reviewed By: K-Wu
Differential Revision: https://reviews.llvm.org/D157652
Added:
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index f9c35d8b14d2e0..98a61b19fc55ea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -151,13 +151,25 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
token, dynamicSizes, ValueRange());
}
+// Allocates a typed buffer on the host with given size.
+static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
+ Value size) {
+ const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
+ return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
+}
+
+// Allocates a typed buffer on the device with given size.
+static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
+ Value size, Value token) {
+ const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
+ return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
+ token, size, ValueRange());
+}
+
// Allocates a void buffer on the device with given size.
static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
Value token) {
- const auto memTp =
- MemRefType::get({ShapedType::kDynamic}, builder.getI8Type());
- return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
- token, size, ValueRange());
+ return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
}
/// Deallocates memory from the device.
@@ -198,7 +210,6 @@ static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
/// assume that the first buffer is the one allocated for output. We create
/// a set of properly chained asynchronous allocation/copy pairs to increase
/// overlap before launching the kernel.
-/// TODO: the output assumption may be a bit too brittle
static Value genParametersIn(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
@@ -571,6 +582,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
token = genDeallocMemRef(rewriter, loc, vecY, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
genHostUnregisterMemref(rewriter, loc, castR);
if (memC)
@@ -579,7 +591,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
genHostUnregisterMemref(rewriter, loc, castX);
genHostUnregisterMemref(rewriter, loc, castY);
}
- tokens.clear();
// Done.
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
@@ -630,7 +641,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
castB = genHostRegisterMemref(rewriter, loc, bufB);
castBufC = genHostRegisterMemref(rewriter, loc, bufC);
}
-
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -702,6 +712,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
token = genDeallocMemRef(rewriter, loc, matC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
genHostUnregisterMemref(rewriter, loc, castR);
if (memC)
@@ -710,14 +721,179 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
genHostUnregisterMemref(rewriter, loc, castB);
genHostUnregisterMemref(rewriter, loc, castC);
}
- tokens.clear();
// Done.
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
return success();
}
-// Match and rewrite 2:4 SpMM kernels.
+// Match and rewrite SpGEMM kernel.
+static LogicalResult
+rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
+ GPUDataTransferStrategy gpuDataTransferStrategy) {
+ Location loc = op.getLoc();
+ Value a = op.getOperand(0);
+ Value b = op.getOperand(1);
+ Value c = op.getOperand(2); // we have C = AB
+ SmallVector<Value> tokens;
+
+ // Only CSR <- CSR x CSR supported.
+ bool isCOO = false;
+ SparseTensorType aTp = getSparseTensorType(a);
+ SparseTensorType bTp = getSparseTensorType(b);
+ SparseTensorType cTp = getSparseTensorType(c);
+ if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
+ return failure();
+
+ // Start sparse kernel and copy data from host to device.
+ // a : amemR/amemC/amemV -> rowA,colA,valA
+ // b : bmemR/bmemC/bmemV -> rowB,colB,valB
+ // c : materializes
+ auto dnCType = cTp.getElementType();
+ Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
+ Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
+ Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
+ Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
+ Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
+ Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
+ Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+ Value amemV = genToValues(rewriter, loc, a);
+ Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
+ Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
+ Value bmemV = genToValues(rewriter, loc, b);
+ Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
+ Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
+ Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
+ Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
+ Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
+ Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
+ genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
+
+ // Create sparse environment and sparse matrix/dense vector handles.
+ Type indexTp = rewriter.getIndexType();
+ Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
+ Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
+ Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
+ Value token = genFirstWait(rewriter, loc);
+ Operation *spGenA =
+ genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
+ rowA, colA, valA, isCOO, enableRT);
+ Value spMatA = spGenA->getResult(0);
+ token = spGenA->getResult(1);
+ Operation *spGenB =
+ genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
+ rowB, colB, valB, isCOO, enableRT);
+ Value spMatB = spGenB->getResult(0);
+ token = spGenB->getResult(1);
+
+ // Sparse matrix C materializes (also assumes beta == 0).
+ Value zero = constantIndex(rewriter, loc, 0);
+ Value one = constantIndex(rewriter, loc, 1);
+ Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
+ auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
+ Value rowC = e1.getResult(0);
+ token = e1.getAsyncToken();
+ auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
+ Value colC = e2.getResult(0);
+ token = e2.getAsyncToken();
+ auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
+ Value valC = e3.getResult(0);
+ token = e3.getAsyncToken();
+ Operation *spGenC =
+ genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
+ rowC, colC, valC, isCOO, enableRT);
+ Value spMatC = spGenC->getResult(0);
+ token = spGenC->getResult(1);
+
+ // Precompute buffersizes for SpGEMM.
+ Operation *descOp =
+ rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
+ Value desc = descOp->getResult(0);
+ token = descOp->getResult(1);
+ Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+ loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
+ valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
+ Value bufferSz1 = work1->getResult(0);
+ token = work1->getResult(1);
+ auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
+ Value buffer1 = buf1.getResult(0);
+ token = buf1.getAsyncToken();
+ Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+ loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
+ bufferSz1, buffer1,
+ gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
+ token = work2->getResult(1);
+
+ // Compute step.
+ Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+ loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
+ valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
+ Value bufferSz2 = compute1->getResult(0);
+ token = compute1->getResult(1);
+ auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
+ Value buffer2 = buf2.getResult(0);
+ token = buf2.getAsyncToken();
+ Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+ loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
+ bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
+ token = compute2->getResult(1);
+
+ // Get sizes.
+ Operation *sizes = rewriter.create<gpu::SpGEMMGetSizeOp>(
+ loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
+ Value nnz = sizes->getResult(2);
+ token = sizes->getResult(3);
+ auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
+ colC = a2.getResult(0);
+ token = a2.getAsyncToken();
+ auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
+ valC = a3.getResult(0);
+ token = a3.getAsyncToken();
+
+ // Update C with new pointers and copy final product back into C.
+ Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
+ loc, tokenTp, token, spMatC, rowC, colC, valC);
+ token = update->getResult(0);
+ Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
+ loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
+ token = copy->getResult(0);
+
+ // Allocate buffers on host.
+ Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
+ Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
+ Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
+
+ // Copy data back to host and free all the resoures.
+ token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
+ .getAsyncToken();
+ token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
+ .getAsyncToken();
+ token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
+ .getAsyncToken();
+ token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
+ .getAsyncToken();
+ token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
+ token = genCopyMemRef(rewriter, loc, colH, colC, token);
+ token = genCopyMemRef(rewriter, loc, valH, valC, token);
+ tokens.push_back(token);
+ genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
+
+ // Done.
+ Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
+ Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
+ Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
+ rewriter.replaceOpWithNewOp<PackOp>(op, c.getType(), vt, ValueRange{rt, ct});
+ return success();
+}
+
+// Match and rewrite 2:4 SpMM kernel.
static LogicalResult
rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
GPUDataTransferStrategy gpuDataTransferStrategy) {
@@ -748,7 +924,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
castB = genHostRegisterMemref(rewriter, loc, bufB);
castC = genHostRegisterMemref(rewriter, loc, bufC);
}
-
if (isZeroCopy) {
matA = bufA;
matB = bufB;
@@ -756,10 +931,11 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
+
+ // Create sparse environment and sparse matrix/dense vector handles.
Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
-
Type indexTp = rewriter.getIndexType();
Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
@@ -768,7 +944,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
loc, spMatHandleTp, tokenTp, token, szm, szk,
gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
-
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -781,7 +956,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
SmallVector<Value>{szm, szn});
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();
-
auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
// Precompute buffersize for SpMM.
@@ -791,8 +965,8 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
/*computeType=*/dmatCType);
-
token = bufferComp.getAsyncToken();
+
Value bufferSz = bufferComp.getResult(0);
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
@@ -824,11 +998,9 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
.getAsyncToken();
SmallVector<Value> newDynamicSizes;
-
token = genDeallocMemRef(rewriter, loc, buffer, token);
token = genDeallocMemRef(rewriter, loc, buffer2, token);
token = genDeallocMemRef(rewriter, loc, buffer3, token);
-
if (!isZeroCopy)
token = genDeallocMemRef(rewriter, loc, matA, token);
if (!isZeroCopy)
@@ -837,12 +1009,14 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
token = genDeallocMemRef(rewriter, loc, matC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
genHostUnregisterMemref(rewriter, loc, castA);
genHostUnregisterMemref(rewriter, loc, castB);
genHostUnregisterMemref(rewriter, loc, castC);
}
- tokens.clear();
+
+ // Done.
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
return success();
}
@@ -889,7 +1063,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
Value memV = genToValues(rewriter, loc, c);
-
Value castB, castA, castR, castC, castV;
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
castB = genHostRegisterMemref(rewriter, loc, bufB);
@@ -899,7 +1072,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
castC = genHostRegisterMemref(rewriter, loc, memC);
castV = genHostRegisterMemref(rewriter, loc, memV);
}
-
if (isZeroCopy) {
matA = bufA;
matB = bufB;
@@ -930,8 +1102,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
rowC, colC, valC, isCOO, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
-
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
+
// Precompute buffersize for SDDMM.
auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
@@ -965,6 +1137,7 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
token = genDeallocMemRef(rewriter, loc, valC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
genHostUnregisterMemref(rewriter, loc, castB);
genHostUnregisterMemref(rewriter, loc, castA);
@@ -973,7 +1146,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
genHostUnregisterMemref(rewriter, loc, castC);
genHostUnregisterMemref(rewriter, loc, castV);
}
- tokens.clear();
// Done.
rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
@@ -986,7 +1158,7 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
/// Proof-of-concept rewriter. This rule generates a GPU implementation
/// for each outermost forall loop generated by the sparse compiler.
-/// TODO: right works with parallelization-strategy=dense-outer-loop
+/// TODO: right now works with parallelization-strategy=dense-outer-loop
/// but give this its own flags in the future
struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
@@ -1109,29 +1281,27 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
AffineExpr i, j, k;
bindDims(getContext(), i, j, k);
- // TODO: more robust patterns, tranposed versions, more kernels...
- // TODO: identify alpha and beta and pass them to the CUDA calls
+ // TODO: more robust patterns, tranposed versions, more kernels,
+ // identify alpha and beta and pass them to the CUDA calls.
// Recognize a SpMV kernel.
if (numLoops == 2 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isReductionIterator(iteratorTypes[1]) &&
- // TODO: add transposed {i, j}
maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
}
- // Recognize a SpMM kernel.
+ // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
if (numLoops == 3 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
- // TODO: add transposed {i, k}, {k, j}
- // TODO: maybe add transposed {i, j} in future
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
+ if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
+ return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy);
if (op->getAttr("DENSE24"))
return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);
-
return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
}
@@ -1140,8 +1310,6 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
- // TODO: add transposed {i, k}, {k, j}
- // TODO: maybe add transposed {i, j} in future
maps == infer({{i, k}, {k, j}, {i, j}}) &&
matchSumReductionOfMulUnary(op)) {
return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
new file mode 100644
index 00000000000000..a39fdd8dc0ac6a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
@@ -0,0 +1,81 @@
+//
+// NOTE: this test requires gpu-sm80
+//
+// without RT lib:
+//
+// RUN: mlir-opt %s \
+// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --e main --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+ lvlTypes = [ "dense", "compressed" ],
+ posWidth = 32,
+ crdWidth = 32
+}>
+
+module {
+ llvm.func @mgpuCreateSparseEnv()
+ llvm.func @mgpuDestroySparseEnv()
+
+ // Computes C = A x B with A,B,C sparse CSR.
+ func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
+ %B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
+ %init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR>
+ %C = linalg.matmul
+ ins(%A, %B: tensor<8x8xf32, #CSR>,
+ tensor<8x8xf32, #CSR>)
+ outs(%init: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
+ return %C: tensor<8x8xf32, #CSR>
+ }
+
+ //
+ // Main driver.
+ //
+ func.func @main() {
+ llvm.call @mgpuCreateSparseEnv(): () -> ()
+
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ %t = arith.constant dense<[
+ [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [ 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [ 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0],
+ [ 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0],
+ [ 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 9.0],
+ [ 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 11.0, 12.0],
+ [ 0.0, 13.0, 14.0, 0.0, 0.0, 0.0, 15.0, 16.0]
+ ]> : tensor<8x8xf32>
+ %Acsr = sparse_tensor.convert %t : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
+
+ %Ccsr = call @matmulCSR(%Acsr, %Acsr) : (tensor<8x8xf32, #CSR>,
+ tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
+
+ //
+ // Verify computed result (expected output, with only 20 nonzeros).
+ //
+ // CHECK: ( ( 1, 39, 52, 0, 0, 0, 45, 51 ),
+ // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 16, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 0, 25, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 0, 0, 36, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 117, 158, 0, 0, 0, 135, 144 ),
+ // CHECK-SAME: ( 0, 156, 318, 0, 0, 0, 301, 324 ),
+ // CHECK-SAME: ( 0, 208, 430, 0, 0, 0, 405, 436 ) )
+ // CHECK-NEXT: 20
+ %d = sparse_tensor.convert %Ccsr : tensor<8x8xf32, #CSR> to tensor<8x8xf32>
+ %v = vector.transfer_read %d[%c0, %c0], %f0: tensor<8x8xf32>, vector<8x8xf32>
+ vector.print %v : vector<8x8xf32>
+ %nnz = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR>
+ %x = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR>
+ vector.print %nnz : index
+
+ llvm.call @mgpuDestroySparseEnv(): () -> ()
+ return
+ }
+}
More information about the Mlir-commits
mailing list