[Mlir-commits] [mlir] 5618d2b - [mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffers for sparse tensors to support debugging.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 8 09:54:37 PST 2022
Author: bixia1
Date: 2022-11-08T09:54:33-08:00
New Revision: 5618d2bea965e127f8d72a2e8fce1b444ce986bf
URL: https://github.com/llvm/llvm-project/commit/5618d2bea965e127f8d72a2e8fce1b444ce986bf
DIFF: https://github.com/llvm/llvm-project/commit/5618d2bea965e127f8d72a2e8fce1b444ce986bf.diff
LOG: [mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffers for sparse tensors to support debugging.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137592
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 97030f58f2a85..5d51c5bd04403 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -63,6 +63,10 @@ struct SparseCompilerOptions
*this, "test-bufferization-analysis-only",
desc("Run only the inplacability analysis"), init(false)};
+ PassOptions::Option<bool> enableBufferInitialization{
+ *this, "enable-buffer-initialization",
+ desc("Enable zero-initialization of memory buffers"), init(false)};
+
/// Projects out the options for `createSparsificationPass`.
SparsificationOptions sparsificationOptions() const {
return SparsificationOptions(parallelization);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 5e301c4e3de30..8d704dc65b723 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -153,8 +153,10 @@ std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
-void populateSparseBufferRewriting(RewritePatternSet &patterns);
-std::unique_ptr<Pass> createSparseBufferRewritePass();
+void populateSparseBufferRewriting(RewritePatternSet &patterns,
+ bool enableBufferInitialization);
+std::unique_ptr<Pass>
+createSparseBufferRewritePass(bool enableBufferInitialization = false);
//===----------------------------------------------------------------------===//
// Registration.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 421706e171cbe..b7c4baa07e927 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -198,6 +198,10 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
+ let options = [
+ Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
+ "false", "Enable zero-initialization of the memory buffers">,
+ ];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index a16340dbc459e..82d060e5bd385 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -65,7 +65,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
options.sparseTensorConversionOptions()));
else
pm.addPass(createSparseTensorCodegenPass());
- pm.addPass(createSparseBufferRewritePass());
+ pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
pm.addPass(createDenseBufferizationPass(
getBufferizationOptions(/*analysisOnly=*/false)));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0af92a656d848..d0564cabad314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -635,6 +635,8 @@ namespace {
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
using OpRewritePattern<PushBackOp>::OpRewritePattern;
+ PushBackRewriter(MLIRContext *context, bool enableInit)
+ : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
LogicalResult matchAndRewrite(PushBackOp op,
PatternRewriter &rewriter) const override {
// Rewrite push_back(buffer, value, n) to:
@@ -705,6 +707,16 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+ if (enableBufferInitialization) {
+ Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
+ Value fillValue = rewriter.create<arith::ConstantOp>(
+ loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+ Value subBuffer = rewriter.create<memref::SubViewOp>(
+ loc, newBuffer, /*offset=*/ValueRange{newSize},
+ /*size=*/ValueRange{fillSize},
+ /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+ rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
+ }
rewriter.create<scf::YieldOp>(loc, newBuffer);
// False branch.
@@ -731,6 +743,9 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
rewriter.replaceOp(op, buffer);
return success();
}
+
+private:
+ bool enableBufferInitialization;
};
/// Sparse rewriting rule for the sort operator.
@@ -777,6 +792,9 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
- patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
+ bool enableBufferInitialization) {
+ patterns.add<PushBackRewriter>(patterns.getContext(),
+ enableBufferInitialization);
+ patterns.add<SortRewriter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 4a35a7f45419c..8bc132f317383 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -215,11 +215,14 @@ struct SparseBufferRewritePass
SparseBufferRewritePass() = default;
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
+ SparseBufferRewritePass(bool enableInit) {
+ enableBufferInitialization = enableInit;
+ }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseBufferRewriting(patterns);
+ populateSparseBufferRewriting(patterns, enableBufferInitialization);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -279,6 +282,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
-std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
- return std::make_unique<SparseBufferRewritePass>();
+std::unique_ptr<Pass>
+mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
+ return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}
More information about the Mlir-commits
mailing list