[Mlir-commits] [mlir] c7e24db - [mlir][sparse] Introducing options for the SparseTensorConversion pass
wren romano
llvmlistbot at llvm.org
Tue Mar 22 13:11:17 PDT 2022
Author: wren romano
Date: 2022-03-22T13:11:09-07:00
New Revision: c7e24db412b34745ddaec7feb033b0f5cb4aecdf
URL: https://github.com/llvm/llvm-project/commit/c7e24db412b34745ddaec7feb033b0f5cb4aecdf
DIFF: https://github.com/llvm/llvm-project/commit/c7e24db412b34745ddaec7feb033b0f5cb4aecdf.diff
LOG: [mlir][sparse] Introducing options for the SparseTensorConversion pass
This is work towards: https://github.com/llvm/llvm-project/issues/51652
This differential sets up the options and threads them through everywhere, but doesn't actually use them yet. The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D122054
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 86b59b67b887d..782f3b443d428 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -49,6 +49,17 @@ struct SparseCompilerOptions
vectorLength, enableSIMDIndex32);
}
+ // These options must be kept in sync with `SparseTensorConversionBase`.
+ PassOptions::Option<int32_t> sparseToSparse{
+ *this, "s2s-strategy",
+ desc("Set the strategy for sparse-to-sparse conversion"), init(0)};
+
+ /// Projects out the options for `createSparsificationPass`.
+ SparseTensorConversionOptions sparseTensorConversionOptions() const {
+ return SparseTensorConversionOptions(
+ sparseToSparseConversionStrategy(sparseToSparse));
+ }
+
// These options must be kept in sync with `ConvertVectorToLLVMBase`.
// TODO(wrengr): does `indexOptimizations`
diff er from `enableSIMDIndex32`?
PassOptions::Option<bool> reassociateFPReductions{
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 96f9ea1bb9c94..1888b45fc8442 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -8,6 +8,12 @@
//
// This header file defines prototypes of all sparse tensor passes.
//
+// In general, this file takes the approach of keeping "mechanism" (the
+// actual steps of applying a transformation) completely separate from
+// "policy" (heuristics for when and where to apply transformations).
+// The only exception is in `SparseToSparseConversionStrategy`; for which,
+// see further discussion there.
+//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
@@ -21,6 +27,10 @@ namespace mlir {
// Forward.
class TypeConverter;
+//===----------------------------------------------------------------------===//
+// The Sparsification pass.
+//===----------------------------------------------------------------------===//
+
/// Defines a parallelization strategy. Any independent loop is a candidate
/// for parallelization. The loop is made parallel if (1) allowed by the
/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
@@ -51,7 +61,7 @@ enum class SparseVectorizationStrategy {
/// Converts command-line vectorization flag to the strategy enum.
SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag);
-/// Sparsification options.
+/// Options for the Sparsification pass.
struct SparsificationOptions {
SparsificationOptions(SparseParallelizationStrategy p,
SparseVectorizationStrategy v, unsigned vl, bool e)
@@ -71,14 +81,56 @@ void populateSparsificationPatterns(
RewritePatternSet &patterns,
const SparsificationOptions &options = SparsificationOptions());
-/// Sets up sparse tensor conversion rules.
-void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
std::unique_ptr<Pass> createSparsificationPass();
std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions &options);
+
+//===----------------------------------------------------------------------===//
+// The SparseTensorConversion pass.
+//===----------------------------------------------------------------------===//
+
+/// Defines a strategy for implementing sparse-to-sparse conversion.
+/// `kAuto` leaves it up to the compiler to automatically determine
+/// the method used. `kViaCOO` converts the source tensor to COO and
+/// then converts the COO to the target format. `kDirect` converts
+/// directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
+/// however, beware that there are many formats not supported by this
+/// conversion method.
+///
+/// The presence of the `kAuto` option violates our usual goal of keeping
+/// policy completely separated from mechanism. The reason it exists is
+/// because (at present) this strategy can only be specified on a per-file
+/// basis. To see why this is a problem, note that `kDirect` cannot
+/// support certain conversions; so if there is no `kAuto` setting,
+/// then whenever a file contains a single non-`kDirect`-able conversion
+/// the user would be forced to use `kViaCOO` for all conversions in
+/// that file! In the future, instead of using this enum as a `Pass`
+/// option, we could instead move it to being an attribute on the
+/// conversion op; at which point `kAuto` would no longer be necessary.
+enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect };
+
+/// Converts command-line sparse2sparse flag to the strategy enum.
+SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag);
+
+/// SparseTensorConversion options.
+struct SparseTensorConversionOptions {
+ SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s)
+ : sparseToSparseStrategy(s2s) {}
+ SparseTensorConversionOptions()
+ : SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) {
+ }
+ SparseToSparseConversionStrategy sparseToSparseStrategy;
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorConversionPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ const SparseTensorConversionOptions &options =
+ SparseTensorConversionOptions());
+
std::unique_ptr<Pass> createSparseTensorConversionPass();
+std::unique_ptr<Pass>
+createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
//===----------------------------------------------------------------------===//
// Registration.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 31b08af00ae37..89aacd69b67a0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -114,6 +114,10 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
"sparse_tensor::SparseTensorDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"sparseToSparse", "s2s-strategy", "int32_t", "0",
+ "Set the strategy for sparse-to-sparse conversion">,
+ ];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 1f5e26689e056..54dac3d7ec441 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -33,7 +33,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addNestedPass<FuncOp>(createLinalgGeneralizationPass());
pm.addPass(createLinalgElementwiseOpFusionPass());
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
- pm.addPass(createSparseTensorConversionPass());
+ pm.addPass(createSparseTensorConversionPass(
+ options.sparseTensorConversionOptions()));
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 17a07da564e25..11329f6abc7be 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -453,7 +453,18 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
+ /// Options to control sparse code generation.
+ SparseTensorConversionOptions options;
+
+public:
using OpConversionPattern::OpConversionPattern;
+ SparseTensorConvertConverter(MLIRContext *context,
+ SparseTensorConversionOptions o)
+ : OpConversionPattern<ConvertOp>(context), options(o) {}
+ SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
+ SparseTensorConversionOptions o)
+ : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
+
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -825,14 +836,17 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
-void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void mlir::populateSparseTensorConversionPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ const SparseTensorConversionOptions &options) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
- SparseTensorInitConverter, SparseTensorConvertConverter,
- SparseTensorReleaseConverter, SparseTensorToPointersConverter,
- SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
- SparseTensorLoadConverter, SparseTensorLexInsertConverter,
- SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorOutConverter>(typeConverter, patterns.getContext());
+ SparseTensorInitConverter, SparseTensorReleaseConverter,
+ SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
+ SparseTensorToValuesConverter, SparseTensorLoadConverter,
+ SparseTensorLexInsertConverter, SparseTensorExpandConverter,
+ SparseTensorCompressConverter, SparseTensorOutConverter>(
+ typeConverter, patterns.getContext());
+ patterns.add<SparseTensorConvertConverter>(typeConverter,
+ patterns.getContext(), options);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 2d8b8585e1ec2..2124aecc128da 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -73,6 +73,13 @@ class SparseTensorTypeConverter : public TypeConverter {
struct SparseTensorConversionPass
: public SparseTensorConversionBase<SparseTensorConversionPass> {
+
+ SparseTensorConversionPass() = default;
+ SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
+ SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
+ sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
+ }
+
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
@@ -106,11 +113,14 @@ struct SparseTensorConversionPass
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
memref::MemRefDialect, scf::SCFDialect>();
+ // Translate strategy flags to strategy options.
+ SparseTensorConversionOptions options(
+ sparseToSparseConversionStrategy(sparseToSparse));
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
- populateSparseTensorConversionPatterns(converter, patterns);
+ populateSparseTensorConversionPatterns(converter, patterns, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -146,6 +156,18 @@ SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
}
}
+SparseToSparseConversionStrategy
+mlir::sparseToSparseConversionStrategy(int32_t flag) {
+ switch (flag) {
+ default:
+ return SparseToSparseConversionStrategy::kAuto;
+ case 1:
+ return SparseToSparseConversionStrategy::kViaCOO;
+ case 2:
+ return SparseToSparseConversionStrategy::kDirect;
+ }
+}
+
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
@@ -158,3 +180,8 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
+
+std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
+ const SparseTensorConversionOptions &options) {
+ return std::make_unique<SparseTensorConversionPass>(options);
+}
More information about the Mlir-commits
mailing list