[Mlir-commits] [mlir] 19466eb - [mlir][sparse][gpu] a first prototype sparse GPU code generator
Aart Bik
llvmlistbot at llvm.org
Wed Apr 5 11:32:15 PDT 2023
Author: Aart Bik
Date: 2023-04-05T11:32:06-07:00
New Revision: 19466ebc7ff8f51e2ce2c69949823a0c3e2fb660
URL: https://github.com/llvm/llvm-project/commit/19466ebc7ff8f51e2ce2c69949823a0c3e2fb660
DIFF: https://github.com/llvm/llvm-project/commit/19466ebc7ff8f51e2ce2c69949823a0c3e2fb660.diff
LOG: [mlir][sparse][gpu] a first prototype sparse GPU code generator
This implements a proof-of-concept GPU code generator
to the sparse compiler pipeline, currently only capable
of generating CUDA threads for outermost parallel loops.
The objective, obviously, is to grow this concept
to a full blown GPU code generator, capable of the
right combinaton of code generation as well as exploiting
idiomatic kernels or vector specific libraries (think cuSparse).
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D147483
Added:
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 8ef6381489e81..c69dfb77f6cbe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -203,6 +203,12 @@ std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32);
+void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
+ unsigned numThreads);
+
+std::unique_ptr<Pass> createSparseGPUCodegenPass();
+std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads);
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c4cbe0771127b..91126f497b423 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -310,6 +310,26 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
];
}
+def SparseGPUCodegen : Pass<"sparse-gpu-codegen", "ModuleOp"> {
+ let summary = "Generates GPU code during sparsification";
+ let description = [{
+ Enables sparse compiler to use GPU acceleration.
+ }];
+ let constructor = "mlir::createSparseGPUCodegenPass()";
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "bufferization::BufferizationDialect",
+ "gpu::GPUDialect",
+ "linalg::LinalgDialect",
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+ let options = [
+ Option<"numThreads", "num_threads", "int32_t", "1024", "Sets the number of GPU threads">,
+ ];
+}
+
def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
let summary = "Lower sparse storage specifer to llvm structure";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 8107f2472537b..6133c8b5174b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenUtils.cpp
LoopEmitter.cpp
SparseBufferRewriting.cpp
+ SparseGPUCodegen.cpp
SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
new file mode 100644
index 0000000000000..28b5f72c19c65
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -0,0 +1,247 @@
+//===- SparseGPUCodegen.cpp - Generates GPU code (using CUDA) -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a prototype GPU codegenerator for the sparse compiler.
+// The objective is to eventually use the right combination of
+// direct code generation and libary calls into vendor-specific
+// highly optimized sparse libraries (e.g. cuSparse for CUDA).
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+#include "LoopEmitter.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Marks the given top module as a GPU container module.
+static void markAsGPUContainer(ModuleOp topModule) {
+ topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
+ UnitAttr::get(topModule->getContext()));
+}
+
+/// Constructs a new GPU module (for GPU kernels) inside the given top module.
+static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule,
+ StringRef name) {
+ markAsGPUContainer(topModule);
+ builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
+ return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), name);
+}
+
+/// Constructs a new GPU kernel in the given GPU module.
+static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
+ StringRef name, SmallVectorImpl<Value> &args) {
+ builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
+ SmallVector<Type> argsTp;
+ for (unsigned i = 0, e = args.size(); i < e; i++)
+ argsTp.push_back(args[i].getType());
+ FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
+ auto gpuFunc =
+ builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), name, type);
+ gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+ return gpuFunc;
+}
+
+/// Constructs code to launch GPU kernel.
+static void genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
+ SmallVectorImpl<Value> &args,
+ unsigned numThreads) {
+ Location loc = gpuFunc->getLoc();
+ Value none = TypedValue<::mlir::IntegerType>{};
+ Value one = constantIndex(builder, loc, 1);
+ Value numT = constantIndex(builder, loc, numThreads);
+ gpu::KernelDim3 gridSize = {one, one, one};
+ gpu::KernelDim3 blckSize = {numT, one, one};
+ builder.create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
+ /*dynSharedMemSz*/ none, args);
+}
+
+/// Maps the provided ranked host buffer into the device address space.
+/// Writes from the host are guaranteed to be visible to device kernels
+/// that are launched afterwards. Writes from the device are guaranteed
+/// to be visible on the host after synchronizing with the device kernel
+/// completion.
+static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
+ Value mem) {
+ MemRefType memTp = mem.getType().cast<MemRefType>();
+ UnrankedMemRefType resTp =
+ UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
+ Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
+ builder.create<gpu::HostRegisterOp>(loc, cast);
+ return mem; // convenience pass-through
+}
+
+/// Constructs code for new GPU kernel.
+static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
+ scf::ParallelOp forallOp,
+ SmallVectorImpl<Value> &constants,
+ SmallVectorImpl<Value> &scalars,
+ SmallVectorImpl<Value> &buffers) {
+ Location loc = gpuFunc->getLoc();
+ Block &block = gpuFunc.getBody().front();
+ rewriter.setInsertionPointToStart(&block);
+
+ // Re-generate the constants, recapture all arguments.
+ unsigned arg = 0;
+ IRMapping irMap;
+ for (Value c : constants)
+ irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
+ for (Value s : scalars)
+ irMap.map(s, block.getArgument(arg++));
+ for (Value b : buffers)
+ irMap.map(b, block.getArgument(arg++));
+
+ // Assume 1-dimensional grid/block configuration (only x dimension),
+ // so that:
+ // row = blockIdx.x * blockDim.x + threadIdx.x
+ // inc = blockDim.x * gridDim.x
+ Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
+ Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
+ Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+ Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
+ Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
+ Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
+ Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
+
+ // Construct the iteration over the computational space that
+ // accounts for the fact that the total number of threads and
+ // the amount of work to be done usually do not match precisely.
+ // for (r = row; r < N; r += inc) {
+ // <loop-body>
+ // }
+ Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
+ rewriter.cloneRegionBefore(forallOp.getLoopBody(), forOp.getLoopBody(),
+ forOp.getLoopBody().begin(), irMap);
+
+ // Done.
+ rewriter.setInsertionPointAfter(forOp);
+ rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules.
+//===----------------------------------------------------------------------===//
+
+/// Proof-of-concept rewriter. This rule generates a CUDA implementation
+/// for each outermost forall loop generated by the sparse compiler.
+//
+// TODO: right works with parallelization-strategy=dense-outer-loop
+// but give this its own flags in the future
+//
+struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
+ using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+ ForallRewriter(MLIRContext *context, unsigned nT)
+ : OpRewritePattern(context), numThreads(nT){};
+
+ LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
+ PatternRewriter &rewriter) const override {
+ // Reject inadmissible loop form.
+ // Essentially only accept a loop, generated by the sparse compiler,
+ // of the form
+ // forall (i = 0; i < N; i++)
+ // so that cyclic scheduling over the threads is easy.
+ if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
+ forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
+ !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
+ !matchPattern(forallOp.getStep()[0], m_One()))
+ return failure();
+ // Collect every value that is computed outside the parallel loop.
+ SetVector<Value> invariants; // stable iteration!
+ forallOp->walk([&](Operation *op) {
+ // Collect all values of admissible ops.
+ for (OpOperand &o : op->getOpOperands()) {
+ Value val = o.get();
+ Block *block;
+ if (auto arg = val.dyn_cast<BlockArgument>())
+ block = arg.getOwner();
+ else
+ block = val.getDefiningOp()->getBlock();
+ if (!isNestedIn(block, forallOp))
+ invariants.insert(val);
+ }
+ });
+ // Outline the outside values as proper parameters. Fail when sharing
+ // value between host and device is not straightforward.
+ SmallVector<Value> constants;
+ SmallVector<Value> scalars;
+ SmallVector<Value> buffers;
+ for (Value val : invariants) {
+ Type tp = val.getType();
+ if (val.getDefiningOp<arith::ConstantOp>())
+ constants.push_back(val);
+ else if (tp.isa<FloatType>() || tp.isIntOrIndex())
+ scalars.push_back(val);
+ else if (isa<MemRefType>(tp))
+ buffers.push_back(val);
+ else
+ return failure(); // don't know how to share
+ }
+ // Prepare the outlined arguments, register buffers.
+ Location loc = forallOp->getLoc();
+ SmallVector<Value> args;
+ for (Value s : scalars)
+ args.push_back(s);
+ for (Value b : buffers)
+ args.push_back(genHostRegisterMemref(rewriter, loc, b));
+ auto saveIp = rewriter.saveInsertionPoint();
+ // Set up GPU module and construct GPU function.
+ //
+ // TODO: only generate once, avoid name conflict
+ //
+ ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
+ auto gpuModule = genGPUModule(rewriter, topModule, "sparsekernels");
+ auto gpuFunc = genGPUFunc(rewriter, gpuModule, "kernel", args);
+ genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
+ // Generate code that launches the kernel.
+ rewriter.restoreInsertionPoint(saveIp);
+ genLaunchGPUFunc(rewriter, gpuFunc, args, numThreads);
+ rewriter.eraseOp(forallOp);
+ return success();
+ }
+
+private:
+ // Helper method to see if block appears in given loop.
+ static bool isNestedIn(Block *block, scf::ParallelOp forallOp) {
+ for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) {
+ if (o == forallOp)
+ return true;
+ }
+ return false;
+ }
+
+ unsigned numThreads;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public method for populating GPU rewriting rules.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
+ unsigned numThreads) {
+ patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f39ead190c54a..cd56fbd5099dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -28,6 +29,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
+#define GEN_PASS_DEF_SPARSEGPUCODEGEN
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -281,6 +283,21 @@ struct SparseVectorizationPass
}
};
+struct SparseGPUCodegenPass
+ : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
+
+ SparseGPUCodegenPass() = default;
+ SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
+ SparseGPUCodegenPass(unsigned nT) { numThreads = nT; }
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ populateSparseGPUCodegenPatterns(patterns, numThreads);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct StorageSpecifierToLLVMPass
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
@@ -406,6 +423,14 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}
+std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
+ return std::make_unique<SparseGPUCodegenPass>();
+}
+
+std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads) {
+ return std::make_unique<SparseGPUCodegenPass>(numThreads);
+}
+
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
return std::make_unique<StorageSpecifierToLLVMPass>();
}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
new file mode 100644
index 0000000000000..e42bbb0924ac2
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN: --pre-sparsification-rewrite \
+// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
+// RUN: --sparse-gpu-codegen | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+//
+// Compute matrix matrix C = AB
+//
+// CHECK-LABEL: gpu.func @kernel(
+// CHECK-SAME: %[[VAL_0:.*0]]: index,
+// CHECK-SAME: %[[VAL_1:.*1]]: index,
+// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_5:.*5]]: memref<?x?xf64>,
+// CHECK-SAME: %[[VAL_6:.*6]]: memref<?x?xf64>) kernel {
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = gpu.block_id x
+// CHECK: %[[VAL_10:.*]] = gpu.block_dim x
+// CHECK: %[[VAL_11:.*]] = gpu.thread_id x
+// CHECK: %[[VAL_12:.*]] = gpu.grid_dim x
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_10]] : index
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_11]] : index
+// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_10]], %[[VAL_12]] : index
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_1]] step %[[VAL_15]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_7]] : index
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_7]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_20]]] : memref<?xf64>
+// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_0]] step %[[VAL_7]] {
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref<?x?xf64>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_21]], %[[VAL_23]]] : memref<?x?xf64>
+// CHECK: %[[VAL_26:.*]] = arith.mulf %[[VAL_22]], %[[VAL_25]] : f64
+// CHECK: %[[VAL_27:.*]] = arith.addf %[[VAL_24]], %[[VAL_26]] : f64
+// CHECK: memref.store %[[VAL_27]], %[[VAL_5]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref<?x?xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: }
+// CHECK: gpu.return
+// CHECK: }
+//
+//
+// CHECK-LABEL: func.func @matmul
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.launch_func @sparsekernels::@kernel blocks
+//
+func.func @matmul(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64>, %C_in: tensor<?x?xf64>) -> tensor<?x?xf64> {
+ %C_out = linalg.matmul
+ ins(%A, %B: tensor<?x?xf64, #CSR>, tensor<?x?xf64>)
+ outs(%C_in: tensor<?x?xf64>) -> tensor<?x?xf64>
+ return %C_out : tensor<?x?xf64>
+}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
new file mode 100644
index 0000000000000..96b7f9dd31299
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN: --pre-sparsification-rewrite \
+// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
+// RUN: --sparse-gpu-codegen | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+//
+// Compute matrix vector y = Ax
+//
+//
+// CHECK: gpu.func @kernel(
+// CHECK-SAME: %[[VAL_0:.*0]]: index,
+// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_5:.*5]]: memref<?xf64>) kernel {
+// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_7:.*]] = gpu.block_id x
+// CHECK: %[[VAL_8:.*]] = gpu.block_dim x
+// CHECK: %[[VAL_9:.*]] = gpu.thread_id x
+// CHECK: %[[VAL_10:.*]] = gpu.grid_dim x
+// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_7]], %[[VAL_8]] : index
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_9]] : index
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
+// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_0]] step %[[VAL_13]] {
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : index
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_6]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f64) {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_20]]] : memref<?xf64>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_22]]] : memref<?xf64>
+// CHECK: %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f64
+// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_21]], %[[VAL_25]] : f64
+// CHECK: scf.yield %[[VAL_26]] : f64
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: memref.store %[[VAL_27:.*]], %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xf64>
+// CHECK: }
+// CHECK: gpu.return
+// CHECK: }
+//
+// CHECK-LABEL: func.func @matvec
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.host_register
+// CHECK: gpu.launch_func @sparsekernels::@kernel blocks
+//
+func.func @matvec(%A: tensor<?x?xf64, #CSR>, %x: tensor<?xf64>, %y_in: tensor<?xf64>) -> tensor<?xf64> {
+ %y_out = linalg.matvec
+ ins(%A, %x: tensor<?x?xf64, #CSR>, tensor<?xf64>)
+ outs(%y_in: tensor<?xf64>) -> tensor<?xf64>
+ return %y_out : tensor<?xf64>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 53dc23bea3ea4..7cc26bf5f5fa5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2343,6 +2343,7 @@ cc_library(
":DialectUtils",
":FuncDialect",
":FuncTransforms",
+ ":GPUDialect",
":IR",
":LLVMCommonConversion",
":LLVMDialect",
More information about the Mlir-commits
mailing list