[Mlir-commits] [mlir] [mlir][sparse][gpu] cleanup GPUDataTransferStrategy (PR #71611)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 16:40:33 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

Rationale:
The flag seemed to do very little different between zero cost and pinned dma. In addition, the register host is not truly the right zero cost flag. So we are simplifying the set up for now, until we have a better definition for what to test.

https://github.com/llvm/llvm-project/issues/64316

---

Patch is 42.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71611.diff


13 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (+2-18) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+6-14) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (-13) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+58-169) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/dump-ptx.mlir (+2-3) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir (+13-12) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir (+1-5) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir (+15-8) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir (+10-15) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir (+2-10) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir (+1-9) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir (+4-12) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir (+4-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 57d8ffb3566f8eb..4de83034b0386d1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -52,21 +52,6 @@ struct SparseCompilerOptions
               mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
               "any-storage-any-loop",
               "Enable sparse parallelization for any storage and loop."))};
-  PassOptions::Option<mlir::GPUDataTransferStrategy> gpuDataTransfer{
-      *this, "gpu-data-transfer-strategy",
-      ::llvm::cl::desc(
-          "Set the data transfer strategy between the host and the GPUs"),
-      ::llvm::cl::init(mlir::GPUDataTransferStrategy::kRegularDMA),
-      llvm::cl::values(
-          clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA, "regular-dma",
-                     "Default option: malloc on host without additional "
-                     "options or care and then use DMA to copy the data"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
-                     "Based on the default option, pin the host memory to "
-                     "accelerate the data transfer"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
-                     "Use zero-copy to perform the data transfer from the host "
-                     "to the GPU"))};
 
   PassOptions::Option<bool> enableIndexReduction{
       *this, "enable-index-reduction",
@@ -166,9 +151,8 @@ struct SparseCompilerOptions
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
-    return SparsificationOptions(parallelization, gpuDataTransfer,
-                                 enableIndexReduction, enableGPULibgen,
-                                 enableRuntimeLibrary);
+    return SparsificationOptions(parallelization, enableIndexReduction,
+                                 enableGPULibgen, enableRuntimeLibrary);
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index a8d4d752dff8882..9c9387c4d0d5c56 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,10 +47,6 @@ enum class ReinterpretMapScope {
   kExceptGeneric, // reinterprets operation other than linalg.generic
 };
 
-/// Defines data movement strategy between host and device for GPU.
-// TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
-enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
-
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
@@ -78,18 +74,14 @@ std::unique_ptr<Pass> createPreSparsificationRewritePass();
 
 /// Options for the Sparsification pass.
 struct SparsificationOptions {
-  SparsificationOptions(SparseParallelizationStrategy p,
-                        GPUDataTransferStrategy t, bool idxReduc,
+  SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc,
                         bool gpuLibgen, bool enableRT)
-      : parallelizationStrategy(p), gpuDataTransferStrategy(t),
-        enableIndexReduction(idxReduc), enableGPULibgen(gpuLibgen),
-        enableRuntimeLibrary(enableRT) {}
+      : parallelizationStrategy(p), enableIndexReduction(idxReduc),
+        enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {}
   SparsificationOptions()
-      : SparsificationOptions(SparseParallelizationStrategy::kNone,
-                              GPUDataTransferStrategy::kRegularDMA, false,
+      : SparsificationOptions(SparseParallelizationStrategy::kNone, false,
                               false, true) {}
   SparseParallelizationStrategy parallelizationStrategy;
-  GPUDataTransferStrategy gpuDataTransferStrategy;
   bool enableIndexReduction;
   bool enableGPULibgen;
   bool enableRuntimeLibrary;
@@ -201,8 +193,8 @@ std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
 void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
                                       unsigned numThreads);
 
-void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT,
-                                     GPUDataTransferStrategy gpuDataTransfer);
+void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
+                                     bool enableRT);
 
 std::unique_ptr<Pass> createSparseGPUCodegenPass();
 std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 485b44a1d6a86c5..bf4c33ac61e96e0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -134,19 +134,6 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
              clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
                         "any-storage-any-loop",
                         "Enable sparse parallelization for any storage and loop."))}]>,
-    Option<"gpuDataTransfer", "gpu-data-transfer-strategy", "mlir::GPUDataTransferStrategy",
-            "mlir::GPUDataTransferStrategy::kRegularDMA",
-            "Set the data transfer strategy", [{llvm::cl::values(
-               clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA,
-                     "regular-dma",
-                     "Default option: malloc on host without additional "
-                     "options or care and then use DMA to copy the data"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
-                     "Based on the default option, pin the host memory to "
-                     "accelerate the data transfer"),
-          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
-                     "Use zero-copy to perform the data transfer from the host "
-                     "to the GPU"))}]>,
     Option<"enableGPULibgen", "enable-gpu-libgen", "bool",
            "false",
            "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 7f32dd1449076ff..c241d02f0f852d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This is a prototype GPU codegenerator for the sparsifier.
+// This is a prototype GPU codegenerator for the sparse compiler.
 // The objective is to eventually use the right combination of
 // direct code generation and libary calls into vendor-specific
 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
@@ -535,18 +535,14 @@ static Operation *genSpMat(OpBuilder &builder, Location loc,
 }
 
 /// Match and rewrite SpMV kernel.
-static LogicalResult
-rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-            GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value x = op.getOperand(1);
   Value y = op.getOperand(2); // we have y = Ax
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // Only admissible sparse matrix format and dense vectors (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType xTp = getSparseTensorType(x);
@@ -563,29 +559,14 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
   Value memV = genToValues(rewriter, loc, a);
-  Value memX, memY;
-  Value castR, castC, castV, castX, castY;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    memX = genTensorToMemref(rewriter, loc, x);
-    memY = genTensorToMemref(rewriter, loc, y);
-    castR = genHostRegisterMemref(rewriter, loc, memR);
-    if (memC)
-      castC = genHostRegisterMemref(rewriter, loc, memC);
-    castV = genHostRegisterMemref(rewriter, loc, memV);
-    castX = genHostRegisterMemref(rewriter, loc, memX);
-    castY = genHostRegisterMemref(rewriter, loc, memY);
-  }
-
   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    memX = genTensorToMemref(rewriter, loc, x);
-  Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    memY = genTensorToMemref(rewriter, loc, y);
+  Value memX = genTensorToMemref(rewriter, loc, x);
+  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
+  Value memY = genTensorToMemref(rewriter, loc, y);
   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -638,21 +619,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     token = genDeallocMemRef(rewriter, loc, colA, token);
   token = genDeallocMemRef(rewriter, loc, valA, token);
   token = genDeallocMemRef(rewriter, loc, buffer, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, vecX, token);
+  token = genDeallocMemRef(rewriter, loc, vecX, token);
   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
   token = genDeallocMemRef(rewriter, loc, vecY, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castR);
-    if (memC)
-      genHostUnregisterMemref(rewriter, loc, castC);
-    genHostUnregisterMemref(rewriter, loc, castV);
-    genHostUnregisterMemref(rewriter, loc, castX);
-    genHostUnregisterMemref(rewriter, loc, castY);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
@@ -660,18 +632,14 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 /// Match and rewrite SpMM kernel.
-static LogicalResult
-rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-            GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
+                                 linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
   Value c = op.getOperand(2); // we have C = AB
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // Only admissible sparse matrix format and dense matrices (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
@@ -682,35 +650,21 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 
   // Start sparse kernel and copy data from host to device.
   //   a : memR/memC/memV -> rowA,colA,valA
-  //   b : bufB           -> matA
+  //   b : bufB           -> matB
   //   c : bufC           -> matC
   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
   Value memV = genToValues(rewriter, loc, a);
-  Value bufB, bufC;
-  Value castR, castC, castV, castB, castBufC;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    bufB = genTensorToMemref(rewriter, loc, b);
-    bufC = genTensorToMemref(rewriter, loc, c);
-    castR = genHostRegisterMemref(rewriter, loc, memR);
-    if (memC)
-      castC = genHostRegisterMemref(rewriter, loc, memC);
-    castV = genHostRegisterMemref(rewriter, loc, memV);
-    castB = genHostRegisterMemref(rewriter, loc, bufB);
-    castBufC = genHostRegisterMemref(rewriter, loc, bufC);
-  }
   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    bufB = genTensorToMemref(rewriter, loc, b);
-  Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
-  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
-    bufC = genTensorToMemref(rewriter, loc, c);
+  Value bufB = genTensorToMemref(rewriter, loc, b);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value bufC = genTensorToMemref(rewriter, loc, c);
   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -766,21 +720,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     token = genDeallocMemRef(rewriter, loc, colA, token);
   token = genDeallocMemRef(rewriter, loc, valA, token);
   token = genDeallocMemRef(rewriter, loc, buffer, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
   token = genDeallocMemRef(rewriter, loc, matC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    genHostUnregisterMemref(rewriter, loc, castR);
-    if (memC)
-      genHostUnregisterMemref(rewriter, loc, castC);
-    genHostUnregisterMemref(rewriter, loc, castV);
-    genHostUnregisterMemref(rewriter, loc, castB);
-    genHostUnregisterMemref(rewriter, loc, castC);
-  }
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
@@ -788,9 +733,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 // Match and rewrite SpGEMM kernel.
-static LogicalResult
-rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
-              GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
+                                   linalg::GenericOp op, bool enableRT) {
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
@@ -816,10 +760,10 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
-  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
+  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
   Value amemV = genToValues(rewriter, loc, a);
   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
-  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
+  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
   Value bmemV = genToValues(rewriter, loc, b);
   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
@@ -966,40 +910,27 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 }
 
 // Match and rewrite 2:4 SpMM kernel.
-static LogicalResult
-rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
-                GPUDataTransferStrategy gpuDataTransferStrategy) {
+static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
+                                     linalg::GenericOp op) {
   Location loc = op.getLoc();
   Value A = op.getOperand(0);
   Value B = op.getOperand(1);
   Value C = op.getOperand(2); // we have C = AB
   SmallVector<Value> tokens;
 
-  bool isZeroCopy =
-      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
-
   // All input should be dense tensors.
   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
     return failure();
 
-  Value matA, matB;
+  // Start sparse kernel and copy data from host to device.
+  //   a : bufA -> matA
+  //   b : bufB -> matB
+  //   c : bufC -> matC
   Value bufA = genTensorToMemref(rewriter, loc, A);
-  if (!isZeroCopy)
-    matA = genAllocCopy(rewriter, loc, bufA, tokens);
+  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
   Value bufB = genTensorToMemref(rewriter, loc, B);
-  if (!isZeroCopy)
-    matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
   Value bufC = genTensorToMemref(rewriter, loc, C);
-  Value castA, castB, castC;
-  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
-    castA = genHostRegisterMemref(rewriter, loc, bufA);
-    castB = genHostRegisterMemref(rewriter, loc, bufB);
-    castC = genHostRegisterMemref(rewriter, loc, bufC);
-  }
-  if (isZeroCopy) {
-    matA = bufA;
-    matB = bufB;
-  }
   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
@@ -1039,27 +970,25 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
       /*computeType=*/dmatCType);
   token = bufferComp.getAsyncToken();
 
-  Value bufferSz = bufferComp.getResult(0);
-  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
-  Value buffer = buf.getResult(0);
-  token = buf.getAsyncToken();
-
+  // Allocate buffers on host.
+  Value bufferSz1 = bufferComp.getResult(0);
+  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
+  Value buffer1 = buf1.getResult(0);
+  token = buf1.getAsyncToken();
   Value bufferSz2 = bufferComp.getResult(1);
   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
   Value buffer2 = buf2.getResult(0);
   token = buf2.getAsyncToken();
-
   Value bufferSz3 = bufferComp.getResult(2);
   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
   Value buffer3 = buf3.getResult(0);
   token = buf3.getAsyncToken();
 
-  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
-
   // Perform the SpMM.
+  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
   auto spmmComp = rewriter.create<gpu::SpMMOp>(
       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
-      SmallVector<Value>{buffer, buffer2, buffer3});
+      SmallVector<Value>{buffer1, buffer2, buffer3});
   token = spmmComp.getAsyncToken();
 
   // Copy data back to host and free all the resources.
@@ -1070,23 +999,16 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
               .getAsyncToken();
   SmallVector<Value> newDynamicSizes;
-  token = genDeallocMemRef(rewriter, loc, buffer, token);
+  token = genDeallocMemRef(rewriter, loc, buffer1, token);
   token = genDeallocMemRef(rewriter, loc, buffer2, token);
   token = genDeallocMemRef(rewriter, loc, buffer3, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matA, token);
-  if (!isZeroCopy)
-    token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, matA, token);
+  token = genDea...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71611


More information about the Mlir-commits mailing list