[Mlir-commits] [mlir] [mlir][sparse][gpu] cleanup GPUDataTransferStrategy (PR #71611)

Aart Bik llvmlistbot at llvm.org
Tue Nov 7 16:40:03 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/71611

Rationale:
The flag seemed to do very little different between zero cost and pinned dma. In addition, the register host is not truly the right zero cost flag. So we are simplifying the set up for now, until we have a better definition for what to test.

https://github.com/llvm/llvm-project/issues/64316

>From 26db6b65ad20bb64a137f7c58f454ab62b652fec Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 7 Nov 2023 16:36:59 -0800
Subject: [PATCH] [mlir][sparse][gpu] cleanup GPUDataTransferStrategy

Rationale:
The flag seemed to do very little different between zero
cost and pinned dma. In addition, the register host is not
truly the right zero cost flag. So we are simplifying the set
up for now, until we have a better definition for what to test.

https://github.com/llvm/llvm-project/issues/64316
---
 .../Dialect/SparseTensor/Pipelines/Passes.h   |  20 +-
 .../Dialect/SparseTensor/Transforms/Passes.h  |  20 +-
 .../Dialect/SparseTensor/Transforms/Passes.td |  13 -
 .../Transforms/SparseGPUCodegen.cpp           | 227 +++++-------------
 .../SparseTensor/GPU/CUDA/dump-ptx.mlir       |   5 +-
 .../sparse-matmul-2-4-lib-from-linalg.mlir    |  25 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir   |   6 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir |  23 +-
 .../GPU/CUDA/sparse-gemm-lib.mlir             |  25 +-
 .../GPU/CUDA/sparse-matmul-lib.mlir           |  12 +-
 .../GPU/CUDA/sparse-matvec-lib.mlir           |  10 +-
 .../GPU/CUDA/sparse-sampled-matmul-lib.mlir   |  16 +-
 .../GPU/CUDA/sparse-sddmm-lib.mlir            |  10 +-
 13 files changed, 118 insertions(+), 294 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 57d8ffb3566f8eb..4de83034b0386d1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -52,21 +52,6 @@ struct SparseCompilerOptions
               mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
               "any-storage-any-loop",
               "Enable sparse parallelization for any storage and loop."))};
-  PassOptions::Option<mlir::GPUDataTransferStrategy> gpuDataTransfer{
-      *this, "gpu-data-transfer-strategy",
-      ::llvm::cl::desc(
-          "Set the data transfer strategy between the host and the GPUs"),
-      ::llvm::cl::init(mlir::GPUDataTransferStrategy::kRegularDMA),
-      llvm::cl::values(
-          clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA, "regular-dma",
-                     "Default option: malloc on host without additional "
-                     "options or care and then use DMA to copy the data"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
-                     "Based on the default option, pin the host memory to "
-                     "accelerate the data transfer"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
-                     "Use zero-copy to perform the data transfer from the host "
-                     "to the GPU"))};
 
   PassOptions::Option<bool> enableIndexReduction{
       *this, "enable-index-reduction",
@@ -166,9 +151,8 @@ struct SparseCompilerOptions
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
-    return SparsificationOptions(parallelization, gpuDataTransfer,
-                                 enableIndexReduction, enableGPULibgen,
-                                 enableRuntimeLibrary);
+    return SparsificationOptions(parallelization, enableIndexReduction,
+                                 enableGPULibgen, enableRuntimeLibrary);
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index a8d4d752dff8882..9c9387c4d0d5c56 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,10 +47,6 @@ enum class ReinterpretMapScope {
   kExceptGeneric, // reinterprets operation other than linalg.generic
 };
 
-/// Defines data movement strategy between host and device for GPU.
-// TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
-enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
-
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
@@ -78,18 +74,14 @@ std::unique_ptr<Pass> createPreSparsificationRewritePass();
 
 /// Options for the Sparsification pass.
 struct SparsificationOptions {
-  SparsificationOptions(SparseParallelizationStrategy p,
-                        GPUDataTransferStrategy t, bool idxReduc,
+  SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc,
                         bool gpuLibgen, bool enableRT)
-      : parallelizationStrategy(p), gpuDataTransferStrategy(t),
-        enableIndexReduction(idxReduc), enableGPULibgen(gpuLibgen),
-        enableRuntimeLibrary(enableRT) {}
+      : parallelizationStrategy(p), enableIndexReduction(idxReduc),
+        enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {}
   SparsificationOptions()
-      : SparsificationOptions(SparseParallelizationStrategy::kNone,
-                              GPUDataTransferStrategy::kRegularDMA, false,
+      : SparsificationOptions(SparseParallelizationStrategy::kNone, false,
                               false, true) {}
   SparseParallelizationStrategy parallelizationStrategy;
-  GPUDataTransferStrategy gpuDataTransferStrategy;
   bool enableIndexReduction;
   bool enableGPULibgen;
   bool enableRuntimeLibrary;
@@ -201,8 +193,8 @@ std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
 void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
                                       unsigned numThreads);
 
-void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT,
-                                     GPUDataTransferStrategy gpuDataTransfer);
+void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
+                                     bool enableRT);
 
 std::unique_ptr<Pass> createSparseGPUCodegenPass();
 std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 485b44a1d6a86c5..bf4c33ac61e96e0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -134,19 +134,6 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
              clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
                         "any-storage-any-loop",
                         "Enable sparse parallelization for any storage and loop."))}]>,
-    Option<"gpuDataTransfer", "gpu-data-transfer-strategy", "mlir::GPUDataTransferStrategy",
-            "mlir::GPUDataTransferStrategy::kRegularDMA",
-            "Set the data transfer strategy", [{llvm::cl::values(
-               clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA,
-                     "regular-dma",
-                     "Default option: malloc on host without additional "
-                     "options or care and then use DMA to copy the data"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
-                     "Based on the default option, pin the host memory to "
-                     "accelerate the data transfer"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
-                     "Use zero-copy to perform the data transfer from the host "
-                     "to the GPU"))}]>,
     Option<"enableGPULibgen", "enable-gpu-libgen", "bool",
            "false",
            "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 7f32dd1449076ff..c241d02f0f852d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This is a prototype GPU codegenerator for the sparsifier.
+// This is a prototype GPU codegenerator for the sparse compiler.
 // The objective is to eventually use the right combination of
 // direct code generation and libary calls into vendor-specific
 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
@@ -535,18 +535,14 @@ static Operation *genSpMat(OpBuilder &builder, Location loc,
 }
 
 /// Match and rewrite SpMV kernel.
-static LogicalResult
-rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-            GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value x = op.getOperand(1);
   Value y = op.getOperand(2); // we have y = Ax
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // Only admissible sparse matrix format and dense vectors (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType xTp = getSparseTensorType(x);
@@ -563,29 +559,14 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
   Value memV = genToValues(rewriter, loc, a);
-  Value memX, memY;
-  Value castR, castC, castV, castX, castY;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    memX = genTensorToMemref(rewriter, loc, x);
-    memY = genTensorToMemref(rewriter, loc, y);
-    castR = genHostRegisterMemref(rewriter, loc, memR);
-    if (memC)
-      castC = genHostRegisterMemref(rewriter, loc, memC);
-    castV = genHostRegisterMemref(rewriter, loc, memV);
-    castX = genHostRegisterMemref(rewriter, loc, memX);
-    castY = genHostRegisterMemref(rewriter, loc, memY);
-  }
-
   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    memX = genTensorToMemref(rewriter, loc, x);
-  Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    memY = genTensorToMemref(rewriter, loc, y);
+  Value memX = genTensorToMemref(rewriter, loc, x);
+  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
+  Value memY = genTensorToMemref(rewriter, loc, y);
   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -638,21 +619,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     token = genDeallocMemRef(rewriter, loc, colA, token);
   token = genDeallocMemRef(rewriter, loc, valA, token);
   token = genDeallocMemRef(rewriter, loc, buffer, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, vecX, token);
+  token = genDeallocMemRef(rewriter, loc, vecX, token);
   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
   token = genDeallocMemRef(rewriter, loc, vecY, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castR);
-    if (memC)
-      genHostUnregisterMemref(rewriter, loc, castC);
-    genHostUnregisterMemref(rewriter, loc, castV);
-    genHostUnregisterMemref(rewriter, loc, castX);
-    genHostUnregisterMemref(rewriter, loc, castY);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
@@ -660,18 +632,14 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 /// Match and rewrite SpMM kernel.
-static LogicalResult
-rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-            GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
   Value c = op.getOperand(2); // we have C = AB
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // Only admissible sparse matrix format and dense matrices (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
@@ -682,35 +650,21 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 
   // Start sparse kernel and copy data from host to device.
   //   a : memR/memC/memV -> rowA,colA,valA
-  //   b : bufB           -> matA
+  //   b : bufB           -> matB
   //   c : bufC           -> matC
   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
   Value memV = genToValues(rewriter, loc, a);
-  Value bufB, bufC;
-  Value castR, castC, castV, castB, castBufC;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    bufB = genTensorToMemref(rewriter, loc, b);
-    bufC = genTensorToMemref(rewriter, loc, c);
-    castR = genHostRegisterMemref(rewriter, loc, memR);
-    if (memC)
-      castC = genHostRegisterMemref(rewriter, loc, memC);
-    castV = genHostRegisterMemref(rewriter, loc, memV);
-    castB = genHostRegisterMemref(rewriter, loc, bufB);
-    castBufC = genHostRegisterMemref(rewriter, loc, bufC);
-  }
   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    bufB = genTensorToMemref(rewriter, loc, b);
-  Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    bufC = genTensorToMemref(rewriter, loc, c);
+  Value bufB = genTensorToMemref(rewriter, loc, b);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value bufC = genTensorToMemref(rewriter, loc, c);
   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -766,21 +720,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     token = genDeallocMemRef(rewriter, loc, colA, token);
   token = genDeallocMemRef(rewriter, loc, valA, token);
   token = genDeallocMemRef(rewriter, loc, buffer, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
   token = genDeallocMemRef(rewriter, loc, matC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castR);
-    if (memC)
-      genHostUnregisterMemref(rewriter, loc, castC);
-    genHostUnregisterMemref(rewriter, loc, castV);
-    genHostUnregisterMemref(rewriter, loc, castB);
-    genHostUnregisterMemref(rewriter, loc, castC);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
@@ -788,9 +733,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 // Match and rewrite SpGEMM kernel.
-static LogicalResult
-rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-              GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
+                                   linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
@@ -816,10 +760,10 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
   Value amemV = genToValues(rewriter, loc, a);
   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
-  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
+  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
   Value bmemV = genToValues(rewriter, loc, b);
   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
@@ -966,40 +910,27 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 // Match and rewrite 2:4 SpMM kernel.
-static LogicalResult
-rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
-                GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
+                                     linalg::GenericOp op) {
   Location loc = op.getLoc();
   Value A = op.getOperand(0);
   Value B = op.getOperand(1);
   Value C = op.getOperand(2); // we have C = AB
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // All input should be dense tensors.
   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
     return failure();
 
-  Value matA, matB;
+  // Start sparse kernel and copy data from host to device.
+  //   a : bufA -> matA
+  //   b : bufB -> matB
+  //   c : bufC -> matC
   Value bufA = genTensorToMemref(rewriter, loc, A);
-  if (!isZeroCopy)
-    matA = genAllocCopy(rewriter, loc, bufA, tokens);
+  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
   Value bufB = genTensorToMemref(rewriter, loc, B);
-  if (!isZeroCopy)
-    matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
   Value bufC = genTensorToMemref(rewriter, loc, C);
-  Value castA, castB, castC;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    castA = genHostRegisterMemref(rewriter, loc, bufA);
-    castB = genHostRegisterMemref(rewriter, loc, bufB);
-    castC = genHostRegisterMemref(rewriter, loc, bufC);
-  }
-  if (isZeroCopy) {
-    matA = bufA;
-    matB = bufB;
-  }
   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -1039,27 +970,25 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
       /*computeType=*/dmatCType);
   token = bufferComp.getAsyncToken();
 
-  Value bufferSz = bufferComp.getResult(0);
-  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
-  Value buffer = buf.getResult(0);
-  token = buf.getAsyncToken();
-
+  // Allocate buffers on host.
+  Value bufferSz1 = bufferComp.getResult(0);
+  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
+  Value buffer1 = buf1.getResult(0);
+  token = buf1.getAsyncToken();
   Value bufferSz2 = bufferComp.getResult(1);
   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
   Value buffer2 = buf2.getResult(0);
   token = buf2.getAsyncToken();
-
   Value bufferSz3 = bufferComp.getResult(2);
   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
   Value buffer3 = buf3.getResult(0);
   token = buf3.getAsyncToken();
 
-  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
-
   // Perform the SpMM.
+  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
   auto spmmComp = rewriter.create<gpu::SpMMOp>(
       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
-      SmallVector<Value>{buffer, buffer2, buffer3});
+      SmallVector<Value>{buffer1, buffer2, buffer3});
   token = spmmComp.getAsyncToken();
 
   // Copy data back to host and free all the resources.
@@ -1070,23 +999,16 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
               .getAsyncToken();
   SmallVector<Value> newDynamicSizes;
-  token = genDeallocMemRef(rewriter, loc, buffer, token);
+  token = genDeallocMemRef(rewriter, loc, buffer1, token);
   token = genDeallocMemRef(rewriter, loc, buffer2, token);
   token = genDeallocMemRef(rewriter, loc, buffer3, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matA, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, matA, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
   token = genDeallocMemRef(rewriter, loc, matC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castA);
-    genHostUnregisterMemref(rewriter, loc, castB);
-    genHostUnregisterMemref(rewriter, loc, castC);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
@@ -1094,18 +1016,14 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
 }
 
 /// Match and rewrite SDDMM kernel.
-static LogicalResult
-rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-             GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
+                                  linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
   Value c = op.getOperand(2);
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
@@ -1118,35 +1036,19 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   // The SDDMM does the in-place operation.
   // Start sparse kernel and copy data from host to device.
   //   a : bufA           -> matA
-  //   b : bufB           -> matA
+  //   b : bufB           -> matB
   //   c : memR/memC/memV -> rowC,colC,valC
   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
-  Value matA, matB;
   Value bufA = genTensorToMemref(rewriter, loc, a);
-  if (!isZeroCopy)
-    matA = genAllocCopy(rewriter, loc, bufA, tokens);
+  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
   Value bufB = genTensorToMemref(rewriter, loc, b);
-  if (!isZeroCopy)
-    matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
   Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
   Value memV = genToValues(rewriter, loc, c);
-  Value castB, castA, castR, castC, castV;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    castB = genHostRegisterMemref(rewriter, loc, bufB);
-    castA = genHostRegisterMemref(rewriter, loc, bufA);
-    castR = genHostRegisterMemref(rewriter, loc, memR);
-    if (memC)
-      castC = genHostRegisterMemref(rewriter, loc, memC);
-    castV = genHostRegisterMemref(rewriter, loc, memV);
-  }
-  if (isZeroCopy) {
-    matA = bufA;
-    matB = bufB;
-  }
   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
@@ -1196,10 +1098,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
               .getAsyncToken();
   token = genDeallocMemRef(rewriter, loc, buffer, token);
-  if (!isZeroCopy) {
-    token = genDeallocMemRef(rewriter, loc, matA, token);
-    token = genDeallocMemRef(rewriter, loc, matB, token);
-  }
+  token = genDeallocMemRef(rewriter, loc, matA, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
   token = genDeallocMemRef(rewriter, loc, rowC, token);
   if (colC)
     token = genDeallocMemRef(rewriter, loc, colC, token);
@@ -1208,14 +1108,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castB);
-    genHostUnregisterMemref(rewriter, loc, castA);
-    genHostUnregisterMemref(rewriter, loc, castR);
-    if (memC)
-      genHostUnregisterMemref(rewriter, loc, castC);
-    genHostUnregisterMemref(rewriter, loc, castV);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
@@ -1227,7 +1119,7 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 //===----------------------------------------------------------------------===//
 
 /// Proof-of-concept rewriter. This rule generates a GPU implementation
-/// for each outermost forall loop generated by the sparsifier.
+/// for each outermost forall loop generated by the sparse compiler.
 /// TODO: right now works with parallelization-strategy=dense-outer-loop
 ///       but give this its own flags in the future
 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
@@ -1239,7 +1131,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
                                 PatternRewriter &rewriter) const override {
     // Reject inadmissible loop form.
-    // Essentially only accept a loop, generated by the sparsifier,
+    // Essentially only accept a loop, generated by the sparse compiler,
     // of the form
     //   forall (i = 0; i < N; i++)
     // so that cyclic scheduling over the threads is easy.
@@ -1333,8 +1225,8 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
 
-  LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t)
-      : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {}
+  LinalgOpRewriter(MLIRContext *context, bool rt)
+      : OpRewritePattern(context), enableRT(rt) {}
 
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1359,7 +1251,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isReductionIterator(iteratorTypes[1]) &&
         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
-      return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
+      return rewriteSpMV(rewriter, op, enableRT);
     }
 
     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
@@ -1369,10 +1261,10 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         linalg::isReductionIterator(iteratorTypes[2]) &&
         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
-        return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy);
+        return rewriteSpGEMM(rewriter, op, enableRT);
       if (op->getAttr("DENSE24"))
-        return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);
-      return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
+        return rewrite2To4SpMM(rewriter, op);
+      return rewriteSpMM(rewriter, op, enableRT);
     }
 
     // Recognize a SDDMM kernel.
@@ -1382,7 +1274,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         linalg::isReductionIterator(iteratorTypes[2]) &&
         maps == infer({{i, k}, {k, j}, {i, j}}) &&
         matchSumReductionOfMulUnary(op)) {
-      return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);
+      return rewriteSDDMM(rewriter, op, enableRT);
     }
 
     return failure();
@@ -1390,7 +1282,6 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
 
 private:
   bool enableRT;
-  GPUDataTransferStrategy gpuDataTransferStrategy;
 };
 
 } // namespace
@@ -1410,9 +1301,7 @@ void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
 }
 
-void mlir::populateSparseGPULibgenPatterns(
-    RewritePatternSet &patterns, bool enableRT,
-    GPUDataTransferStrategy gpuDataTransfer) {
-  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT,
-                                 gpuDataTransfer);
+void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
+                                           bool enableRT) {
+  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
 }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
index 0cb06b7bf1d2001..4483d18231e80ed 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir
@@ -1,11 +1,10 @@
 // RUN: mlir-opt %s \
-// RUN: | mlir-opt -test-lower-to-nvvm -debug-only=serialize-to-isa \
-// RUN: 2>&1 | FileCheck %s
+// RUN:  | mlir-opt -test-lower-to-nvvm -debug-only=serialize-to-isa \
+// RUN:  2>&1 | FileCheck %s
 
 // CHECK: Generated by LLVM NVPTX Back-End
 // CHECK: .visible .func kernel_a()
 // CHECK: ret;
-
 gpu.module @bar {
   llvm.func @kernel_a()
     attributes  { gpu.kernel } {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
index 67c8ce8dfa3004f..ccf855f3cf8aad7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
@@ -1,18 +1,20 @@
-//
 // NOTE: this test requires gpu-sm80 and cusparselt
 //
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
 // DEFINE: %{run} = mlir-cpu-runner \
-// DEFINE: --shared-libs=%mlir_cuda_runtime \
-// DEFINE: --shared-libs=%mlir_c_runner_utils \
-// DEFINE: --e main --entry-point-result=void \
+// DEFINE:   --shared-libs=%mlir_cuda_runtime \
+// DEFINE:   --shared-libs=%mlir_c_runner_utils \
+// DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
-
-//  RUN:  %{compile}" | %{run}
-//  RUN:  %{compile} gpu-data-transfer-strategy=pinned-dma" | %{run}
-//  Tracker #64316
-//  RUNNOT: %{compile} gpu-data-transfer-strategy=zero-copy" | %{run}
+//
+// with RT lib:
+//
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
+//
+// without RT lib:
+//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
 
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
@@ -34,7 +36,6 @@ module {
     } -> tensor<16x16xf16>
     return %0 : tensor<16x16xf16>
   }
-  
 
   //
   // This test performs a matrix multiplication
@@ -195,7 +196,7 @@ module {
       %pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
       vector.print %pc0 : vector<16xf16>
     }
-    
+
     llvm.call @mgpuDestroySparseLtEnv() : () -> ()
     return
   }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 1c9e6a956c0b284..daf29d5290bab0b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -1,4 +1,3 @@
-//
 // NOTE: this test requires gpu-sm80 and cusparselt
 //
 // DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
@@ -9,11 +8,8 @@
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
 // DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
-
+//
 // RUN: %{compile} | %{run}
-// RUN: %{compile} --sparse-compiler="gpu-data-transfer-strategy=pinned-dma" | %{run}
-// RUNNOT: %{compile} --sparse-compiler="gpu-data-transfer-strategy=zero-copy" | %{run}
-
 
 module {
   llvm.func @mgpuCreateSparseLtEnv()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index 8917ab1e5a70d71..4f308fba3b1496e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -1,13 +1,20 @@
-//
 // NOTE: this test requires gpu-sm80 and cusparselt
 //
-// RUN: mlir-opt --sparse-compiler="enable-runtime-library=false enable-gpu-libgen=true gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format" \
-// RUN:          %s \
-// RUN: | mlir-cpu-runner \
-// RUN:   --shared-libs=%mlir_cuda_runtime \
-// RUN:   --shared-libs=%mlir_c_runner_utils \
-// RUN:   --e main --entry-point-result=void \
-// RUN: | FileCheck %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:   --shared-libs=%mlir_cuda_runtime \
+// DEFINE:   --shared-libs=%mlir_c_runner_utils \
+// DEFINE:   --e main --entry-point-result=void \
+// DEFINE: | FileCheck %s
+//
+// with RT lib:
+//
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
+//
+// without RT lib:
+//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
index 258ea13e60c07c6..6c53de4a0568691 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
@@ -1,25 +1,20 @@
-//
 // NOTE: this test requires gpu-sm80
 //
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:   --shared-libs=%mlir_cuda_runtime \
+// DEFINE:   --shared-libs=%mlir_c_runner_utils \
+// DEFINE:   --e main --entry-point-result=void \
+// DEFINE: | FileCheck %s
+//
 // with RT lib:
 //
-// RUN: mlir-opt %s \
-// RUN:   --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format"  \
-// RUN: | mlir-cpu-runner \
-// RUN:   --shared-libs=%mlir_cuda_runtime \
-// RUN:   --shared-libs=%mlir_c_runner_utils \
-// RUN:   --e main --entry-point-result=void \
-// RUN: | FileCheck %s
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
 //
 // without RT lib:
 //
-// RUN: mlir-opt %s \
-// RUN:   --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format"  \
-// RUN: | mlir-cpu-runner \
-// RUN:   --shared-libs=%mlir_cuda_runtime \
-// RUN:   --shared-libs=%mlir_c_runner_utils \
-// RUN:   --e main --entry-point-result=void \
-// RUN: | FileCheck %s
+// RUN: %{compile} enable-runtime-library=false" | %{run}
 
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
index 9e397e0ad5b5dc2..5b44400d8362bcb 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
@@ -1,28 +1,20 @@
-//
 // NOTE: this test requires gpu-sm80
 //
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:    --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
 // DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
 //
-//
 // with RT lib (SoA COO):
 //
-// RUN:  %{compile} enable-runtime-library=true" | %{run}
-// RUN:  %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
-// Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
 //
 // without RT lib (AoS COO): note, may fall back to CPU
 //
 // RUN: %{compile} enable-runtime-library=false" | %{run}
-// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run}
-// Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run}
 
 #SortedCOO = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
index b569806b4028a8a..3b85ca3a275b782 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir
@@ -1,4 +1,3 @@
-//
 // NOTE: this test requires gpu-sm80
 //
 // DEFINE: %{compile} = mlir-opt %s \
@@ -12,17 +11,10 @@
 // with RT lib (SoA COO):
 //
 // RUN: %{compile} enable-runtime-library=true"  | %{run}
-// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
-// Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy"  | %{run}
 //
 // without RT lib (AoS COO): note, may fall back to CPU
 //
-// RUN: %{compile} enable-runtime-library=false"  | %{run}
-// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run}
-// Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy"  | %{run}
-//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
 
 #SortedCOO = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
index e79696ac4c047ca..a86cea783dd6f2d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
@@ -1,4 +1,3 @@
-//
 // NOTE: this test requires gpu-sm80
 //
 // DEFINE: %{compile} = mlir-opt %s \
@@ -8,23 +7,16 @@
 // DEFINE:   mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
-// DEFINE:   --e entry --entry-point-result=void \
+// DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
 //
 // with RT lib:
 //
-// RUN:  %{compile} enable-runtime-library=true" | %{run}
-// RUN:  %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
-// TODO: Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
 //
 // without RT lib:
 //
-// RUN:  %{compile} enable-runtime-library=false" | %{run}
-// RUN:  %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run}
-// TODO:  Tracker #64316
-// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run}
-//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
 
 !Filename = !llvm.ptr
 
@@ -85,7 +77,7 @@ module {
   //
   // Main driver.
   //
-  func.func @entry() {
+  func.func @main() {
     llvm.call @mgpuCreateSparseEnv() : () -> ()
     %d0 = arith.constant 0.0 : f32
     %c0 = arith.constant 0 : index
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir
index c1062dd4ee3e938..735dc8cb4bb3611 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir
@@ -1,4 +1,3 @@
-//
 // NOTE: this test requires gpu-sm80
 //
 // DEFINE: %{compile} = mlir-opt %s \
@@ -8,18 +7,17 @@
 // DEFINE:   mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
-// DEFINE:   --e entry --entry-point-result=void \
+// DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
 //
 // with RT lib:
 //
-// RUN:  %{compile} enable-runtime-library=true" | %{run}
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
 //
 // without RT lib:
 //
 // TODO: make this work
-// R_UN:  %{compile} enable-runtime-library=false" | %{run}
-//
+// R_U_N: %{compile} enable-runtime-library=false" | %{run}
 
 !Filename = !llvm.ptr
 
@@ -117,7 +115,7 @@ module {
   //
   // Main driver.
   //
-  func.func @entry() {
+  func.func @main() {
     llvm.call @mgpuCreateSparseEnv() : () -> ()
     %d0 = arith.constant 0.0 : f32
     %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list