[Mlir-commits] [mlir] ca01c99 - [mlir][sparse] Add SparseTensorStorageExpansion Pass to expand compounded sparse tensor tuples
Peiming Liu
llvmlistbot at llvm.org
Thu Sep 1 15:47:40 PDT 2022
Author: Peiming Liu
Date: 2022-09-01T22:47:31Z
New Revision: ca01c996b2185af08e076d92b39afd96e0567faf
URL: https://github.com/llvm/llvm-project/commit/ca01c996b2185af08e076d92b39afd96e0567faf
DIFF: https://github.com/llvm/llvm-project/commit/ca01c996b2185af08e076d92b39afd96e0567faf.diff
LOG: [mlir][sparse] Add SparseTensorStorageExpansion Pass to expand compounded sparse tensor tuples
This patch adds SparseTensorStorageExpansion pass, it flattens the tuple used to store a sparse
tensor handle.
Right now, it only set up the skeleton for the pass, more lowering rules for sparse tensor storage
operation need to be added.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133125
Added:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 83eb1bb2d2f75..523f1fb057d47 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -161,6 +161,22 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
std::unique_ptr<Pass> createSparseTensorCodegenPass();
+//===----------------------------------------------------------------------===//
+// The SparseTensorStorageExpansion pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage type converter from compound to expanded form.
+class SparseTensorStorageTupleExpander : public TypeConverter {
+public:
+ SparseTensorStorageTupleExpander();
+};
+
+/// Sets up sparse tensor storage expansion rules.
+void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorStorageExpansionPass();
+
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c8e71237a2280..d765a10701acb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -146,4 +146,39 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
];
}
+def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> {
+ let summary = "Expand compounded sparse tensor storage into individual SSA values";
+ let description = [{
+ A pass that expands sparse tensor storage (aggregated by tuple) into
+ individual SSA values. It also lowers sparse tensor storage operations,
+ e.g., sparse_tensor.storage_get and sparse_tensor.storage_set.
+
+ Example of the conversion:
+
+ ```mlir
+ Before:
+ func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>,
+ memref<?xf64>,
+ f64>)
+ -> tuple<memref<?xf64>,
+ memref<?xf64>,
+ f64> {
+ return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+ }
+ After:
+ func.func @sparse_storage_set(%arg0: memref<?xf64>,
+ %arg1: memref<?xf64>,
+ %arg2: f64)
+ -> (memref<?xf64>, memref<?xf64>, f64) {
+ return %arg0, %arg1, %arg2 : memref<?xf64>, memref<?xf64>, f64
+ }
+ ```
+ }];
+ let constructor = "mlir::createSparseTensorStorageExpansionPass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 640ee67302b1b..39b633a6c7f6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
+ SparseTensorStorageExpansion.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b30d0d2b927f0..d5e2b96089d5b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -24,6 +24,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
+#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -185,6 +186,44 @@ struct SparseTensorCodegenPass
}
};
+struct SparseTensorStorageExpansionPass
+ : public impl::SparseTensorStorageExpansionBase<
+ SparseTensorStorageExpansionPass> {
+
+ SparseTensorStorageExpansionPass() = default;
+ SparseTensorStorageExpansionPass(
+ const SparseTensorStorageExpansionPass &pass) = default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ SparseTensorStorageTupleExpander converter;
+ ConversionTarget target(*ctx);
+ // Now, everything in the sparse dialect must go!
+ target.addIllegalDialect<SparseTensorDialect>();
+ // All dynamic rules below accept new function, call, return.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+ return converter.isSignatureLegal(op.getCalleeType());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+ return converter.isLegal(op.getOperandTypes());
+ });
+ // Populate with rules and apply rewriting rules.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ populateSparseTensorStorageExpansionPatterns(converter, patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -255,3 +294,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
+
+std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
+ return std::make_unique<SparseTensorStorageExpansionPass>();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
new file mode 100644
index 0000000000000..c1305eba79009
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -0,0 +1,96 @@
+//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The sparse tensor storage expansion pass expands the compound storage for
+// sparse tensors (using tuple) to flattened SSA values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Expands sparse tensor storage tuple.
+static Optional<LogicalResult>
+convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
+ if (auto tuple = t.dyn_cast<TupleType>()) {
+ // Note that it does not handle nest tuples, but it is fine
+ // for sparse compiler as they will not be generated.
+ result.append(tuple.getTypes().begin(), tuple.getTypes().end());
+ return success();
+ }
+ return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage conversion rule for returns.
+class SparseStorageReturnConverter
+ : public OpConversionPattern<func::ReturnOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value, 8> flattened;
+ for (auto operand : adaptor.getOperands()) {
+ if (auto cast =
+ dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+ cast && cast->getResultTypes()[0].isa<TupleType>())
+ // An unrealized_conversion_cast will be inserted by type converter to
+ // inter-mix the gap between 1:N conversion between tuple and types.
+ // In this case, take the operands in the cast and replace the tuple
+ // output with the flattened type array.
+ flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+ else
+ flattened.push_back(operand);
+ }
+ // Create a return with the flattened value extracted from tuple.
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor storage expansion
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertSparseTensorStorageTuple);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required
+/// to expand compounded sparse tensor tuples.
+void mlir::populateSparseTensorStorageExpansionPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<SparseStorageReturnConverter>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
new file mode 100644
index 0000000000000..445b234a2a8d2
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
+
+// CHECK-LABEL: func @sparse_storage_expand(
+// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64
+// CHECK return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
+func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
+ -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
More information about the Mlir-commits
mailing list