[Mlir-commits] [mlir] [mlir][sparse] Extract `StorageSpecifierToLLVMPass` from bufferization pipeline (PR #68635)
Matthias Springer
llvmlistbot at llvm.org
Mon Oct 9 17:54:15 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/68635
>From 11878d6f25af22746bc0b8843910fb979c4b3fc5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 9 Oct 2023 23:31:22 +0200
Subject: [PATCH] [mlir][sparse] Extract `StorageSpecifierToLLVMPass` from
bufferization pipeline
`StorageSpecifierToLLVMPass` does not have to be part of the bufferization mini pipeline. It can run after the bufferization pipeline. This is desirable because it keeps the bufferization pipeline smaller.
---
.../SparseTensor/Pipelines/SparseTensorPipelines.cpp | 2 ++
.../Transforms/SparsificationAndBufferizationPass.cpp | 6 +++---
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 54069064839eac3..7569413546c0a6e 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -42,6 +42,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
/*enableSIMDIndex32=*/options.force32BitVectorIndices));
if (options.testBufferizationAnalysisOnly)
return;
+
+ pm.addPass(createStorageSpecifierToLLVMPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
mlir::bufferization::createFinalizingBufferizePass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 6fca8f82e356626..480e18e257277de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -39,7 +39,7 @@ namespace sparse_tensor {
/// Return `true` if one of the given types is a sparse tensor type.
static bool containsSparseTensor(TypeRange types) {
for (Type t : types)
- if (getSparseTensorEncoding(t))
+ if (isa<TensorType>(t) && getSparseTensorEncoding(t))
return true;
return false;
}
@@ -97,7 +97,8 @@ class SparsificationAndBufferizationPass
return false;
});
- if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
+ if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
+ updatedOptions)))
return failure();
bufferization::removeBufferizationAttributesInModule(getOperation());
@@ -154,7 +155,6 @@ class SparsificationAndBufferizationPass
pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
enableBufferInitialization));
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
- pm.addPass(createStorageSpecifierToLLVMPass());
}
if (failed(runPipeline(pm, getOperation())))
return signalPassFailure();
More information about the Mlir-commits
mailing list