[Mlir-commits] [mlir] ee42e23 - [mlir][sparse][gpu] first implementation of the GPU libgen approach

Aart Bik llvmlistbot at llvm.org
Mon May 15 08:49:48 PDT 2023


Author: Aart Bik
Date: 2023-05-15T08:49:38-07:00
New Revision: ee42e23614c789088a1528d41926d47c94e8ccdf

URL: https://github.com/llvm/llvm-project/commit/ee42e23614c789088a1528d41926d47c94e8ccdf
DIFF: https://github.com/llvm/llvm-project/commit/ee42e23614c789088a1528d41926d47c94e8ccdf.diff

LOG: [mlir][sparse][gpu] first implementation of the GPU libgen approach

The sparse compiler now has two prototype strategies for GPU acceleration:

* CUDA codegen: this converts sparsified code to CUDA threads
* CUDA libgen: this converts pre-sparsified code to cuSPARSE library calls

This revision introduces the first steps required for the second approach.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D150170

Added: 
    mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 8da020c09c1e5..febb0113cf993 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -130,9 +130,16 @@ struct SparseCompilerOptions
   PassOptions::Option<std::string> gpuFeatures{*this, "gpu-features",
                                                desc("GPU target features")};
 
+  /// This option is used to enable GPU library generation.
+  PassOptions::Option<bool> enableGPULibgen{
+      *this, "enable-gpu-libgen",
+      desc("Enables GPU acceleration by means of direct library calls (like "
+           "cuSPARSE)")};
+
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
-    return SparsificationOptions(parallelization, enableIndexReduction);
+    return SparsificationOptions(parallelization, enableIndexReduction,
+                                 enableGPULibgen, enableRuntimeLibrary);
   }
 
   /// Projects out the options for `createSparseTensorConversionPass`.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c69dfb77f6cbe..c2942cf7be0b4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -49,12 +49,17 @@ enum class SparseParallelizationStrategy {
 
 /// Options for the Sparsification pass.
 struct SparsificationOptions {
-  SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc)
-      : parallelizationStrategy(p), enableIndexReduction(idxReduc) {}
+  SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc,
+                        bool gpuLibgen, bool enableRT)
+      : parallelizationStrategy(p), enableIndexReduction(idxReduc),
+        enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {}
   SparsificationOptions()
-      : SparsificationOptions(SparseParallelizationStrategy::kNone, false) {}
+      : SparsificationOptions(SparseParallelizationStrategy::kNone, false,
+                              false, true) {}
   SparseParallelizationStrategy parallelizationStrategy;
   bool enableIndexReduction;
+  bool enableGPULibgen;
+  bool enableRuntimeLibrary;
 };
 
 /// Sets up sparsification rewriting rules with the given options.
@@ -206,6 +211,9 @@ std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
 void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
                                       unsigned numThreads);
 
+void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
+                                     bool enableRT);
+
 std::unique_ptr<Pass> createSparseGPUCodegenPass();
 std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads);
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3ea68f9da6700..962399931d933 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -73,6 +73,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
     "affine::AffineDialect",
     "arith::ArithDialect",
     "bufferization::BufferizationDialect",
+    "gpu::GPUDialect",
     "LLVM::LLVMDialect",
     "linalg::LinalgDialect",
     "memref::MemRefDialect",
@@ -100,7 +101,12 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
                         "Enable dense parallelization for any loop."),
              clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
                         "any-storage-any-loop",
-                        "Enable sparse parallelization for any storage and loop."))}]>
+                        "Enable sparse parallelization for any storage and loop."))}]>,
+    Option<"enableGPULibgen", "enable-gpu-libgen", "bool",
+           "false",
+           "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">,
+    Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
+           "true", "Enable runtime library for manipulating sparse tensors">,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index bb52e08686fe5..ac199e02d95d6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -1,4 +1,4 @@
-//===- SparseGPUCodegen.cpp - Generates GPU code (using CUDA) -------------===//
+//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -18,9 +18,12 @@
 
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -140,8 +143,7 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
   SmallVector<Value> dynamicSizes;
   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
     if (shape[r] == ShapedType::kDynamic) {
-      Value dim = constantIndex(builder, loc, r);
-      Value dimOp = builder.create<memref::DimOp>(loc, mem, dim);
+      Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
       dynamicSizes.push_back(dimOp);
     }
   }
@@ -149,6 +151,15 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
                                       token, dynamicSizes, ValueRange());
 }
 
+// Allocates a void buffer on the device with given size.
+static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
+                                   Value token) {
+  const auto memTp =
+      MemRefType::get({ShapedType::kDynamic}, builder.getI8Type());
+  return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
+                                      token, size, ValueRange());
+}
+
 /// Deallocates memory from the device.
 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
                               Value token) {
@@ -163,6 +174,26 @@ static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
       .getAsyncToken();
 }
 
+/// Generates an alloc/copy pair.
+static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
+                          SmallVectorImpl<Value> &tokens) {
+  Value firstToken = genFirstWait(builder, loc);
+  auto alloc = genAllocMemRef(builder, loc, b, firstToken);
+  Value devMem = alloc.getResult(0);
+  Value depToken = alloc.getAsyncToken(); // copy-after-alloc
+  tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
+  return devMem;
+}
+
+/// Generates a memref from tensor operation.
+static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
+                               Value tensor) {
+  auto tensorType = tensor.getType().cast<ShapedType>();
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+  return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
+}
+
 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
 /// assume that the first buffer is the one allocated for output. We create
 /// a set of properly chained asynchronous allocation/copy pairs to increase
@@ -186,12 +217,7 @@ static Value genParametersIn(OpBuilder &builder, Location loc,
       useHostRegistrationForOut = false;
       continue;
     }
-    Value firstToken = genFirstWait(builder, loc);
-    auto alloc = genAllocMemRef(builder, loc, b, firstToken);
-    Value devMem = alloc.getResult(0);
-    Value depToken = alloc.getAsyncToken(); // copy-after-alloc
-    args.push_back(devMem);
-    tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
+    args.push_back(genAllocCopy(builder, loc, b, tokens));
   }
   return out;
 }
@@ -272,10 +298,216 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
 }
 
 //===----------------------------------------------------------------------===//
-// Rewriting rules.
+// Library helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Helper to detect a * b.
+static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
+  if (auto *def = val.getDefiningOp()) {
+    if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
+      Value a = op.getBlock()->getArguments()[0];
+      Value b = op.getBlock()->getArguments()[1];
+      return (def->getOperand(0) == a && def->getOperand(1) == b) ||
+             (def->getOperand(0) == b && def->getOperand(1) == a);
+    }
+  }
+  return false;
+}
+
+/// Helper to detect x = x + a * b
+static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
+  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
+  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
+    if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
+      Value x = op.getBlock()->getArguments()[2];
+      return (def->getOperand(0) == x &&
+              matchMulOfArgs(op, def->getOperand(1))) ||
+             (def->getOperand(1) == x &&
+              matchMulOfArgs(op, def->getOperand(0)));
+    }
+  }
+  return false;
+}
+
+/// Test for sorted COO with suitable data and coordinates types.
+static bool isAdmissibleCOO(SparseTensorType &aTp) {
+  return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
+         aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
+         (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
+         (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
+          aTp.getCrdWidth() == 64);
+}
+
+/// Test for CSR with suitable data and coordinates types.
+static bool isAdmissibleCSR(SparseTensorType &aTp) {
+  return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
+         aTp.isUniqueLvl(1) &&
+         (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
+         (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
+          aTp.getCrdWidth() == 64);
+}
+
+/// Generates the first positions/coordinates of a sparse matrix.
+static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
+                               bool isCOO, bool enableRT) {
+  if (isCOO) {
+    // Library uses SoA COO, direct IR uses AoS COO.
+    if (enableRT)
+      return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
+    return genToCoordinatesBuffer(builder, loc, a);
+  }
+  // CSR uses positions.
+  return genToPositions(builder, loc, a, 1);
+}
+
+/// Generates the second coordinates of a sparse matrix.
+static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
+                           bool isCOO, bool enableRT) {
+  if (isCOO && !enableRT)
+    return Value(); // nothing needed
+  return genToCoordinates(builder, loc, a, 1, /*cooStart=*/0);
+}
+
+/// Generates the sparse matrix multiplication.
+static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
+                           Type tokenTp, Value token, Value szY, Value szX,
+                           Value nnzA, Value rowA, Value colA, Value valA,
+                           bool isCOO, bool enableRT) {
+  if (isCOO) {
+    // Library uses SoA COO, direct IR uses AoS COO.
+    if (enableRT)
+      return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
+                                              szY, szX, nnzA, rowA, colA, valA);
+    llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
+  }
+  return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, szY,
+                                          szX, nnzA, rowA, colA, valA);
+}
+
+/// Match and rewrite SpMV kernel.
+static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
+  Location loc = op.getLoc();
+  Value a = op.getOperand(0);
+  Value x = op.getOperand(1);
+  Value y = op.getOperand(2); // we have y = Ax
+  SmallVector<Value> tokens;
+
+  // Only admissible sparse matrix format and dense vectors for now.
+  bool isCOO = false;
+  SparseTensorType aTp = getSparseTensorType(a);
+  SparseTensorType xTp = getSparseTensorType(x);
+  SparseTensorType yTp = getSparseTensorType(y);
+  if (xTp.hasEncoding() || yTp.hasEncoding())
+    return failure();
+  if (isAdmissibleCOO(aTp)) {
+    isCOO = true;
+    // TODO: CreateCooAoSOp was deprecated, find another way
+    if (!enableRT)
+      return failure();
+  } else if (isAdmissibleCSR(aTp)) {
+    isCOO = false;
+  } else {
+    return failure();
+  }
+
+  // Start sparse kernel and copy data from host to device.
+  //   a : memR/memC/memV -> rowA,colA,valA
+  //   x : memX           -> vecX
+  //   y : memY           -> vecY
+  Value nnzA = rewriter.create<NumberOfEntriesOp>(loc, a);
+  Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
+  Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
+  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memV = genToValues(rewriter, loc, a);
+  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
+  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
+  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
+  Value memX = genTensorToMemref(rewriter, loc, x);
+  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
+  Value memY = genTensorToMemref(rewriter, loc, y);
+  Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Create sparse environment and sparse matrix/dense vector handles.
+  Type indexTp = rewriter.getIndexType();
+  Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
+  Value token = genFirstWait(rewriter, loc);
+  auto env =
+      rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+  Value handle = env.getResult(0);
+  token = env.getAsyncToken();
+  Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY,
+                               szX, nnzA, rowA, colA, valA, isCOO, enableRT);
+  Value spMatA = spGenA->getResult(0);
+  token = spGenA->getResult(1);
+  auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+                                                   token, vecX, szX);
+  Value dnX = dvecX.getResult(0);
+  token = dvecX.getAsyncToken();
+  auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+                                                   token, vecY, szY);
+  Value dnY = dvecY.getResult(0);
+  token = dvecY.getAsyncToken();
+
+  // Precompute buffersize for SpMV.
+  auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
+      loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY);
+  Value bufferSz = bufferComp.getResult(0);
+  token = bufferComp.getAsyncToken();
+  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
+  Value buffer = buf.getResult(0);
+  token = buf.getAsyncToken();
+
+  // Perform the SpMV.
+  auto spmvComp = rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle,
+                                               spMatA, dnX, dnY, buffer);
+  token = spmvComp.getAsyncToken();
+
+  // Copy data back to host and free all the resoures.
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroyDnVecOp>(loc, tokenTp, token, dnX)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroyDnVecOp>(loc, tokenTp, token, dnY)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySparseEnvOp>(loc, tokenTp, token, handle)
+              .getAsyncToken();
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+  token = genFirstWait(rewriter, loc);
+  token = genCopyMemRef(rewriter, loc, memY, vecY, token);
+  token = genDeallocMemRef(rewriter, loc, rowA, token);
+  if (colA)
+    token = genDeallocMemRef(rewriter, loc, colA, token);
+  token = genDeallocMemRef(rewriter, loc, valA, token);
+  token = genDeallocMemRef(rewriter, loc, buffer, token);
+  token = genDeallocMemRef(rewriter, loc, vecX, token);
+  token = genDeallocMemRef(rewriter, loc, vecY, token);
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Done.
+  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
+  return success();
+}
+
+/// Match and rewrite SpMM kernel.
+static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
+  return failure(); // TODO: implement
+}
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for direct code generation.
 //===----------------------------------------------------------------------===//
 
-/// Proof-of-concept rewriter. This rule generates a CUDA implementation
+/// Proof-of-concept rewriter. This rule generates a GPU implementation
 /// for each outermost forall loop generated by the sparse compiler.
 /// TODO: right works with parallelization-strategy=dense-outer-loop
 ///       but give this its own flags in the future
@@ -373,13 +605,77 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
   unsigned numThreads;
 };
 
+//===----------------------------------------------------------------------===//
+// Rewriting rules for library recognition and code generation.
+//===----------------------------------------------------------------------===//
+
+/// Proof-of-concept rewriter. This rule recognizes certain math kernels
+/// and replaces these with corresponding calls into the sparse library.
+struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LinalgOpRewriter(MLIRContext *context, bool rt)
+      : OpRewritePattern(context), enableRT(rt) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getNumDpsInits() != 1)
+      return failure(); // reject multi-output
+
+    const unsigned numLoops = op.getNumLoops();
+    const unsigned numTensors = op->getNumOperands();
+    const auto iteratorTypes = op.getIteratorTypesArray();
+    SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    AffineExpr i, j, k;
+    bindDims(getContext(), i, j, k);
+
+    // TODO: more robust patterns, tranposed versions, more kernels...
+
+    // Recognize a SpMV kernel.
+    if (numLoops == 2 && numTensors == 3 &&
+        linalg::isParallelIterator(iteratorTypes[0]) &&
+        linalg::isReductionIterator(iteratorTypes[1]) &&
+        maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
+      return rewriteSpMV(rewriter, op, enableRT);
+    }
+
+    // Recognize a SpMM kernel.
+    if (numLoops == 3 && numTensors == 3 &&
+        linalg::isParallelIterator(iteratorTypes[0]) &&
+        linalg::isParallelIterator(iteratorTypes[1]) &&
+        linalg::isReductionIterator(iteratorTypes[2]) &&
+        maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
+      return rewriteSpMM(rewriter, op, enableRT);
+    }
+
+    return failure();
+  }
+
+private:
+  bool enableRT;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
 // Public method for populating GPU rewriting rules.
+//
+// Currently two set of rewriting rules are made available. The first set
+// implements direct code generation, currently by means of convering the
+// outermost paralell loop into GPU threads. The second set implements
+// libary recognition of a set of sparse operations. Eventually, the right
+// combination of these two approaches has to be found.
 //===----------------------------------------------------------------------===//
 
 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
                                             unsigned numThreads) {
   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
 }
+
+void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
+                                           bool enableRT) {
+  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index cd56fbd5099dc..f59e6ed346c13 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -66,14 +66,20 @@ struct SparsificationPass
   SparsificationPass(const SparsificationOptions &options) {
     parallelization = options.parallelizationStrategy;
     enableIndexReduction = options.enableIndexReduction;
+    enableGPULibgen = options.enableGPULibgen;
+    enableRuntimeLibrary = options.enableRuntimeLibrary;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     // Translate strategy flags to strategy options.
-    SparsificationOptions options(parallelization, enableIndexReduction);
-    // Apply sparsification and cleanup rewriting.
+    SparsificationOptions options(parallelization, enableIndexReduction,
+                                  enableGPULibgen, enableRuntimeLibrary);
+    // Apply GPU libgen (if requested), sparsification, and cleanup rewriting.
     RewritePatternSet patterns(ctx);
+    if (enableGPULibgen) {
+      populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
+    }
     populateSparsificationPatterns(patterns, options);
     scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index be1846ea8da2c..fbbf6c18316fd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -98,6 +99,7 @@ class SparsificationAndBufferizationPass
 
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
     registry.insert<bufferization::BufferizationDialect>();
+    registry.insert<gpu::GPUDialect>();
     registry.insert<LLVM::LLVMDialect>();
   }
 

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir
new file mode 100644
index 0000000000000..517d52926b74e
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --sparsification="enable-gpu-libgen" | FileCheck %s
+
+#SortedCOO = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+module {
+
+// CHECK-LABEL:   func.func @matvec(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?xf64>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<?xf64>) -> tensor<?xf64> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_11:.*]] = gpu.wait async
+// CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_8]], %[[VAL_3]] : memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_13:.*]], %[[VAL_14:.*]] = gpu.alloc async {{\[}}%[[VAL_11]]] (%[[VAL_12]]) : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = gpu.memcpy async {{\[}}%[[VAL_14]]] %[[VAL_13]], %[[VAL_8]] : memref<?xindex>, memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_16:.*]] = gpu.wait async
+// CHECK:           %[[VAL_17:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_18:.*]], %[[VAL_19:.*]] = gpu.alloc async {{\[}}%[[VAL_16]]] (%[[VAL_17]]) : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = gpu.memcpy async {{\[}}%[[VAL_19]]] %[[VAL_18]], %[[VAL_9]] : memref<?xindex>, memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_21:.*]] = gpu.wait async
+// CHECK:           %[[VAL_22:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref<?xf64>
+// CHECK:           %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_21]]] (%[[VAL_22]]) : memref<?xf64>
+// CHECK:           %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_10]] : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_26:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?xf64>
+// CHECK:           %[[VAL_27:.*]] = gpu.wait async
+// CHECK:           %[[VAL_28:.*]] = memref.dim %[[VAL_26]], %[[VAL_3]] : memref<?xf64>
+// CHECK:           %[[VAL_29:.*]], %[[VAL_30:.*]] = gpu.alloc async {{\[}}%[[VAL_27]]] (%[[VAL_28]]) : memref<?xf64>
+// CHECK:           %[[VAL_31:.*]] = gpu.memcpy async {{\[}}%[[VAL_30]]] %[[VAL_29]], %[[VAL_26]] : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_32:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?xf64>
+// CHECK:           %[[VAL_33:.*]] = gpu.wait async
+// CHECK:           %[[VAL_34:.*]] = memref.dim %[[VAL_32]], %[[VAL_3]] : memref<?xf64>
+// CHECK:           %[[VAL_35:.*]], %[[VAL_36:.*]] = gpu.alloc async {{\[}}%[[VAL_33]]] (%[[VAL_34]]) : memref<?xf64>
+// CHECK:           %[[VAL_37:.*]] = gpu.memcpy async {{\[}}%[[VAL_36]]] %[[VAL_35]], %[[VAL_32]] : memref<?xf64>, memref<?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]], %[[VAL_31]], %[[VAL_37]]]
+// CHECK:           %[[VAL_38:.*]] = gpu.wait async
+// CHECK:           %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_38]]]
+// CHECK:           %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_coo async {{\[}}%[[VAL_40]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_13]], %[[VAL_18]], %[[VAL_23]] : memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:           %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.create_dn_vec async {{\[}}%[[VAL_42]]] %[[VAL_29]], %[[VAL_7]] : memref<?xf64>
+// CHECK:           %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.create_dn_vec async {{\[}}%[[VAL_44]]] %[[VAL_35]], %[[VAL_6]] : memref<?xf64>
+// CHECK:           %[[VAL_47:.*]], %[[VAL_48:.*]] = gpu.spmv_buffer_size async {{\[}}%[[VAL_46]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]], %[[VAL_45]]
+// CHECK:           %[[VAL_49:.*]], %[[VAL_50:.*]] = gpu.alloc async {{\[}}%[[VAL_48]]] (%[[VAL_47]]) : memref<?xi8>
+// CHECK:           %[[VAL_51:.*]] = gpu.spmv async {{\[}}%[[VAL_50]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]], %[[VAL_45]], %[[VAL_49]] : memref<?xi8>
+// CHECK:           %[[VAL_52:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_51]]] %[[VAL_41]]
+// CHECK:           %[[VAL_53:.*]] = gpu.destroy_dn_vec async {{\[}}%[[VAL_52]]] %[[VAL_43]]
+// CHECK:           %[[VAL_54:.*]] = gpu.destroy_dn_vec async {{\[}}%[[VAL_53]]] %[[VAL_45]]
+// CHECK:           %[[VAL_55:.*]] = gpu.destroy_sparse_env async {{\[}}%[[VAL_54]]] %[[VAL_39]]
+// CHECK:           gpu.wait {{\[}}%[[VAL_55]]]
+// CHECK:           %[[VAL_56:.*]] = gpu.wait async
+// CHECK:           %[[VAL_57:.*]] = gpu.memcpy async {{\[}}%[[VAL_56]]] %[[VAL_32]], %[[VAL_35]] : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_13]] : memref<?xindex>
+// CHECK:           %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_58]]] %[[VAL_18]] : memref<?xindex>
+// CHECK:           %[[VAL_60:.*]] = gpu.dealloc async {{\[}}%[[VAL_59]]] %[[VAL_23]] : memref<?xf64>
+// CHECK:           %[[VAL_61:.*]] = gpu.dealloc async {{\[}}%[[VAL_60]]] %[[VAL_49]] : memref<?xi8>
+// CHECK:           %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_29]] : memref<?xf64>
+// CHECK:           %[[VAL_63:.*]] = gpu.dealloc async {{\[}}%[[VAL_62]]] %[[VAL_35]] : memref<?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_63]]]
+// CHECK:           return %[[VAL_2]] : tensor<?xf64>
+// CHECK:         }
+func.func @matvec(%A: tensor<?x?xf64, #SortedCOO>,
+                  %x: tensor<?xf64>,
+                  %y_in: tensor<?xf64>) -> tensor<?xf64> {
+  %y_out = linalg.matvec
+    ins(%A, %x: tensor<?x?xf64, #SortedCOO>, tensor<?xf64>)
+    outs(%y_in: tensor<?xf64>) -> tensor<?xf64>
+  return %y_out : tensor<?xf64>
+}
+
+}


        


More information about the Mlir-commits mailing list