[Mlir-commits] [mlir] 5618d2b - [mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffers for sparse tensors to support debugging.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 8 09:54:37 PST 2022


Author: bixia1
Date: 2022-11-08T09:54:33-08:00
New Revision: 5618d2bea965e127f8d72a2e8fce1b444ce986bf

URL: https://github.com/llvm/llvm-project/commit/5618d2bea965e127f8d72a2e8fce1b444ce986bf
DIFF: https://github.com/llvm/llvm-project/commit/5618d2bea965e127f8d72a2e8fce1b444ce986bf.diff

LOG: [mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffers for sparse tensors to support debugging.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137592

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 97030f58f2a85..5d51c5bd04403 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -63,6 +63,10 @@ struct SparseCompilerOptions
       *this, "test-bufferization-analysis-only",
       desc("Run only the inplacability analysis"), init(false)};
 
+  PassOptions::Option<bool> enableBufferInitialization{
+      *this, "enable-buffer-initialization",
+      desc("Enable zero-initialization of memory buffers"), init(false)};
+
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
     return SparsificationOptions(parallelization);

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 5e301c4e3de30..8d704dc65b723 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -153,8 +153,10 @@ std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);
 
-void populateSparseBufferRewriting(RewritePatternSet &patterns);
-std::unique_ptr<Pass> createSparseBufferRewritePass();
+void populateSparseBufferRewriting(RewritePatternSet &patterns,
+                                   bool enableBufferInitialization);
+std::unique_ptr<Pass>
+createSparseBufferRewritePass(bool enableBufferInitialization = false);
 
 //===----------------------------------------------------------------------===//
 // Registration.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 421706e171cbe..b7c4baa07e927 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -198,6 +198,10 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
+  let options = [
+    Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
+           "false", "Enable zero-initialization of the memory buffers">,
+  ];
 }
 
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index a16340dbc459e..82d060e5bd385 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -65,7 +65,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
         options.sparseTensorConversionOptions()));
   else
     pm.addPass(createSparseTensorCodegenPass());
-  pm.addPass(createSparseBufferRewritePass());
+  pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
   pm.addPass(createDenseBufferizationPass(
       getBufferizationOptions(/*analysisOnly=*/false)));
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0af92a656d848..d0564cabad314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -635,6 +635,8 @@ namespace {
 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
 public:
   using OpRewritePattern<PushBackOp>::OpRewritePattern;
+  PushBackRewriter(MLIRContext *context, bool enableInit)
+      : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
   LogicalResult matchAndRewrite(PushBackOp op,
                                 PatternRewriter &rewriter) const override {
     // Rewrite push_back(buffer, value, n) to:
@@ -705,6 +707,16 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
 
       Value newBuffer =
           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+      if (enableBufferInitialization) {
+        Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
+        Value fillValue = rewriter.create<arith::ConstantOp>(
+            loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+        Value subBuffer = rewriter.create<memref::SubViewOp>(
+            loc, newBuffer, /*offset=*/ValueRange{newSize},
+            /*size=*/ValueRange{fillSize},
+            /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+        rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
+      }
       rewriter.create<scf::YieldOp>(loc, newBuffer);
 
       // False branch.
@@ -731,6 +743,9 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
     rewriter.replaceOp(op, buffer);
     return success();
   }
+
+private:
+  bool enableBufferInitialization;
 };
 
 /// Sparse rewriting rule for the sort operator.
@@ -777,6 +792,9 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 
-void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
-  patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
+                                         bool enableBufferInitialization) {
+  patterns.add<PushBackRewriter>(patterns.getContext(),
+                                 enableBufferInitialization);
+  patterns.add<SortRewriter>(patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 4a35a7f45419c..8bc132f317383 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -215,11 +215,14 @@ struct SparseBufferRewritePass
 
   SparseBufferRewritePass() = default;
   SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
+  SparseBufferRewritePass(bool enableInit) {
+    enableBufferInitialization = enableInit;
+  }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseBufferRewriting(patterns);
+    populateSparseBufferRewriting(patterns, enableBufferInitialization);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -279,6 +282,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
 
-std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
-  return std::make_unique<SparseBufferRewritePass>();
+std::unique_ptr<Pass>
+mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
+  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
 }


        


More information about the Mlir-commits mailing list