[Mlir-commits] [mlir] [mlir][gpu] Productize `test-lower-to-nvvm` as `gpu-lower-to-nvvm` (PR #75775)

Guray Ozen llvmlistbot at llvm.org
Mon Dec 18 02:34:37 PST 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/75775

>From 2aa9027012c56e6d4e37c2aa55ece686aa36dcc9 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 8 Dec 2023 20:17:08 +0100
Subject: [PATCH 1/2] [mlir][gpu] Move test-lower-to-nvvm to gpu-lower-to-nvvm
 Pipeline

MLIR has `test-lower-to-nvvm` that is full pipeline compilation down to nvvm and host compilation. There is need to use it with python as well. So this PR move to `test-lower-to-nvvm` pipeline to `gpu-lower-to-nvvm` pipeline.
---
 .../mlir/Dialect/GPU/Pipelines/Passes.h       | 22 +++++++++++++++
 mlir/include/mlir/InitAllPasses.h             |  4 +++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |  1 +
 mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 12 +++++++++
 .../GPU/Pipelines/GPUToNVVMPipeline.cpp}      | 27 +++++++++----------
 .../SparseTensor/GPU/CUDA/dump-ptx.mlir       |  2 +-
 .../GPU/CUDA/sparse-mma-2-4-f16.mlir          |  2 +-
 .../GPU/CUDA/test-reduction-distribute.mlir   |  2 +-
 .../Vector/GPU/CUDA/test-warp-distribute.mlir |  6 ++---
 ...ansform-mma-sync-matmul-f16-f16-accum.mlir |  2 +-
 .../sm80/transform-mma-sync-matmul-f32.mlir   |  2 +-
 .../GPU/CUDA/TensorCore/wmma-matmul-f16.mlir  |  2 +-
 .../TensorCore/wmma-matmul-f32-bare-ptr.mlir  |  2 +-
 .../GPU/CUDA/TensorCore/wmma-matmul-f32.mlir  |  2 +-
 .../Integration/GPU/CUDA/all-reduce-and.mlir  |  4 +--
 .../GPU/CUDA/all-reduce-maxsi.mlir            |  2 +-
 .../GPU/CUDA/all-reduce-minsi.mlir            |  2 +-
 .../Integration/GPU/CUDA/all-reduce-op.mlir   |  2 +-
 .../Integration/GPU/CUDA/all-reduce-or.mlir   |  2 +-
 .../GPU/CUDA/all-reduce-region.mlir           |  2 +-
 .../Integration/GPU/CUDA/all-reduce-xor.mlir  |  2 +-
 .../Integration/GPU/CUDA/gpu-to-cubin.mlir    |  2 +-
 .../GPU/CUDA/multiple-all-reduce.mlir         |  2 +-
 mlir/test/Integration/GPU/CUDA/printf.mlir    |  2 +-
 mlir/test/Integration/GPU/CUDA/shuffle.mlir   |  2 +-
 .../GPU/CUDA/sm90/cga_cluster.mlir            |  2 +-
 .../sm90/gemm_f32_f16_f16_128x128x128.mlir    |  2 +-
 .../gemm_pred_f32_f16_f16_128x128x128.mlir    |  2 +-
 .../sm90/tma_load_128x64_swizzle128b.mlir     |  2 +-
 .../CUDA/sm90/tma_load_64x64_swizzle128b.mlir |  2 +-
 .../sm90/tma_load_64x8_8x128_noswizzle.mlir   |  2 +-
 .../Integration/GPU/CUDA/two-modules.mlir     |  2 +-
 mlir/test/lib/Dialect/GPU/CMakeLists.txt      |  1 -
 mlir/tools/mlir-opt/mlir-opt.cpp              |  3 ---
 34 files changed, 82 insertions(+), 48 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
 create mode 100644 mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
 rename mlir/{test/lib/Dialect/GPU/TestLowerToNVVM.cpp => lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp} (90%)

diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
new file mode 100644
index 00000000000000..c2ab4fb54536a4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -0,0 +1,22 @@
+//===- Passes.h - GPU NVVM pipeline entry points -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes of all sparse tensor pipelines.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
+#define MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
+
+namespace mlir {
+namespace gpu {
+void registerGPUToNVVMPipeline();
+} // namespace gpu
+} // namespace mlir
+
+#endif
\ No newline at end of file
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index f22980036ffcfa..311d93477d037e 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -91,6 +92,9 @@ inline void registerAllPasses() {
   bufferization::registerBufferizationPipelines();
   sparse_tensor::registerSparseTensorPipelines();
   tosa::registerTosaToLinalgPipelines();
+#if MLIR_CUDA_CONVERSIONS_ENABLED
+  gpu::registerGPUToNVVMPipeline();
+#endif
 }
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index e8b69879ad6a7e..ab6834cb262fb5 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -105,6 +105,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   )
 
 add_subdirectory(TransformOps)
+add_subdirectory(Pipelines)
 
 if(MLIR_ENABLE_CUDA_RUNNER)
   if(NOT MLIR_ENABLE_CUDA_CONVERSIONS)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
new file mode 100644
index 00000000000000..095f8fd5205172
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRGPUPipelines
+  GPUToNVVMPipeline.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
+
+  LINK_LIBS PUBLIC
+  MLIRMemRefTransforms
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/test/lib/Dialect/GPU/TestLowerToNVVM.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
similarity index 90%
rename from mlir/test/lib/Dialect/GPU/TestLowerToNVVM.cpp
rename to mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index 28f76bde0820a6..4cb8b1dc7bc2d1 100644
--- a/mlir/test/lib/Dialect/GPU/TestLowerToNVVM.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -1,4 +1,4 @@
-//===- TestLowerToNVVM.cpp - Test lowering to NVVM as a sink pass ---------===//
+//===- GPUToNVVMPipeline.cpp - Test lowering to NVVM as a sink pass -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -27,6 +27,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -39,8 +40,8 @@ using namespace mlir;
 
 #if MLIR_CUDA_CONVERSIONS_ENABLED
 namespace {
-struct TestLowerToNVVMOptions
-    : public PassPipelineOptions<TestLowerToNVVMOptions> {
+struct GPUToNVVMPipelineOptions
+    : public PassPipelineOptions<GPUToNVVMPipelineOptions> {
   PassOptions::Option<int64_t> indexBitWidth{
       *this, "index-bitwidth",
       llvm::cl::desc("Bitwidth of the index type for the host (warning this "
@@ -83,16 +84,14 @@ struct TestLowerToNVVMOptions
 // Common pipeline
 //===----------------------------------------------------------------------===//
 void buildCommonPassPipeline(OpPassManager &pm,
-                             const TestLowerToNVVMOptions &options) {
+                             const GPUToNVVMPipelineOptions &options) {
   pm.addPass(createConvertNVGPUToNVVMPass());
   pm.addPass(createGpuKernelOutliningPass());
   pm.addPass(createConvertLinalgToLoopsPass());
   pm.addPass(createConvertVectorToSCFPass());
   pm.addPass(createConvertSCFToCFPass());
   pm.addPass(createConvertNVVMToLLVMPass());
-  pm.addPass(createConvertVectorToLLVMPass());
   pm.addPass(createConvertMathToLLVMPass());
-  pm.addPass(createFinalizeMemRefToLLVMConversionPass());
   pm.addPass(createConvertFuncToLLVMPass());
   pm.addPass(memref::createExpandStridedMetadataPass());
 
@@ -115,7 +114,7 @@ void buildCommonPassPipeline(OpPassManager &pm,
 // GPUModule-specific stuff.
 //===----------------------------------------------------------------------===//
 void buildGpuPassPipeline(OpPassManager &pm,
-                          const TestLowerToNVVMOptions &options) {
+                          const GPUToNVVMPipelineOptions &options) {
   pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
   ConvertGpuOpsToNVVMOpsOptions opt;
   opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
@@ -130,7 +129,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
 // Host Post-GPU pipeline
 //===----------------------------------------------------------------------===//
 void buildHostPostPipeline(OpPassManager &pm,
-                           const TestLowerToNVVMOptions &options) {
+                           const GPUToNVVMPipelineOptions &options) {
   GpuToLLVMConversionPassOptions opt;
   opt.hostBarePtrCallConv = options.hostUseBarePtrCallConv;
   opt.kernelBarePtrCallConv = options.kernelUseBarePtrCallConv;
@@ -145,7 +144,7 @@ void buildHostPostPipeline(OpPassManager &pm,
 }
 
 void buildLowerToNVVMPassPipeline(OpPassManager &pm,
-                                  const TestLowerToNVVMOptions &options) {
+                                  const GPUToNVVMPipelineOptions &options) {
   //===----------------------------------------------------------------------===//
   // Common pipeline
   //===----------------------------------------------------------------------===//
@@ -164,14 +163,14 @@ void buildLowerToNVVMPassPipeline(OpPassManager &pm,
 } // namespace
 
 namespace mlir {
-namespace test {
-void registerTestLowerToNVVM() {
-  PassPipelineRegistration<TestLowerToNVVMOptions>(
-      "test-lower-to-nvvm",
+namespace gpu {
+void registerGPUToNVVMPipeline() {
+  PassPipelineRegistration<GPUToNVVMPipelineOptions>(
+      "gpu-lower-to-nvvm",
       "An example of pipeline to lower the main dialects (arith, linalg, "
       "memref, scf, vector) down to NVVM.",
       buildLowerToNVVMPassPipeline);
 }
-} // namespace test
+} // namespace gpu
 } // namespace mlir
 #endif // MLIR_CUDA_CONVERSIONS_ENABLED
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
index 4483d18231e80e..42348e39832ade 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  | mlir-opt -test-lower-to-nvvm -debug-only=serialize-to-isa \
+// RUN:  | mlir-opt -gpu-lower-to-nvvm -debug-only=serialize-to-isa \
 // RUN:  2>&1 | FileCheck %s
 
 // CHECK: Generated by LLVM NVPTX Back-End
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
index e36b83e931933a..62d0d9e1cac984 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
@@ -4,7 +4,7 @@
 // RUN: mlir-opt \
 // RUN: --pass-pipeline="builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm,affine-expand-index-ops,lower-affine,convert-arith-to-llvm),convert-vector-to-llvm,canonicalize,cse)" \
 // RUN: %s \
-// RUN: | mlir-opt --test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt --gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_c_runner_utils \
diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
index 8c991493a2b017..94a57d7c266819 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize |\
 // RUN: mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if |\
 // RUN: mlir-opt -lower-affine -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm \
-// RUN:  -convert-arith-to-llvm -test-lower-to-nvvm | \
+// RUN:  -convert-arith-to-llvm -gpu-lower-to-nvvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
index f26c18c4ae3dd2..896051ab5dd7eb 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
@@ -2,7 +2,7 @@
 // everything on the same thread.
 // RUN: mlir-opt %s -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -test-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
@@ -13,7 +13,7 @@
 // RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \
 // RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -test-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
@@ -23,7 +23,7 @@
 // RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
 // RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -test-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
index c9f45ddad6ffcf..d4bd51aab03535 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s \
 // RUN:  -transform-interpreter \
 // RUN:  -test-transform-dialect-erase-schedule \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
index 367b4f32ede386..3e5f291db8e744 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -11,7 +11,7 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter \
 // RUN:   -test-transform-dialect-erase-schedule \
-// RUN:   -test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
+// RUN:   -gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index 95068974a1a07b..bbeddd5bb2285f 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
index 9ab0e59a291e07..d5950eae2543a6 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
@@ -3,7 +3,7 @@
 // Similar to the wmma-matmul-f32 but but with the memref bare pointer lowering convention.
 // This test also uses gpu.memcpy operations (instead of gpu.host_register).
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="host-bare-ptr-calling-convention=1 kernel-bare-ptr-calling-convention=1 cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="host-bare-ptr-calling-convention=1 kernel-bare-ptr-calling-convention=1 cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --entry-point-result=void \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
index 41f4c1d35454d6..c75f9c1b5649b1 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
index 13a05a2766e5df..fe999e0aa575b1 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm \
+// RUN: | mlir-opt -gpu-lower-to-nvvm \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
@@ -8,7 +8,7 @@
 
 // Same as above but with the memref bare pointer lowering convention.
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="kernel-bare-ptr-calling-convention=1 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="kernel-bare-ptr-calling-convention=1 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
index d858358a2892c6..dcd503c7bd806c 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
index 1ec926d9cacb01..8236550feb1113 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
index 070679689240c1..6f965c225e2d89 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
index 107e8a407d00cf..340db39f5d28f8 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
index 4aa44b9ce5e967..b4fc32ff9b838a 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
index 717dc542cc594b..f43a095584d69c 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
index 605a717b83f3f3..7f5b38b34c8995 100644
--- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
+++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
index 3635caac43555a..a894030d430807 100644
--- a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
+++ b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir
index 01c5939b251649..9555a77f45f11f 100644
--- a/mlir/test/Integration/GPU/CUDA/printf.mlir
+++ b/mlir/test/Integration/GPU/CUDA/printf.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
index 2a7482f9cece15..4e5bb3e8f5ca64 100644
--- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
index 5beba48813480f..bca3cb1f9a1e07 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index 327607f3796e7c..c8dc45ab861d16 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index 9185bc8fefcb92..bc3437b6545d71 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index 19f88306050afb..65f301968669aa 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index a078cf3a205468..fdbb188c28a9c7 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index 081a60dded788a..ed58504cfdb106 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/two-modules.mlir b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
index f16dcd9a72272e..f68359d78c0475 100644
--- a/mlir/test/Integration/GPU/CUDA/two-modules.mlir
+++ b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 3f20e5a6ecfc4b..aa94bce275eafb 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -33,7 +33,6 @@ set(LIBS
 add_mlir_library(MLIRGPUTestPasses
   TestGpuMemoryPromotion.cpp
   TestGpuRewrite.cpp
-  TestLowerToNVVM.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b4850560..7364838a10ae78 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -201,9 +201,6 @@ void registerTestPasses() {
   mlir::test::registerTestControlFlowSink();
   mlir::test::registerTestDiagnosticsPass();
   mlir::test::registerTestDialectConversionPasses();
-#if MLIR_CUDA_CONVERSIONS_ENABLED
-  mlir::test::registerTestLowerToNVVM();
-#endif
   mlir::test::registerTestDecomposeCallGraphTypes();
   mlir::test::registerTestDataLayoutPropagation();
   mlir::test::registerTestDataLayoutQuery();

>From 466a4ff416cfa1312e415c0850a0742653227e23 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 18 Dec 2023 11:34:22 +0100
Subject: [PATCH 2/2] address @ftynse comments

---
 mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h     | 8 ++------
 mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp | 6 ++++--
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index c2ab4fb54536a4..7128ffff2b748d 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,14 +1,10 @@
-//===- Passes.h - GPU NVVM pipeline entry points -----------*- C++ -*-===//
+//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This header file defines prototypes of all sparse tensor pipelines.
-//
-//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
 #define MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
@@ -19,4 +15,4 @@ void registerGPUToNVVMPipeline();
 } // namespace gpu
 } // namespace mlir
 
-#endif
\ No newline at end of file
+#endif
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index 4cb8b1dc7bc2d1..676dd14199c220 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -167,8 +167,10 @@ namespace gpu {
 void registerGPUToNVVMPipeline() {
   PassPipelineRegistration<GPUToNVVMPipelineOptions>(
       "gpu-lower-to-nvvm",
-      "An example of pipeline to lower the main dialects (arith, linalg, "
-      "memref, scf, vector) down to NVVM.",
+      "The default pipeline lowers main dialects (arith, linalg, memref, scf, "
+      "vector, gpu, and nvgpu) to NVVM. It starts by lowering GPU code to the "
+      "specified compilation target (default is fatbin) then lowers the host "
+      "code.",
       buildLowerToNVVMPassPipeline);
 }
 } // namespace gpu



More information about the Mlir-commits mailing list