[Mlir-commits] [mlir] ace69e6 - [mlir][gpu] Improve `gpu-lower-to-nvvm-pipeline` Documentation (#77062)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 5 03:51:29 PST 2024


Author: Guray Ozen
Date: 2024-01-05T12:51:25+01:00
New Revision: ace69e6b942b8fa7e610d70be2a92e801ceea481

URL: https://github.com/llvm/llvm-project/commit/ace69e6b942b8fa7e610d70be2a92e801ceea481
DIFF: https://github.com/llvm/llvm-project/commit/ace69e6b942b8fa7e610d70be2a92e801ceea481.diff

LOG: [mlir][gpu] Improve `gpu-lower-to-nvvm-pipeline` Documentation (#77062)

This PR improves the documentation for the `gpu-lower-to-nvvm-pipeline`
(as it was remaning item for #75775)

- Changes pipeline `gpu-lower-to-nvvm` -> `gpu-lower-to-nvvm-pipeline`
- Adds a section in GPU Dialect in website. It clarifies the pipeline's
functionality in lowering primary dialects to NVVM targets.

Added: 
    mlir/test/Integration/GPU/CUDA/sm90/asd

Modified: 
    mlir/docs/Dialects/GPU.md
    mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
    mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
    mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
    mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
    mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
    mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
    mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
    mlir/test/Integration/GPU/CUDA/printf.mlir
    mlir/test/Integration/GPU/CUDA/shuffle.mlir
    mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
    mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
    mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
    mlir/test/Integration/GPU/CUDA/two-modules.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md
index 8558667ea51ab5..85255fdc5e6439 100644
--- a/mlir/docs/Dialects/GPU.md
+++ b/mlir/docs/Dialects/GPU.md
@@ -60,6 +60,50 @@ mlir-translate example-nvvm.mlir        \
   -o example.ll
 ```
 
+### Default NVVM Compilation Pipeline: gpu-lower-to-nvvm-pipeline
+
+The `gpu-lower-to-nvvm-pipeline` compilation pipeline serves as the default way
+for NVVM target compilation within MLIR. This pipeline operates by lowering
+primary dialects (arith, memref, scf, vector, gpu, and nvgpu) to NVVM target. It
+begins by lowering GPU code region(s) to the specified NVVM compilation target
+and subsequently handles the host code.
+
+This pipeline specifically requires explicitly parallel IR and doesn't do GPU
+parallelization. To enable parallelism, necessary transformations must be
+applied before utilizing this pipeline.
+
+It's designed to provide a generic solution for NVVM targets, generating NVVM
+and LLVM dialect code compatible with `mlir-cpu-runner` or execution engine.
+
+#### Example:
+
+Here's a snippet illustrating the use of primary dialects, including arith,
+within GPU code execution:
+
+```
+func.func @main() {
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    gpu.launch 
+        blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1) 
+        threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) { 
+        gpu.printf "Hello from %d\n" %6 : index
+        gpu.terminator
+    }
+    return
+}
+```
+
+The `gpu-lower-to-nvvm` pipeline compiles this input code to NVVM format as
+below. It provides customization options like specifying SM capability, PTX
+version, and optimization level. Once compiled, the resulting IR is ready for
+execution using `mlir-cpu-runner`. Alternatively, it can be translated into
+LLVM, expanding its utility within the system.
+
+```
+mlir-opt example.mlir -gpu-lower-to-nvvm-pipeline = "cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+```
+
 ### Module serialization
 Attributes implementing the GPU Target Attribute Interface handle the
 serialization process and are called Target attributes. These attributes can be

diff  --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 7128ffff2b748d..caa0901bb49434 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -9,9 +9,65 @@
 #ifndef MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
 #define MLIR_DIALECT_GPU_PIPELINES_PASSES_H_
 
+#include "mlir/Pass/PassOptions.h"
+
 namespace mlir {
 namespace gpu {
+
+/// Options for the gpu to nvvm pipeline.
+struct GPUToNVVMPipelineOptions
+    : public PassPipelineOptions<GPUToNVVMPipelineOptions> {
+  PassOptions::Option<int64_t> indexBitWidth{
+      *this, "index-bitwidth",
+      llvm::cl::desc("Bitwidth of the index type for the host (warning this "
+                     "should be 64 until the GPU layering is fixed)"),
+      llvm::cl::init(64)};
+  PassOptions::Option<std::string> cubinTriple{
+      *this, "cubin-triple",
+      llvm::cl::desc("Triple to use to serialize to cubin."),
+      llvm::cl::init("nvptx64-nvidia-cuda")};
+  PassOptions::Option<std::string> cubinChip{
+      *this, "cubin-chip", llvm::cl::desc("Chip to use to serialize to cubin."),
+      llvm::cl::init("sm_50")};
+  PassOptions::Option<std::string> cubinFeatures{
+      *this, "cubin-features",
+      llvm::cl::desc("Features to use to serialize to cubin."),
+      llvm::cl::init("+ptx60")};
+  PassOptions::Option<std::string> cubinFormat{
+      *this, "cubin-format",
+      llvm::cl::desc("Compilation format to use to serialize to cubin."),
+      llvm::cl::init("fatbin")};
+  PassOptions::Option<int> optLevel{
+      *this, "opt-level",
+      llvm::cl::desc("Optimization level for NVVM compilation"),
+      llvm::cl::init(2)};
+  PassOptions::Option<bool> kernelUseBarePtrCallConv{
+      *this, "kernel-bare-ptr-calling-convention",
+      llvm::cl::desc(
+          "Whether to use the bareptr calling convention on the kernel "
+          "(warning this should be false until the GPU layering is fixed)"),
+      llvm::cl::init(false)};
+  PassOptions::Option<bool> hostUseBarePtrCallConv{
+      *this, "host-bare-ptr-calling-convention",
+      llvm::cl::desc(
+          "Whether to use the bareptr calling convention on the host (warning "
+          "this should be false until the GPU layering is fixed)"),
+      llvm::cl::init(false)};
+};
+
+//===----------------------------------------------------------------------===//
+// Building and Registering.
+//===----------------------------------------------------------------------===//
+
+/// Adds the GPU to NVVM pipeline to the given pass manager. Transforms main
+/// dialects into NVVM targets. Begins with GPU code regions, then handles host
+/// code.
+void buildLowerToNVVMPassPipeline(OpPassManager &pm,
+                                  const GPUToNVVMPipelineOptions &options);
+
+/// Register all pipeleines for the `gpu` dialect.
 void registerGPUToNVVMPipeline();
+
 } // namespace gpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index 5bee234e932a69..0b4739214bf2f1 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -40,54 +40,14 @@ using namespace mlir;
 
 #if MLIR_CUDA_CONVERSIONS_ENABLED
 namespace {
-struct GPUToNVVMPipelineOptions
-    : public PassPipelineOptions<GPUToNVVMPipelineOptions> {
-  PassOptions::Option<int64_t> indexBitWidth{
-      *this, "index-bitwidth",
-      llvm::cl::desc("Bitwidth of the index type for the host (warning this "
-                     "should be 64 until the GPU layering is fixed)"),
-      llvm::cl::init(64)};
-  PassOptions::Option<std::string> cubinTriple{
-      *this, "cubin-triple",
-      llvm::cl::desc("Triple to use to serialize to cubin."),
-      llvm::cl::init("nvptx64-nvidia-cuda")};
-  PassOptions::Option<std::string> cubinChip{
-      *this, "cubin-chip", llvm::cl::desc("Chip to use to serialize to cubin."),
-      llvm::cl::init("sm_50")};
-  PassOptions::Option<std::string> cubinFeatures{
-      *this, "cubin-features",
-      llvm::cl::desc("Features to use to serialize to cubin."),
-      llvm::cl::init("+ptx60")};
-  PassOptions::Option<std::string> cubinFormat{
-      *this, "cubin-format",
-      llvm::cl::desc("Compilation format to use to serialize to cubin."),
-      llvm::cl::init("fatbin")};
-  PassOptions::Option<int> optLevel{
-      *this, "opt-level",
-      llvm::cl::desc("Optimization level for NVVM compilation"),
-      llvm::cl::init(2)};
-  PassOptions::Option<bool> kernelUseBarePtrCallConv{
-      *this, "kernel-bare-ptr-calling-convention",
-      llvm::cl::desc(
-          "Whether to use the bareptr calling convention on the kernel "
-          "(warning this should be false until the GPU layering is fixed)"),
-      llvm::cl::init(false)};
-  PassOptions::Option<bool> hostUseBarePtrCallConv{
-      *this, "host-bare-ptr-calling-convention",
-      llvm::cl::desc(
-          "Whether to use the bareptr calling convention on the host (warning "
-          "this should be false until the GPU layering is fixed)"),
-      llvm::cl::init(false)};
-};
 
 //===----------------------------------------------------------------------===//
 // Common pipeline
 //===----------------------------------------------------------------------===//
-void buildCommonPassPipeline(OpPassManager &pm,
-                             const GPUToNVVMPipelineOptions &options) {
+void buildCommonPassPipeline(
+    OpPassManager &pm, const mlir::gpu::GPUToNVVMPipelineOptions &options) {
   pm.addPass(createConvertNVGPUToNVVMPass());
   pm.addPass(createGpuKernelOutliningPass());
-  pm.addPass(createConvertLinalgToLoopsPass());
   pm.addPass(createConvertVectorToSCFPass());
   pm.addPass(createConvertSCFToCFPass());
   pm.addPass(createConvertNVVMToLLVMPass());
@@ -114,7 +74,7 @@ void buildCommonPassPipeline(OpPassManager &pm,
 // GPUModule-specific stuff.
 //===----------------------------------------------------------------------===//
 void buildGpuPassPipeline(OpPassManager &pm,
-                          const GPUToNVVMPipelineOptions &options) {
+                          const mlir::gpu::GPUToNVVMPipelineOptions &options) {
   pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
   ConvertGpuOpsToNVVMOpsOptions opt;
   opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
@@ -129,7 +89,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
 // Host Post-GPU pipeline
 //===----------------------------------------------------------------------===//
 void buildHostPostPipeline(OpPassManager &pm,
-                           const GPUToNVVMPipelineOptions &options) {
+                           const mlir::gpu::GPUToNVVMPipelineOptions &options) {
   GpuToLLVMConversionPassOptions opt;
   opt.hostBarePtrCallConv = options.hostUseBarePtrCallConv;
   opt.kernelBarePtrCallConv = options.kernelUseBarePtrCallConv;
@@ -143,36 +103,28 @@ void buildHostPostPipeline(OpPassManager &pm,
   pm.addPass(createReconcileUnrealizedCastsPass());
 }
 
-void buildLowerToNVVMPassPipeline(OpPassManager &pm,
-                                  const GPUToNVVMPipelineOptions &options) {
-  //===----------------------------------------------------------------------===//
-  // Common pipeline
-  //===----------------------------------------------------------------------===//
+} // namespace
+
+void mlir::gpu::buildLowerToNVVMPassPipeline(
+    OpPassManager &pm, const GPUToNVVMPipelineOptions &options) {
+  // Common pipelines
   buildCommonPassPipeline(pm, options);
 
-  //===----------------------------------------------------------------------===//
-  // GPUModule-specific stuff.
-  //===----------------------------------------------------------------------===//
+  // GPUModule-specific stuff
   buildGpuPassPipeline(pm, options);
 
-  //===----------------------------------------------------------------------===//
-  // Host post-GPUModule-specific stuff.
-  //===----------------------------------------------------------------------===//
+  // Host post-GPUModule-specific stuff
   buildHostPostPipeline(pm, options);
 }
-} // namespace
 
-namespace mlir {
-namespace gpu {
-void registerGPUToNVVMPipeline() {
+void mlir::gpu::registerGPUToNVVMPipeline() {
   PassPipelineRegistration<GPUToNVVMPipelineOptions>(
-      "gpu-lower-to-nvvm",
-      "The default pipeline lowers main dialects (arith, linalg, memref, scf, "
+      "gpu-lower-to-nvvm-pipeline",
+      "The default pipeline lowers main dialects (arith, memref, scf, "
       "vector, gpu, and nvgpu) to NVVM. It starts by lowering GPU code to the "
       "specified compilation target (default is fatbin) then lowers the host "
       "code.",
       buildLowerToNVVMPassPipeline);
 }
-} // namespace gpu
-} // namespace mlir
+
 #endif // MLIR_CUDA_CONVERSIONS_ENABLED

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
index 42348e39832ade..0cc5d8645bb364 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  | mlir-opt -gpu-lower-to-nvvm -debug-only=serialize-to-isa \
+// RUN:  | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=serialize-to-isa \
 // RUN:  2>&1 | FileCheck %s
 
 // CHECK: Generated by LLVM NVPTX Back-End

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
index 62d0d9e1cac984..5a624e64342974 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
@@ -4,7 +4,7 @@
 // RUN: mlir-opt \
 // RUN: --pass-pipeline="builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm,affine-expand-index-ops,lower-affine,convert-arith-to-llvm),convert-vector-to-llvm,canonicalize,cse)" \
 // RUN: %s \
-// RUN: | mlir-opt --gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt --gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_c_runner_utils \

diff  --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
index 94a57d7c266819..378e5b39415b5c 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize |\
 // RUN: mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if |\
 // RUN: mlir-opt -lower-affine -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm \
-// RUN:  -convert-arith-to-llvm -gpu-lower-to-nvvm | \
+// RUN:  -convert-arith-to-llvm -gpu-lower-to-nvvm-pipeline | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \

diff  --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
index 896051ab5dd7eb..7e9234901ffa1a 100644
--- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
+++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir
@@ -2,7 +2,7 @@
 // everything on the same thread.
 // RUN: mlir-opt %s -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -gpu-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm-pipeline | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
@@ -13,7 +13,7 @@
 // RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \
 // RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -gpu-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm-pipeline | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \
@@ -23,7 +23,7 @@
 // RUN: mlir-opt %s  -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
 // RUN:   -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \
 // RUN: mlir-opt -convert-vector-to-scf -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
-// RUN:  -gpu-lower-to-nvvm | \
+// RUN:  -gpu-lower-to-nvvm-pipeline | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_cuda_runtime \
 // RUN:   -shared-libs=%mlir_c_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
index d4bd51aab03535..8379710ebbbb77 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s \
 // RUN:  -transform-interpreter \
 // RUN:  -test-transform-dialect-erase-schedule \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
index 3e5f291db8e744..afed0ef667a277 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -11,7 +11,7 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter \
 // RUN:   -test-transform-dialect-erase-schedule \
-// RUN:   -gpu-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
+// RUN:   -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-features=+ptx76 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index bbeddd5bb2285f..958da79ee1668f 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
index d5950eae2543a6..6b5b635c853454 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32-bare-ptr.mlir
@@ -3,7 +3,7 @@
 // Similar to the wmma-matmul-f32 but but with the memref bare pointer lowering convention.
 // This test also uses gpu.memcpy operations (instead of gpu.host_register).
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="host-bare-ptr-calling-convention=1 kernel-bare-ptr-calling-convention=1 cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="host-bare-ptr-calling-convention=1 kernel-bare-ptr-calling-convention=1 cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --entry-point-result=void \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
index c75f9c1b5649b1..7fbe3e1c881911 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_70 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
index fe999e0aa575b1..9e10aab0f3812a 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \
@@ -8,7 +8,7 @@
 
 // Same as above but with the memref bare pointer lowering convention.
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="kernel-bare-ptr-calling-convention=1 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="kernel-bare-ptr-calling-convention=1 cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
index dcd503c7bd806c..c2ea7919cc3f1e 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
index 8236550feb1113..db649cbeb1943e 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
index 6f965c225e2d89..60323cee952a04 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
index 340db39f5d28f8..1501160e98a170 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
index b4fc32ff9b838a..8e683f360f10c0 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
index f43a095584d69c..b1cae5b3f971a8 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
index 7f5b38b34c8995..41024a003b1833 100644
--- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
+++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
index a894030d430807..512f4902e5ec30 100644
--- a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
+++ b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir
index 9555a77f45f11f..99ea1208e9c5e7 100644
--- a/mlir/test/Integration/GPU/CUDA/printf.mlir
+++ b/mlir/test/Integration/GPU/CUDA/printf.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
index 4e5bb3e8f5ca64..cd11592c2dceb2 100644
--- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/asd b/mlir/test/Integration/GPU/CUDA/sm90/asd
new file mode 100644
index 00000000000000..353d8e7c16b741
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/asd
@@ -0,0 +1,207 @@
+module attributes {gpu.container_module} {
+  llvm.mlir.global private constant @vector_print_str_0(dense<[73, 110, 99, 111, 114, 114, 101, 99, 116, 32, 82, 101, 115, 117, 108, 116, 115, 32, 58, 10, 0]> : tensor<21xi8>) {addr_space = 0 : i32} : !llvm.array<21 x i8>
+  llvm.func @printNewline()
+  llvm.func @printI64(i64)
+  llvm.func @printString(!llvm.ptr)
+  llvm.mlir.global private constant @vector_print_str(dense<[67, 111, 114, 114, 101, 99, 116, 32, 82, 101, 115, 117, 108, 116, 115, 32, 58, 10, 0]> : tensor<19xi8>) {addr_space = 0 : i32} : !llvm.array<19 x i8>
+  llvm.func @malloc(i64) -> !llvm.ptr
+  llvm.mlir.global private @__mbarrier() {addr_space = 3 : i32, alignment = 8 : i64} : !llvm.array<2 x i64>
+  llvm.func @printMemrefF32(i64, !llvm.ptr) attributes {sym_visibility = "private"}
+  llvm.mlir.global private @dynamicShmem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x f16>
+  llvm.mlir.global private @accShmem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x f32>
+  llvm.func @main() {
+    %0 = llvm.mlir.constant(2 : index) : i64
+    %1 = llvm.mlir.constant(0 : i8) : i8
+    %2 = llvm.mlir.constant(64 : index) : i64
+    %3 = llvm.mlir.constant(65536 : i32) : i32
+    %4 = llvm.mlir.constant(16 : index) : i64
+    %5 = llvm.mlir.constant(8 : index) : i64
+    %6 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+    %7 = llvm.mlir.constant(6 : i32) : i64
+    %8 = llvm.mlir.constant(5 : i32) : i64
+    %9 = llvm.mlir.constant(0 : i32) : i64
+    %10 = llvm.mlir.constant(3 : i32) : i64
+    %11 = llvm.mlir.constant(1 : i32) : i32
+    %12 = llvm.mlir.constant(0 : i32) : i32
+    %13 = llvm.mlir.constant(9.99999993E-9 : f32) : f32
+    %14 = llvm.mlir.constant(1 : index) : i64
+    %15 = llvm.mlir.constant(0 : index) : i64
+    %16 = llvm.mlir.constant(128 : index) : i64
+    %17 = llvm.mlir.zero : !llvm.ptr
+    %18 = llvm.getelementptr %17[16384] : (!llvm.ptr) -> !llvm.ptr, f16
+    %19 = llvm.ptrtoint %18 : !llvm.ptr to i64
+    %20 = llvm.call @malloc(%19) : (i64) -> !llvm.ptr
+    %21 = llvm.call @malloc(%19) : (i64) -> !llvm.ptr
+    %22 = llvm.getelementptr %17[16384] : (!llvm.ptr) -> !llvm.ptr, f32
+    %23 = llvm.ptrtoint %22 : !llvm.ptr to i64
+    %24 = llvm.call @malloc(%23) : (i64) -> !llvm.ptr
+    %25 = llvm.call @malloc(%23) : (i64) -> !llvm.ptr
+    llvm.br ^bb1(%15 : i64)
+  ^bb1(%26: i64):  // 2 preds: ^bb0, ^bb5
+    %27 = llvm.icmp "slt" %26, %16 : i64
+    llvm.cond_br %27, ^bb2, ^bb6
+  ^bb2:  // pred: ^bb1
+    llvm.br ^bb3(%15 : i64)
+  ^bb3(%28: i64):  // 2 preds: ^bb2, ^bb4
+    %29 = llvm.icmp "slt" %28, %16 : i64
+    llvm.cond_br %29, ^bb4, ^bb5
+  ^bb4:  // pred: ^bb3
+    %30 = llvm.mul %26, %16  : i64
+    %31 = llvm.add %30, %28  : i64
+    %32 = llvm.udiv %31, %5  : i64
+    %33 = llvm.urem %32, %4  : i64
+    %34 = llvm.trunc %33 : i64 to i32
+    %35 = llvm.sitofp %34 : i32 to f16
+    %36 = llvm.getelementptr %21[%31] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+    llvm.store %35, %36 : f16, !llvm.ptr
+    %37 = llvm.mul %28, %2  : i64
+    %38 = llvm.add %37, %26  : i64
+    %39 = llvm.udiv %38, %5  : i64
+    %40 = llvm.urem %39, %4  : i64
+    %41 = llvm.trunc %40 : i64 to i32
+    %42 = llvm.sitofp %41 : i32 to f16
+    %43 = llvm.mul %28, %16  : i64
+    %44 = llvm.add %43, %26  : i64
+    %45 = llvm.getelementptr %20[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+    llvm.store %42, %45 : f16, !llvm.ptr
+    %46 = llvm.getelementptr %24[%31] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    llvm.store %6, %46 : f32, !llvm.ptr
+    %47 = llvm.getelementptr %25[%31] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    llvm.store %6, %47 : f32, !llvm.ptr
+    %48 = llvm.add %28, %14  : i64
+    llvm.br ^bb3(%48 : i64)
+  ^bb5:  // pred: ^bb3
+    %49 = llvm.add %26, %14  : i64
+    llvm.br ^bb1(%49 : i64)
+  ^bb6:  // pred: ^bb1
+    %50 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr
+    %51 = llvm.call @mgpuMemAlloc(%19, %50, %1) : (i64, !llvm.ptr, i8) -> !llvm.ptr
+    %52 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    %53 = llvm.insertvalue %51, %52[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %54 = llvm.insertvalue %51, %53[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %55 = llvm.insertvalue %15, %54[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %56 = llvm.insertvalue %16, %55[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %57 = llvm.insertvalue %16, %56[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %58 = llvm.insertvalue %16, %57[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %59 = llvm.insertvalue %14, %58[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %60 = llvm.call @mgpuMemAlloc(%19, %50, %1) : (i64, !llvm.ptr, i8) -> !llvm.ptr
+    %61 = llvm.insertvalue %60, %52[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %62 = llvm.insertvalue %60, %61[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %63 = llvm.insertvalue %15, %62[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %64 = llvm.insertvalue %16, %63[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %65 = llvm.insertvalue %16, %64[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %66 = llvm.insertvalue %16, %65[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %67 = llvm.insertvalue %14, %66[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    %68 = llvm.call @mgpuMemAlloc(%23, %50, %1) : (i64, !llvm.ptr, i8) -> !llvm.ptr
+    llvm.call @mgpuMemcpy(%51, %20, %19, %50) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> ()
+    llvm.call @mgpuMemcpy(%60, %21, %19, %50) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> ()
+    %69 = llvm.alloca %14 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
+    llvm.store %59, %69 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
+    %70 = llvm.alloca %14 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
+    llvm.store %67, %70 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
+    %71 = llvm.alloca %8 x i64 : (i64) -> !llvm.ptr
+    llvm.store %16, %71 : i64, !llvm.ptr
+    %72 = llvm.getelementptr %71[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+    llvm.store %2, %72 : i64, !llvm.ptr
+    %73 = llvm.call @mgpuTensorMapEncodeTiledMemref(%0, %69, %7, %9, %10, %9, %9, %71) : (i64, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr) -> !llvm.ptr
+    %74 = llvm.alloca %8 x i64 : (i64) -> !llvm.ptr
+    llvm.store %2, %74 : i64, !llvm.ptr
+    %75 = llvm.getelementptr %74[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+    llvm.store %2, %75 : i64, !llvm.ptr
+    %76 = llvm.call @mgpuTensorMapEncodeTiledMemref(%0, %70, %7, %9, %10, %9, %9, %74) : (i64, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr) -> !llvm.ptr
+    gpu.launch_func  @main_kernel::@main_kernel blocks in (%14, %14, %14) threads in (%16, %14, %14) : i64 dynamic_shared_memory_size %3 args(%68 : !llvm.ptr, %68 : !llvm.ptr, %15 : i64, %16 : i64, %16 : i64, %16 : i64, %14 : i64, %73 : !llvm.ptr, %76 : !llvm.ptr)
+    llvm.call @mgpuMemcpy(%24, %68, %23, %50) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> ()
+    llvm.br ^bb7(%15 : i64)
+  ^bb7(%77: i64):  // 2 preds: ^bb6, ^bb14
+    %78 = llvm.icmp "slt" %77, %16 : i64
+    llvm.cond_br %78, ^bb8, ^bb15
+  ^bb8:  // pred: ^bb7
+    llvm.br ^bb9(%15 : i64)
+  ^bb9(%79: i64):  // 2 preds: ^bb8, ^bb13
+    %80 = llvm.icmp "slt" %79, %16 : i64
+    llvm.cond_br %80, ^bb10, ^bb14
+  ^bb10:  // pred: ^bb9
+    llvm.br ^bb11(%15 : i64)
+  ^bb11(%81: i64):  // 2 preds: ^bb10, ^bb12
+    %82 = llvm.icmp "slt" %81, %16 : i64
+    llvm.cond_br %82, ^bb12, ^bb13
+  ^bb12:  // pred: ^bb11
+    %83 = llvm.mul %77, %16  : i64
+    %84 = llvm.add %83, %81  : i64
+    %85 = llvm.getelementptr %20[%84] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+    %86 = llvm.load %85 : !llvm.ptr -> f16
+    %87 = llvm.mul %81, %16  : i64
+    %88 = llvm.add %87, %79  : i64
+    %89 = llvm.getelementptr %21[%88] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+    %90 = llvm.load %89 : !llvm.ptr -> f16
+    %91 = llvm.add %83, %79  : i64
+    %92 = llvm.getelementptr %25[%91] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %93 = llvm.load %92 : !llvm.ptr -> f32
+    %94 = llvm.fpext %86 : f16 to f32
+    %95 = llvm.fpext %90 : f16 to f32
+    %96 = llvm.fmul %94, %95  : f32
+    %97 = llvm.fadd %93, %96  : f32
+    llvm.store %97, %92 : f32, !llvm.ptr
+    %98 = llvm.add %81, %14  : i64
+    llvm.br ^bb11(%98 : i64)
+  ^bb13:  // pred: ^bb11
+    %99 = llvm.add %79, %14  : i64
+    llvm.br ^bb9(%99 : i64)
+  ^bb14:  // pred: ^bb9
+    %100 = llvm.add %77, %14  : i64
+    llvm.br ^bb7(%100 : i64)
+  ^bb15:  // pred: ^bb7
+    llvm.br ^bb16(%15, %12, %12 : i64, i32, i32)
+  ^bb16(%101: i64, %102: i32, %103: i32):  // 2 preds: ^bb15, ^bb24
+    %104 = llvm.icmp "slt" %101, %16 : i64
+    llvm.cond_br %104, ^bb17, ^bb25
+  ^bb17:  // pred: ^bb16
+    llvm.br ^bb18(%15, %102, %103 : i64, i32, i32)
+  ^bb18(%105: i64, %106: i32, %107: i32):  // 2 preds: ^bb17, ^bb23
+    %108 = llvm.icmp "slt" %105, %16 : i64
+    llvm.cond_br %108, ^bb19, ^bb24
+  ^bb19:  // pred: ^bb18
+    %109 = llvm.mul %101, %16  : i64
+    %110 = llvm.add %109, %105  : i64
+    %111 = llvm.getelementptr %25[%110] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %112 = llvm.load %111 : !llvm.ptr -> f32
+    %113 = llvm.getelementptr %24[%110] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %114 = llvm.load %113 : !llvm.ptr -> f32
+    %115 = llvm.fsub %112, %114  : f32
+    %116 = llvm.intr.fabs(%115)  : (f32) -> f32
+    %117 = llvm.fcmp "ult" %13, %116 : f32
+    llvm.cond_br %117, ^bb20, ^bb21
+  ^bb20:  // pred: ^bb19
+    %118 = llvm.add %106, %11  : i32
+    llvm.br ^bb22(%118, %107 : i32, i32)
+  ^bb21:  // pred: ^bb19
+    %119 = llvm.add %107, %11  : i32
+    llvm.br ^bb22(%106, %119 : i32, i32)
+  ^bb22(%120: i32, %121: i32):  // 2 preds: ^bb20, ^bb21
+    llvm.br ^bb23
+  ^bb23:  // pred: ^bb22
+    %122 = llvm.add %105, %14  : i64
+    llvm.br ^bb18(%122, %120, %121 : i64, i32, i32)
+  ^bb24:  // pred: ^bb18
+    %123 = llvm.add %101, %14  : i64
+    llvm.br ^bb16(%123, %106, %107 : i64, i32, i32)
+  ^bb25:  // pred: ^bb16
+    %124 = llvm.mlir.addressof @vector_print_str : !llvm.ptr
+    llvm.call @printString(%124) : (!llvm.ptr) -> ()
+    %125 = llvm.sext %103 : i32 to i64
+    llvm.call @printI64(%125) : (i64) -> ()
+    llvm.call @printNewline() : () -> ()
+    %126 = llvm.mlir.addressof @vector_print_str_0 : !llvm.ptr
+    llvm.call @printString(%126) : (!llvm.ptr) -> ()
+    %127 = llvm.sext %102 : i32 to i64
+    llvm.call @printI64(%127) : (i64) -> ()
+    llvm.call @printNewline() : () -> ()
+    llvm.return
+  }
+  gpu.binary @main_kernel  [#gpu.object<#nvvm.target<O = 3, chip = "sm_90a", features = "+ptx80">, "P\EDU\BA\01\00\10\00\A83\00\00\00\00\00\00\02\00\01\01@\00\00\00p$\00\00\00\00\00\00\00\00\00\00\00\00\00\00\07\00\01\00Z\00\00\00\00\00\00\00\00\00\00\00\11\00\10\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\7FELF\02\01\013\07\00\00\00\00\00\00\00\02\00\BE\00{\00\00\00\00\00\00\00\00\00\00\00X#\00\00\00\00\00\00X \00\00\00\00\00\00Z\0DZ\00@\008\00\05\00@\00\0C\00\01\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.uft.entry\00.nv.info\00.text.main_kernel\00.nv.info.main_kernel\00.nv.shared.main_kernel\00.rel.text.main_kernel\00.rela.text.main_kernel\00.debug_frame\00.rel.debug_frame\00.rela.debug_frame\00.nv.constant0.main_kernel\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.uft.entry\00.nv.info\00main_kernel\00.text.main_kernel\00.nv.info.main_kernel\00.nv.shared.main_kernel\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00.rel.text.main_kernel\00.rela.text.main_kernel\00$__dynamicShmem__31\00$____mbarrier__33\00$__accShmem__35\00.debug_frame\00.rel.debug_frame\00.rela.debug_frame\00.nv.constant0.main_kernel\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00L\00\00\00\03\00\09\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00s\00\00\00\03\00\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\02\01\00\00\03\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\12\10\09\00\00\00\00\00\00\00\00\00\80\18\00\00\00\00\00\002\01\00\00\03\00\0B\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\FF\FF\FF\FF$\00\00\00\00\00\00\00\FF\FF\FF\FF\FF\FF\FF\FF\03\00\04|\FF\FF\FF\FF\0F\0C\81\80\80(\00\08\FF\81\80(\08\81\80\80(\00\00\00\FF\FF\FF\FF,\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\80\18\00\00\00\00\00\00\04T\00\00\00\0C\81\80\80(\00\04\A0\05\00\00\00\00\00\00\04/\08\00\05\00\00\00\9A\00\00\00\04#\08\00\05\00\00\00\00\00\00\00\04\12\08\00\05\00\00\00\00\00\00\00\04\11\08\00\05\00\00\00\00\00\00\00\047\04\00{\00\00\00\04\17\0C\00\00\00\00\00\08\00@\00\00\F0!\00\04\17\0C\00\00\00\00\00\07\008\00\00\F0!\00\04\17\0C\00\00\00\00\00\06\000\00\00\F0!\00\04\17\0C\00\00\00\00\00\05\00(\00\00\F0!\00\04\17\0C\00\00\00\00\00\04\00 \00\00\F0!\00\04\17\0C\00\00\00\00\00\03\00\18\00\00\F0!\00\04\17\0C\00\00\00\00\00\02\00\10\00\00\F0!\00\04\17\0C\00\00\00\00\00\01\00\08\00\00\F0!\00\04\17\0C\00\00\00\00\00\00\00\00\00\00\F0!\00\03\1B\FF\00\0490\00\C0\00\00\00\FF\00\00\00\00\00\00\00\00\01\09\00\00\01\00\00\FF\00\00\00\08\00\00\00\00\01\09\00p\09\00\00\00\00\00\00\00\00\00\00\0A\01?\00\038\02\00\04\1C\0C\00@\12\00\00\80\17\00\00\D0\17\00\00\04\1E\04\00\00\00\00\00\03\19H\00\04\0A\08\00\06\00\00\00\10\02H\00\00\00\00\00D\00\00\00\00\00\00\00\02\00\00\00\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\82{\01\FF\00\0A\00\00\00\08\00\00\00$\0E\00\C3y\04\00\00\00\00\00\00\88\00\00\00b\0E\00\19y\06\00\00\00\00\00\00!\00\00\00\A2\0E\00\82x\08\00\00\04\00\00\00\00\00\00\00\E2\0F\00\82x\0A\00\FE\FF\1F\00\00\00\00\00\00\E2\0F\00\90x\09\08\08\00\00\00?\E0\FF\0F\00\E2\0F\00Ey\00\00\A0\04\00\00\00\00\80\03\00\E2\0F\00\82x\0B\00\00\F8\FF\7F\00\00\00\00\00\E2\0F\00\B9z\06\00\00\92\00\00\00\0A\00\00\00\E2\0F\00\B9z\0C\00\00\94\00\00\00\0A\00\00\00\E2\0F\00\96x\09\04T\06\00\00\09\00\00\08\00\E4/\00\96x\08\04T\06\00\00\08\00\00\08\00\E2\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\E2\0F\00\0Cr\00\06\FF\00\00\00pR\F0\03\00\E2O\00\C6s\00\00\00\00\00\00\00\00\00\00\00n\0E\00\B2u?\09\0A\00\00\00\00\01\00\08\00b\02\00\18y\00\00\00\00\00\00\00\00\00\00\00\E2\0F\00\B2u?\09\0A\08\00\00\00\01\00\08\00b\02\00\B9y\00\06\00\00\00\00\00\00\04\08\00\E2\03\00\B9y\00\0C\00\00\00\00\00\00\04\08\00\E4\03\00G\09\EC\00\00\00\00\00\00\00\80\03\00\EA/\00$t\00\FF\00\80\00\00\FF\00\8E\07\00\E2\0F\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\82|\0A\00?\00\00\00\00\00\00\08\00\E4\0F\00\A7y\FF\FF\00\00\00\00\09\00\00\08\00\F4\03\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82|\0B\00\0A\00\00\00\00\00\00\08\00\E4\0F\00\B4u\00\06\08\00\00\00\00\80\00\08\00\F4\05\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E8\00\FD\FF\FF\FF\FF\FF\93\03\00\EAO\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\90x\10\08\00\80\00\00?\E0\FF\0F\00\D8\0F\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82|\11\00\09\00\00\00\00\00\00\08\00\E2\0F\00\82|\12\00\0A\00\00\00\00\00\00\08\00\E2\0F\00\82|\13\00\0A\00\00\00\00\00\00\08\00\E4\0F\00\B4u\00\0C\10\00\00\00\00\80\00\08\00\F0\05\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E0\00\FD\FF\FF\FF\FF\FF\93\03\00\EAO\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\90x\10\08\00\A0\00\00?\E0\FF\0F\00\E2\0F\00\82x\12\00@\00\00\00\00\00\00\00\00\D6\0F\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82|\11\00\09\00\00\00\00\00\00\08\00\E2\0F\00\82|\13\00\0A\00\00\00\00\00\00\08\00\E4\0F\00\B4u\00\0C\10\00\00\00\00\80\00\08\00\F2\05\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E4\00\FD\FF\FF\FF\FF\FF\93\03\00\EAO\00\A7y\FF\FF\00\08\00\00\09\00\00\08\00\E2\05\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\90x\10\08\00@\00\00?\E0\FF\0F\00\E4\0F\00\90x\11\09\08\00\00\00?\E0\FF\0F\00\D8\0F\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82x\12\00@\00\00\00\00\00\00\00\00\E2\0F\00\82|\13\00\0A\00\00\00\00\00\00\08\00\E4\0F\00\B4u\00\06\10\00\00\00\00\80\00\08\00\F2\07\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E4\00\FD\FF\FF\FF\FF\FF\93\03\00\EA\8F\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\90x\10\08\00\C0\00\00?\E0\FF\0F\00\D8\0F\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82|\12\00\0A\00\00\00\00\00\00\08\00\E2\0F\00\82x\13\00@\00\00\00\00\00\00\00\00\E4\0F\00\B4u\00\0C\10\00\00\00\00\80\00\08\00\F2\07\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E4\00\FD\FF\FF\FF\FF\FF\93\03\00\EA\8F\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\90x\10\08\00\E0\00\00?\E0\FF\0F\00\D8\0F\00/\08?\00\00\00\00\00\00\00\82\03\00\E2\0F\00\82x\12\00@\00\00\00\00\00\00\00\00\E2\0F\00\82x\13\00@\00\00\00\00\00\00\00\00\E4\0F\00\B4u\00\0C\10\00\00\00\00\80\00\08\00\F2\07\00\1C\18\00\00\00\00\00\00p\E1\F0\00\00\C4\0F\00\1Cx\00\00\00\00\00\00p\E1\F2\03\00\D6\0F\00G\09\E4\00\FD\FF\FF\FF\FF\FF\93\03\00\EA\8F\00Ay\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\05xX\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xZ\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\\\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x^\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x`\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xb\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05xd\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xf\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xh\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xj\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xl\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xn\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05xp\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xr\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xt\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xv\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xx\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xz\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x|\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x~\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\80\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\82\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\84\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\86\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x\88\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\8A\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\8C\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\8E\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\90\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\92\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x\94\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\96\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\18\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\1A\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\1C\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\1E\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x \00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x\22\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x$\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x&\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x(\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x*\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x,\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x.\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x0\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x2\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x4\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x6\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05x8\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x:\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x<\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x>\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05x@\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xB\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05xD\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xF\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xH\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xJ\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xL\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xN\00\00\00\00\00\00\FF\01\00\00\C4\0F\00\05xP\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xR\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xT\00\00\00\00\00\00\FF\01\00\00\E4\0F\00\05xV\00\00\00\00\00\00\FF\01\00\00\E2\0F\00\9Cx\00\00\00\00\00\00p\F0\F0\03\00\E2\0F\00\82|\05\00?\00\00\00\00\00\00\08\00\E2\0F\00\82|\04\00?\00\00\00\00\00\00\08\00\D2\0F\00\1Cx\00\00\00\00\00\00\08\F0\F0\03\00\E2\0F\00\91r\06\05\09\00\00\00?\18\8E\0F\00\D8\0F\00$~\00\FF\06\00\00\00\FF\00\8E\0F\00\C8o\00\A7u\00\00\FF\00\00\00\7F\01\02\08\00b\02\00\1Ay\00\00\00\90\00\00\00\00\00\00\00\C8\0F\00]\99\00\00\81\96\98\00\00\00\90\03\00\EA\0F\00\A7\95\00\00\FF\00\00\00\7F\00\02\08\00$\0E\00G\99\E8\00\FC\FF\FF\FF\FF\FF\83\03\00\EA\1F\00\91r\07\05\08\00\00\00?p\80\0F\00\E2\0F\00\C5y\00\00\00\00\00\00\00\00\00\00\00\E2\0F\00\82x\19\00@\00\00@\00\00\00\00\00\E2\0F\00\82x\1B\00@\00\00@\00\00\00\00\00\E2\0F\00\91r\0A\05?\00\00\00\04t\0F\08\00\E4\0F\00\90x\05\07\00\80\00\00?\E0\F1\0F\00\E4\0F\00\99x\07\07\04\00\00\00\0A\12\00\08\00\E4\0F\00\90r\06?\0A\00\00\00?\E4\7F\08\00\C8\0F\00\99x\05\05\04\00\00\00\06\12\00\08\00\E4\0F\00\92x\06\07\FF?\00\00?\C0\8E\0F\00\E4\0F\00\92x\05\05\FF?\00\00?\C0\8E\0F\00\E4\0F\00\92x\18\06\00\00\00\04?\FC\8E\0F\00\E4\0F\00\92x\1A\05\00\00\00\02?\FC\8E\0F\00\E4\0F\00\90x\14\06\02\00\00\04?\E0\F1\0F\00\C4\0F\00\90x\16\05\80\00\00\02?\E0\F3\0F\00\E4\0F\00\90x\15?@\00\00@?\E4\7F\08\00\E4\0F\00\90x\17?@\00\00@?\E4\FF\08\00\E2\0F\00\F0yX\18\00\00\E0\01X\08p\08\00\E2\0F\00\90x\10\06\04\00\00\04?\E0\F1\0F\00\E4\0F\00\90x\12\05\00\01\00\02?\E0\F3\0F\00\E4\0F\00\90x\11?@\00\00@?\E4\7F\08\00\C4\0F\00\90x\13?@\00\00@?\E4\FF\08\00\E4\0F\00\90x\0C\06\06\00\00\04?\E0\F1\0F\00\E4\0F\00\90x\0E\05\80\01\00\02?\E0\F3\0F\00\E4\0F\00\90x\0D?@\00\00@?\E4\7F\08\00\E4\0F\00\90x\18\06\00\02\00\04?\E0\F1\0F\00\E4\0F\00\90x\0F?@\00\00@?\E4\FF\08\00\C4\0F\00\90x\19?@\00\00@?\E4\7F\08\00\E2\0F\00\82x\05\00\01\00\00\00\00\00\00\00\00\E2\0F\00\F0yX\14\00\00\E0\01X\08p\08\00\E2\0F\00\90x\14\06\02\02\00\04?\E0\F1\0F\00\C8\0F\00\90x\15?@\00\00@?\E4\7F\08\00\CE\0F\00\F0yX\10\00\00\E0\01X\08p\08\00\E2\0F\00\90x\10\06\04\02\00\04?\E0\F1\0F\00\C8\0F\00\90x\11?@\00\00@?\E4\7F\08\00\CE\0F\00\F0yX\0C\00\00\E0\01X\08p\08\00\E2\0F\00\90x\0C\06\06\02\00\04?\E0\F1\0F\00\C8\0F\00\90x\0D?@\00\00@?\E4\7F\08\00\E4\0F\00\9Cx\00\00\00\00\00\00p\E8\F0\03\00\CA\0F\00\F0y\18\18\00\00\E0\01\18\08p\08\00\D8\0F\00\F0y\18\14\00\00\E0\01\18\08p\08\00\D8\0F\00\F0y\18\10\00\00\E0\01\18\08p\08\00\D8\0F\00\F0y\18\0C\00\00\E0\01\18\08\00\08\00\E6\0F\00\C5y\00\00\00\80\00\00\00\01\01\00\00\E4\0F\00G\09,\00\FC\FF\FF\FF\FF\FF\83\03\00\EA\0F\00\C3y\07\00\00\00\00\00\00\88\00\00\00b\0E\00\19x\00\FF\02\00\00\00\06\16\01\00\00\E2\0F\10$x\02\06\02\00\00\00\FF\00\8E\07\00\E2\0F\00\19x\03\FF\01\00\00\00\06\16\01\00\00\E2\0F\00\82x\06\00\00\04\00\00\00\00\00\00\00\E2\0F\00\1Ax\00\00\03\00\00\00\00\00\00\00\00\E2\0F\00\90x\06\06 \00\00\00?\E0\FF\0F\00\E2\0F\00\12x\02\02\06\00\00\00\FF\C0\8E\07\00\E2\0F\00\C5y\00\00\00\80\00\00\00\00\01\00\00\E4\0F\00\12x\00\00\F0\FF\FF\7F\03\F8\8E\07\00\C4\0F\00\0Cx\00\06\FF\0F\00\00p@\F0\03\00\E4\0F\00\12x\04\00\08\00\00\00\FF\FC\8E\07\00\E2\0F\04$x\03\00\80\00\00\00\FF\00\8E\07\00\E4\0F\006x\05\00@\00\00\00\00\00\00\00\00\E4\0F\006x\00\00H\00\00\00\00\00\00\00\00\E2\0F\00\12r\08\03\02\00\00\00\FF\FC\8E\07\00\E2\0F\08$x\03\04\80\00\00\00\FF\00\8E\07\00\E4\0F\00$x\05\05\80\00\00\00\FF\00\8E\07\00\E2\0F\00\96x\06\07T\06\00\00\06\00\00\08\00\E2/\00$x\07\00\80\00\00\00\FF\00\8E\07\00\E2\0F\00\12r\03\03\02\00\00\00\FF\FC\8E\07\00\C4\0F\00\12r\05\05\02\00\00\00\FF\FC\8E\07\00\E4\0F\08\12r\00\07\02\00\00\00\FF\FC\8E\07\00\E4\0F\00\11|\04\08\06\00\00\00\FF\10\8E\0F\00\E4\0F\00\11|\03\03\06\00\00\00\FF\10\8E\0F\00\E4\0F\00\11|\02\05\06\00\00\00\FF\10\8E\0F\00\E2\0F\00\88s\00\04X\00\00\00\00\0A\00\00\00\E2\03\00\11|\00\00\06\00\00\00\FF\10\8E\0F\00\C6\0F\00\88s\00\04\\ \00\00\00\0A\00\00\00\E8\03\00\88s\00\04`@\00\00\00\0A\00\00\00\E8\03\00\88s\00\04d`\00\00\00\0A\00\00\00\E8\03\00\88s\00\04h\80\00\00\00\0A\00\00\00\E8\03\00\88s\00\04l\A0\00\00\00\0A\00\00\00\E8\03\00\88s\00\04p\C0\00\00\00\0A\00\00\00\E8\03\00\88s\00\04t\E0\00\00\00\0A\00\00\00\E8\03\00\88s\00\04x\00\01\00\00\0A\00\00\00\E8\03\00\88s\00\04| \01\00\00\0A\00\00\00\E8\03\00\88s\00\04\80@\01\00\00\0A\00\00\00\E8\03\00\88s\00\04\84`\01\00\00\0A\00\00\00\E8\03\00\88s\00\04\88\80\01\00\00\0A\00\00\00\E8\03\00\88s\00\04\8C\A0\01\00\00\0A\00\00\00\E8\03\00\88s\00\04\90\C0\01\00\00\0A\00\00\00\E8\03\00\88s\00\04\94\E0\01\00\00\0A\00\00\00\E8\03\00\88s\00\03Z\00\00\00\00\0A\00\00\00\E8\03\00\88s\00\03^ \00\00\00\0A\00\00\00\E8\03\00\88s\00\03b@\00\00\00\0A\00\00\00\E8\03\00\88s\00\03f`\00\00\00\0A\00\00\00\E8\03\00\88s\00\03j\80\00\00\00\0A\00\00\00\E8\03\00\88s\00\03n\A0\00\00\00\0A\00\00\00\E8\03\00\88s\00\03r\C0\00\00\00\0A\00\00\00\E8\03\00\88s\00\03v\E0\00\00\00\0A\00\00\00\E8\03\00\88s\00\03z\00\01\00\00\0A\00\00\00\E8\03\00\88s\00\03~ \01\00\00\0A\00\00\00\E8\03\00\88s\00\03\82@\01\00\00\0A\00\00\00\E8\03\00\88s\00\03\86`\01\00\00\0A\00\00\00\E8\03\00\88s\00\03\8A\80\01\00\00\0A\00\00\00\E8\03\00\88s\00\03\8E\A0\01\00\00\0A\00\00\00\E8\03\00\88s\00\03\92\C0\01\00\00\0A\00\00\00\E8\03\00\88s\00\03\96\E0\01\00\00\0A\00\00\00\E8\03\00\88s\00\02\18\00\00\00\00\0A\00\00\00\E8\03\00\88s\00\02\1C \00\00\00\0A\00\00\00\E8\03\00\88s\00\02 @\00\00\00\0A\00\00\00\E8\03\00\88s\00\02$`\00\00\00\0A\00\00\00\E8\03\00\88s\00\02(\80\00\00\00\0A\00\00\00\E8\03\00\88s\00\02,\A0\00\00\00\0A\00\00\00\E8\03\00\88s\00\020\C0\00\00\00\0A\00\00\00\E8\03\00\88s\00\024\E0\00\00\00\0A\00\00\00\E8\03\00\88s\00\028\00\01\00\00\0A\00\00\00\E8\03\00\88s\00\02< \01\00\00\0A\00\00\00\E8\03\00\88s\00\02@@\01\00\00\0A\00\00\00\E8\03\00\88s\00\02D`\01\00\00\0A\00\00\00\E8\03\00\88s\00\02H\80\01\00\00\0A\00\00\00\E8\03\00\88s\00\02L\A0\01\00\00\0A\00\00\00\E8\03\00\88s\00\02P\C0\01\00\00\0A\00\00\00\E8\03\00\88s\00\02T\E0\01\00\00\0A\00\00\00\E8\03\00\88s\00\00\1A\00\00\00\00\0A\00\00\00\E8\03\00\88s\00\00\1E \00\00\00\0A\00\00\00\E8\03\00\88s\00\00\22@\00\00\00\0A\00\00\00\E8\03\00\88s\00\00&`\00\00\00\0A\00\00\00\E8\03\00\88s\00\00*\80\00\00\00\0A\00\00\00\E8\03\00\88s\00\00.\A0\00\00\00\0A\00\00\00\E8\03\00\88s\00\002\C0\00\00\00\0A\00\00\00\E8\03\00\88s\00\006\E0\00\00\00\0A\00\00\00\E8\03\00\88s\00\00:\00\01\00\00\0A\00\00\00\E8\03\00\88s\00\00> \01\00\00\0A\00\00\00\E8\03\00\88s\00\00B@\01\00\00\0A\00\00\00\E8\03\00\88s\00\00F`\01\00\00\0A\00\00\00\E8\03\00\88s\00\00J\80\01\00\00\0A\00\00\00\E8\03\00\88s\00\00N\A0\01\00\00\0A\00\00\00\E8\03\00\88s\00\00R\C0\01\00\00\0A\00\00\00\E8\03\00\88s\00\00V\E0\01\00\00\0A\00\00\00\E2\03\00M\09\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\19x\03\FF\1F\00\00\00\06\14\01\00\00\E2/\00\B9z\04\00\00\86\00\00\00\0A\00\00\00\E2\0F\00\12x\02\06\1F\00\00\00\FF\C0\8E\07\00\E2\0F\04Ey\00\00\C0\01\00\00\00\00\80\03\00\E2\0F\00\19x\07\06\05\00\00\00\03\12\00\00\00\E4\0F\10\19x\00\FF\05\00\00\00\03\16\01\00\00\E2\0F\00%x\02\02\10\00\00\00\FF\00\8E\07\00\E2\0F\00\10x\16\07\FC\FF\FF\FF\FF\E0\F1\07\00\C6\0F\00$x\05\07\00\02\00\00\FF\00\8E\07\00\E2\0F\04\0Cx\00\16|\00\00\00p`\F2\03\00\E4\0F\00\10x\17\00\FF\FF\FF\FF\FF\E4\7F\00\00\E4\0F\00\19x\07\07\09\00\00\00\00\02\01\00\00\E4\0F\00\0Cr\00\17\FF\00\00\00\10a\F2\03\00\E4\0F\00\12r\00\05\02\00\00\00\FF\FC\8E\07\00\E4\0F\00\12r\02\07\03\00\00\00\FF\FC\8E\07\00\C4\0F\00\10|\14\00\04\00\00\00\FF\E0\F5\0F\00\E2\0F\006|\00\00\06\00\00\00\00\00\00\08\00\E2\0F\00\1Cx\00\00\00\00\00\00p\F0\F0\03\00\E4\0F\00\10|\15\02\05\00\00\00\FF\E4\7F\09\00\E2\0F\00\B9z\04\00\00\82\00\00\00\0A\00\00\00\E4\0F\00G\99(\00\00\00\00\00\00\00\80\03\00\F0\0F\00\84y\04\00\00\00\00\00\00\0C\00\00\00\A2\02\00$r\02\FF\FF\00\00\00\14\00\8E\07\00\E2\0F\00\10x\16\16\04\00\00\00\FF\E0\F3\07\00\E2\0F\00$r\03\FF\FF\00\00\00\15\00\8E\07\00\E2\0F\00\10x\14\14\00\08\00\00\FF\E0\F5\07\00\E4\0F\00\1Cx\00\00\00\00\00\00p\E1\F0\03\00\E2\0F\00$r\17\FF\FF\00\00\00\17\06\8E\00\00\E4\0F\006x\00\00\00\08\00\00\00\00\00\00\00\E4/\00$r\15\FF\FF\00\00\00\15\06\0E\01\00\E2\0F\00\86y\00\02\04\00\00\00\04\1D\10\0C\00\EEC\00Ay\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\10x\02\16|\00\00\00\FF\E1\F7\07\00\E2/\00Ey\00\00\A0\01\00\00\00\00\80\03\00\E2\0F\00\0Cx\00\16|\00\00\00p`\F4\03\00\E4\0F\00\0Cx\00\02\0C\00\00\00p0\F2\03\00\E2\0F\00$r\02\FF\FF\00\00\00\17\0E\8E\01\00\E2\0F\00\0Cr\00\17\FF\00\00\00 a\F4\03\00\C8\0F\00\0Cr\00\02\FF\00\00\00\105r\01\00\DA\0F\00G\19L\00\00\00\00\00\00\00\80\03\00\EA\0F\00\1Cx\00\00\00\00\00\00p\E1\F0\03\00\DA\0F\00\84y\04\00\00\00\00\00\00\0C\00\00\00b.\00$r\02\FF\FF\00\00\00\14\00\8E\07\00\E2\0F\00\10x\16\16\10\00\00\00\FF\E0\F3\07\00\E2\0F\00$r\03\FF\FF\00\00\00\15\00\8E\07\00\E2\0F\00\84y\08\00\00\00\08\00\00\0C\00\00\00\A4\0E\00\10x\14\02\00 \00\00\FF\E0\F5\07\00\E2\0F\00$r\17\FF\FF\00\00\00\17\06\8E\00\00\E2\0F\00\84y\0C\00\00\00\10\00\00\0C\00\00\00\E2\0E\00\0Cx\00\16p\00\00\00p`\F2\03\00\E4\0F\00$r\15\FF\FF\00\00\00\03\06\0E\01\00\E2\0F\00\84y\10\00\00\00\18\00\00\0C\00\00\00b\09\00\0Cr\00\17\FF\00\00\00\10a\F2\03\00\E2\0F\006x\00\00\00 \00\00\00\00\00\00\00\C4\0F\01\86y\00\02\04\00\00\00\04\1D\10\0C\00\E8#\00\86y\00\02\08\00\08\00\04\1D\10\0C\00\E8C\00\86y\00\02\0C\00\10\00\04\1D\10\0C\00\E8\83\00\86y\00\02\10\00\18\00\04\1D\10\0C\00\E2\03\02G\99\B8\00\FC\FF\FF\FF\FF\FF\83\03\00\EA\0F\00Ay\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\10x\02\16|\00\00\00\FF\E1\F7\07\00\E2/\00Ey\00\000\01\00\00\00\00\80\03\00\E2\0F\00\0Cx\00\16|\00\00\00p`\F4\03\00\E4\0F\00\0Cx\00\02\04\00\00\00p0\F2\03\00\E2\0F\00$r\02\FF\FF\00\00\00\17\0E\8E\01\00\E2\0F\00\0Cr\00\17\FF\00\00\00 a\F4\03\00\C8\0F\00\0Cr\00\02\FF\00\00\00\105r\01\00\DA\0F\00G\190\00\00\00\00\00\00\00\80\03\00\EA\0F\00\84y\04\00\00\00\00\00\00\0C\00\00\00b\0E\00$r\02\FF\FF\00\00\00\14\00\8E\07\00\E2\0F\00\10x\16\16\08\00\00\00\FF\E0\F3\07\00\E2\0F\00$r\03\FF\FF\00\00\00\15\00\8E\07\00\E2\0F\00\84y\08\00\00\00\08\00\00\0C\00\00\00\E2\04\00\10x\14\14\00\10\00\00\FF\E0\F5\07\00\E4\0F\00\1Cx\00\00\00\00\00\00p\E1\F0\03\00\E2\0F\00$r\17\FF\FF\00\00\00\17\06\8E\00\00\E4\0F\00$r\15\FF\FF\00\00\00\15\06\0E\01\00\C4\0F\006x\00\00\00\10\00\00\00\00\00\00\00\E2O\00\86y\00\02\04\00\00\00\04\1D\10\0C\00\E8#\00\86y\00\02\08\00\08\00\04\1D\10\0C\00\E6\83\00Ay\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\0Cx\00\16|\00\00\00p\10\F2\03\00\C8\0F\00\0Cr\00\17\FF\00\00\00\10\15p\00\00\DA\0F\00M\89\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00\84y\04\00\00\00\00\00\00\0C\00\00\00\22.\00$r\02\FF\FF\00\00\00\14\00\8E\07\00\E4\0F\00$r\03\FF\FF\00\00\00\15\00\8E\07\00\CA\0F\00\86y\00\02\04\00\00\00\04\1D\10\0C\00\E2\1F\00My\00\00\00\00\00\00\00\00\80\03\00\EA\0F\00Gy\FC\00\FC\FF\FF\FF\FF\FF\83\03\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\18y\00\00\00\00\00\00\00\00\00\00\00\C0\0F\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\F5\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0B\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00`\01\00\00\00\00\00\00L\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\13\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\B0\02\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\02\00\00\00\05\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\AB\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00X\03\00\00\00\00\00\00h\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\007\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C0\03\00\00\00\00\00\000\00\00\00\00\00\00\00\03\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00R\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\03\00\00\00\00\00\00\FC\00\00\00\00\00\00\00\03\00\00\00\09\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\94\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\03\00\00\00\09\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\C9\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\04\00\00\00\00\00\00\18\00\00\00\00\00\00\00\03\00\00\00\04\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00@\00\00\00\01\00\00\00\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\80\18\00\00\00\00\00\00\03\00\00\00\05\00\00\00\80\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00g\00\00\00\08\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\1E\00\00\00\00\00\00!\04\00\00\00\00\00\00\00\00\00\00\09\00\00\00\10\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\DB\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\1E\00\00\00\00\00\00X\02\00\00\00\00\00\00\00\00\00\00\09\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\06\00\00\00\04\00\00\00X#\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\18\01\00\00\00\00\00\00\18\01\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\04\00\00\00X#\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\18\01\00\00\00\00\00\00\18\01\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\05\00\00\00\80\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\80\18\00\00\00\00\00\00\80\18\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\06\00\00\00\00\1E\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00!\04\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\04\00\00\00\00\1E\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00X\02\00\00\00\00\00\00X\02\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\01\01H\00\00\00\B0\0E\00\00\00\00\00\00\AD\0E\00\00@\00\00\00\00\00\08\00Z\00\00\00\00\00\00\00\00\00\00\00\11 \10\00\00\00\00\00\00\00\00\00\00\00\00\00\F7=\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F2\22\0A\0A\0A\0A.version 8.0\0A.target sm_90a\0A.address_size 64\0A\01\00\F8\19.visible .entry main_kernel(\0A.param .u64\19\00\11_\17\00?_0,!\00\0C\1F1!\00\0D\1F2!\00\0D\1F3!\00\0D\1F4!\00\0D\1F5!\00\0D\1F6!\00\0D\1F7!\00\0D\F3\088\0A)\0A{\0A.reg .pred %p<8>;\12\00\95b32 %r<47\12\00\10f\12\00ff<1669&\00\F0\0364 %rd<83>;\0A\0A\09.shaJ\00\FF\0B.align 16 .b8 dynamicShmem&\00\00\118%\00\EF__mbarrier[16]M\00\07$acI\00\22ld\E0\00\22.u\85\00_21, [\E7\00\00\1F]+\00\00\1F0+\00\02\1F7+\00\00/17+\00\02\911];\0Amov.u'\01\F0\051, %tid.x;\0Asetp.ne.s\19\002p2,\1E\00\130.\00\03T\00'8,\F0\00\02\1B\00\02p\01d2, 1;\0A\1A\00S.init\07\01\01k\01\11[=\00\10]U\00\832;\0Aadd.sR\00$9,Y\00\1F8@\00\0D\149@\00\F1\05\0A\09prefetch.tensormap#\00!20\81\01\0F \00\07\111 \00\07\D2\00;78,\E7\01\06\D4\00\116\FF\00\F0\08@%p2 bra $L__BB0_2;\0AcvtA\01\03E\00\133\C3\00\08<\00\814, 32768o\00\04\D6\00\00\07\00\C8ve.expect_tx\E2\00 _,\A2\00\113\E3\00\1046\00\0Ad\00\225,\AE\00\01\1A\00\C3p.async.bulk\E2\003.2dN\00\F4\02::cluster.global.\7F\00\F0\05::complete_tx::bytest\0035],=\010, {\F6\00c%r6} ]\90\00\00q\00\01\B0\01\01\D7\00\01\AE\01\175\DC\00\0F\8C\00<\149\8C\00\1F1\8C\00\0F#13\8D\00X40960}\01o14, 64\9F\00@*13\A0\00/14\A1\00\09222,\1E\02\0F\06\02\1D&22\07\02\07\F0\00\04}\01O1638\DE\00B\1A9\0A\02\0A\DE\00,22\DF\00\05\80\01O9152\90\00 at .23\0E\02/14\90\00\07\137\90\00O5734 \01A*27\90\00/14\91\00\00\05&\04\10:&\04Da.to_\00\02\C9\031d1,\06\00\127F\043s64\BF\00\12ds\05\02\C8\02\02\B3\060154\83\05#f0\01\00\09\8C\05%79\9E\05\03\05\07X7, -1\08\03o33, 10<\00\00\104F\02\09<\00\154M\00\05{\00#2,\83\00\0B\18\00\1F3\18\00\04\1F4\18\00\04\1F5\18\00\04\1F6\18\00\04\1F7\18\00\04\1F8\18\00\04\1F9\18\00\03/50\18\00\04\1F1\18\00\04\0F\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F5\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F6\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04\1F7\F0\00\04/78\18\00\04\0F\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F8\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\04\1F9\F0\00\03?600\18\00\04\1F1\18\00\04\1F2\18\00\04\1F3\18\00\04\1F4\18\00\04\1F5\18\00\04\1F6\18\00\04\1F7\18\00\04\1F8\18\00\04\0F\F0\00\04\1F1\F0\00\04\1F1\F0\00\04\1F1\F0\00\04/13\18\00\04\0F\F0\00\04\1F1\F0\00\04\1F1\F0\00\04/17\18\00\04\0F\F0\00\04/19\18\00\03/20\18\00\04\0F\F0\00\04/22\18\00\04\0F\F0\00\04\1F2\F0\00\04\1F2\F0\00\04\1F2\F0\00\04/27\18\00\04\0F\F0\00\04\1F2\F0\00\04\1F3\F0\00\04\1F3\F0\00\04\1F3\F0\00\04/33\18\00\04\0F\F0\00\04\1F3\F0\00\04\1F3\F0\00\04\1F3\F0\00\04\1F3\F0\00\04\1F3\F0\00\04\1F4\F0\00\04/41\18\00\04\0F\F0\00\04\0F`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\1F6`\09\04\186`\09\04\A5\0C(3:\04\0C\00\C9\00\93p7;\0Ashl.b)\0C\00e\00\02l\0C\193\B8\11\01.\02\02\B8\11\01'\00\0D\14\11\121\19\00\01\FB\0E\09\AD\13\F6\00P1; \0ALAB_WAIT: \F4\11\80try_wait\C7\128ity\17\0F!P1b\0D!31\18\0F\01\94\00\8033; \0A at P1\9B\11\C4.uni DONE; \0A\0F\00\04b\000; \0A\1A\00x: \0A}\0A\0A\09\DB\00\159\DB\00)14\DC\00\01\C2\02\02\1B\12\01(\00T;\0Abfe;\0D\01\C7\02\02$\00\124:\00$orS\00\01\0B\05\02#\00\FA\03461168629337240371\F8\12%53G\00\03L\11\08d\00$4,$\00\0Dd\00\01W\05\02#\00\07d\00\8038849280n\01\C1wgmma.fence.\FE\0E\01\87\14\22ed\1A\13\09\8C\01\16p\F7\13\002\0E2p, \F3\13\02C\00Bmma_@\0F\08G\00\B0.m64n128k16\\\02 .f\08\00B16 {]\02\1A,=\0E*3,\1D\0E*5,\FD\0D*7,\DD\0D)9,\BD\0D\00 \01\08\9D\0D\017\01\07}\0D\04\0D\04\04]\0D\04\ED\03\04=\0D\04\CD\03\04\1D\0D\04\AD\03\04\FD\0C\04\8D\03\04\DD\0C\04m\03\04\BD\0C\04M\03\04\9D\0C\04\8D\0C\04}\0C\04m\0C\04]\0C\04M\0C\04=\0C\04-\0C\04\1D\0C\04\0D\0C\04\FD\0B\00~\02\08\DD\0B\04\CD\0B\04\BD\0B\04\AD\0B\04\9D\0B\04\8D\0B\04}\0B\04m\0B\04]\0B\04M\0B\04=\0B\04-\0B\04\1D\0B\04\0D\0B\04\FD\0A\04\ED\0A\04\DD\0A\04\CD\0A\04\BD\0A\129P\00\04\9D\0A\04\8D\0A\04}\0A\04m\0AW604},\02\03\00\A5\02\02M\02\00P\02\01S\02$}\0A\AD\11\03\C8\02\1F1,\03\08\0A\90\03/32\F3\02\06.40a\04\0F\D5\02\FF\FF@\03\AB\02\00\87\02\0F\D5\02\0D\1F3\D5\02\08\1A6\D5\02\1F4\D5\02\06?536\D5\02\FF\FFR$3,\87\02\0F\D5\02\0D\1F5\D5\02\08\1A8\D5\02\1F6\D5\02\06?664\D5\02\FF\FFR$5,\87\02\0F\D5\02\0D\1F7\D5\02\05?422\AA\02O\04v\13\04f\13\04V\13\04F\13\046\13\04&\13\04\16\13\04\06\13\04\F6\12\04\E6\12\04\D6\12\04\C6\12\04\B6\12\04\A6\12\04\96\12\04\86\12\04v\12\04f\12\04V\12\04F\12\046\12\04&\12\04\16\12\04\06\12\04\F6\11\04\E6\11\04\D6\11\04\C6\11\04\B6\11\04\A6\11\04\96\11\04\86\11\04v\11\04f\11\04V\11\136F\11\046\11\04&\11\04\16\11\04\06\11\04\F6\10\04\E6\10\04\D6\10\04\C6\10\04\B6\10\04\A6\10\04\96\10\04\86\10\04v\10\04f\10\04V\10\04F\10\046\10\04&\10\04\16\10\04\06\10\04\F6\0F\04\E6\0F\04\D6\0F\04\C6\0F\04\B6\0F\04\A6\0F\04\96\0F3668\AA\02\1F7)\0B\16\1F9\AA\02\08\0F)\08O\0F\AA\02\FF\F0\1F9\FE\0A\15\01\0A\01\0F\AA\02\05\0F\A8\0DO\0F\AA\02\FF\EF/41\D3\0A\15\01\FA\00\0F\AA\02\04/30\AA\02\FF\FFR\1F3\A8\0A\0A\04\C4\15\CFcommit_group\CB\15\00\02%\00/wa#\00\02\00\E0#*\0A\09\1E$=%p4C$\00\8E\00\115\E6(\191\E6(\1D3\D2\17\114-\00\01\E9%\0Fz\00\0B\01A\03\01\CF\16\00{\18\11r\FD\01\01?\00\112F\18\01_\18\02\1A\00\146\1A\00d1;\0Aand\17\00#7,\1D\00\106.\00\14rH\00\1F8.\00\03#9,\1D\00\A221474836321\17\01\1F\00\01S\02\04\84\00\139\D5\00#64T\00\22d5\19\00\B57;\0Amul.wide\1A\00#6,?\008128\82\17\01\1D\02\04 \003d55\D2\00\03\1D\00$8,$\00\192l\01\00C\02\06\03,\07\8C\04\01T\02\02 \00\00\07\00E8;\0As\FB* v2<\04! ['\00!],=\04\07\E5\0E\1F},\00\065+32/\00\05\F4\0E\0F/\00\08%64/\00\05\03\0F\0F/\00\08$96/\00\06\12\0F\0F/\00\0851280\00\05\22\0F\0F0\00\09\06\ED\00\062\0F\0F0\00\09\159\EE\00\06B\0F\0F0\00\08%22\EF\00\06R\0F\0F0\00\09\155\F0\00\06b\0F\0F0\00\09\158\F0\00\06r\0F\0F\AD\01\0A\05\F0\00\06\82\0F\0F0\00\09\155\F0\00\06\92\0F\0F0\00\09\158\F0\00\06\A2\0F\0F0\00\08%41\F0\00\06\B2\0F\0F0\00\09\154\F0\00\06\C2\0F\0F0\00\09\148\F0\00\07\D2\0F\18}\B5\03\141|\03\1E8\99\03\01\13\02\01#\00\0B\99\03\01\8E\05\04 \00\0C\99\03\01\A3\05\02$\00\0A\8A\1B\01\B4\05\05\80\03/63\B7\00\05\05C\01\06U\12\0F\E3\00\06\184\80\03\05d\12\0F/\00\08\06^\00\06s\12\0F/\00\08\07\80\03\05\82\12\0F/\00\08\08\80\03\05\92\12\0F0\00\09\07\80\03\05\A2\12\0F0\00\09\07\80\03\05\B2\12\0F0\00\08\07\80\03\06\C2\12\0F0\00\09\07\80\03\05\D2\12\0F0\00\09\07\80\03\05\E2\12\0F\AD\01\0A\05\F0\00\06\F2\12\0F0\00\09\07\80\03\05\02\13\0F0\00\09\06\80\03\06\12\13\0F0\00\08\08\80\03\05\22\13\0F0\00\09\07\80\03\052\13\0F0\00\09\07\80\03\06B\13&;\0AL-\01y\09\01\81\03\00H/\05\82\03\03\AA,\01e\05\01$\00\0C\82\03&6, \00\0D\82\03$7,$\00\0B\82\03\198\82\03\1F7\B9\00\05\04\E5\00\07?\0B\0F\E5\00\06\168\82\03\07N\0B\0F/\00\08\05\82\03\07]\0B\0F/\00\08\05\82\03\07l\0B\0F/\00\08\06\82\03\07|\0B\0F0\00\09\05\82\03\07\8C\0B\0F0\00\09\05\82\03\07\9C\0B\0F0\00\08\06\82\03\07\AC\0B\0F0\00\09\05\82\03\07\BC\0B\0F0\00\09\05\82\03\07\CC\0B\0F\AD\01\0A\05\F0\00\06\DC\0B\0F0\00\09\05\82\03\07\EC\0B\0F0\00\09\05\82\03\07\FC\0B\0F0\00\08\06\82\03\07\0C\0C\0F0\00\09\05\82\03\07\1C\0C\0F0\00\09\06\82\03\06,\0C\0A\82\03\143\82\03/72\82\03\00#9,$\00\0B\82\03\01\1A\18\04 \00\0C\82\03\01m\05\02$\00\0A\82\03\01@\18\05\82\03/71\B9\00\04\157u\01\06\B1\0E\1F},\00\06\07\82\03\06\C0\0E\0F/\00\08\07\82\03\05\CF\0E\0F/\00\08\07\82\03\05\DE\0E\0F/\00\08\08\82\03\05\EE\0E\0F0\00\09\07\82\03\05\FE\0E\0F0\00\09\06\82\03\06\0E\0F\0F0\00\08\08\82\03\05\1E\0F\0F0\00\09\07\82\03\05.\0F\0F0\00\09\07\82\03\05>\0F\0F\AD\01\0A\05\F0\00\06N\0F\0F0\00\09\06\82\03\06^\0F\0F0\00\09\07\82\03\05n\0F\0F0\00\08\08\82\03\05~\0F\0F0\00\09\07\82\03\05\8E\0F\0F0\00\09\07\82\03\06\9E\0F\03\E9\11\11g\F9\0ED %p5s\0E2409&\0F\195&\0F\137\A1\0E\02\E4\0D\01\1D\0C\01\C93\195_\03\01O\1B\02\1F\00*-4\92\03\154\1A\00\189\D5\0E\154y\00.31\02\0B\01\05\09\01$\00)16\FD\03(6,U\00\0B\85\00&1,\7F4\1A7\8D \01\00\1C\06\00\04\166\22( 6:Y:\04B\01\114B\01\005\01\135\07\063538\08\00\02\A5\01A540}\AC'Ad80]~\01\03\FE4%v4~\01#81z\01\0FE\00\09\0FA\01\01\00\07\00\1F4\D5\00\01\00\07\00O2048\D5\00\01$80\1C\00\01\DB\01\11l\FF\0FE %p6S\001124\DB\01\196\DB\01\07\05\01\C07:\0Aret;\0A\0A}\0A\00\00\00\00">]
+  llvm.func @mgpuTensorMapEncodeTiledMemref(i64, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr) -> !llvm.ptr
+  llvm.func @mgpuStreamCreate() -> !llvm.ptr
+  llvm.func @mgpuMemAlloc(i64, !llvm.ptr, i8) -> !llvm.ptr
+  llvm.func @mgpuMemcpy(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr)
+}
+

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
index bca3cb1f9a1e07..025282ec0d688f 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index c8dc45ab861d16..35ca0ee8677cca 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -convert-linalg-to-loops \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index bc3437b6545d71..5a10bbba26d8cf 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  -convert-linalg-to-loops \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index 65f301968669aa..9c5aacf96b0d69 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index fdbb188c28a9c7..536e71d260f568 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index ed58504cfdb106..aee265e3faf175 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:  -gpu-lower-to-nvvm="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
 // RUN:  | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \

diff  --git a/mlir/test/Integration/GPU/CUDA/two-modules.mlir b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
index f68359d78c0475..db4b365dd85d33 100644
--- a/mlir/test/Integration/GPU/CUDA/two-modules.mlir
+++ b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%mlir_cuda_runtime \
 // RUN:   --shared-libs=%mlir_runner_utils \


        


More information about the Mlir-commits mailing list