[Mlir-commits] [mlir] ca01c99 - [mlir][sparse] Add SparseTensorStorageExpansion Pass to expand compounded sparse tensor tuples

Peiming Liu llvmlistbot at llvm.org
Thu Sep 1 15:47:40 PDT 2022


Author: Peiming Liu
Date: 2022-09-01T22:47:31Z
New Revision: ca01c996b2185af08e076d92b39afd96e0567faf

URL: https://github.com/llvm/llvm-project/commit/ca01c996b2185af08e076d92b39afd96e0567faf
DIFF: https://github.com/llvm/llvm-project/commit/ca01c996b2185af08e076d92b39afd96e0567faf.diff

LOG: [mlir][sparse] Add SparseTensorStorageExpansion Pass to expand compounded sparse tensor tuples

This patch adds SparseTensorStorageExpansion pass, it flattens the tuple used to store a sparse
tensor handle.

Right now, it only set up the skeleton for the pass, more lowering rules for sparse tensor storage
operation need to be added.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133125

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
    mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 83eb1bb2d2f75..523f1fb057d47 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -161,6 +161,22 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
 
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
+//===----------------------------------------------------------------------===//
+// The SparseTensorStorageExpansion pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage type converter from compound to expanded form.
+class SparseTensorStorageTupleExpander : public TypeConverter {
+public:
+  SparseTensorStorageTupleExpander();
+};
+
+/// Sets up sparse tensor storage expansion rules.
+void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter,
+                                                  RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorStorageExpansionPass();
+
 //===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c8e71237a2280..d765a10701acb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -146,4 +146,39 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   ];
 }
 
+def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> {
+  let summary = "Expand compounded sparse tensor storage into individual SSA values";
+  let description = [{
+    A pass that expands sparse tensor storage (aggregated by tuple) into
+    individual SSA values. It also lowers sparse tensor storage operations,
+    e.g., sparse_tensor.storage_get and sparse_tensor.storage_set.
+
+    Example of the conversion:
+
+    ```mlir
+    Before:
+      func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>,
+                                                 memref<?xf64>,
+                                                 f64>)
+                                        -> tuple<memref<?xf64>,
+                                                 memref<?xf64>,
+                                                 f64> {
+        return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+      }
+    After:
+      func.func @sparse_storage_set(%arg0: memref<?xf64>,
+                                    %arg1: memref<?xf64>,
+                                    %arg2: f64)
+                                    -> (memref<?xf64>, memref<?xf64>, f64) {
+        return %arg0, %arg1, %arg2 : memref<?xf64>, memref<?xf64>, f64
+      }
+    ```
+  }];
+  let constructor = "mlir::createSparseTensorStorageExpansionPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 640ee67302b1b..39b633a6c7f6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
+  SparseTensorStorageExpansion.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b30d0d2b927f0..d5e2b96089d5b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -24,6 +24,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
+#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -185,6 +186,44 @@ struct SparseTensorCodegenPass
   }
 };
 
+struct SparseTensorStorageExpansionPass
+    : public impl::SparseTensorStorageExpansionBase<
+          SparseTensorStorageExpansionPass> {
+
+  SparseTensorStorageExpansionPass() = default;
+  SparseTensorStorageExpansionPass(
+      const SparseTensorStorageExpansionPass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseTensorStorageTupleExpander converter;
+    ConversionTarget target(*ctx);
+    // Now, everything in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    // All dynamic rules below accept new function, call, return.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    // Populate with rules and apply rewriting rules.
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    populateSparseTensorStorageExpansionPatterns(converter, patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -255,3 +294,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
+
+std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
+  return std::make_unique<SparseTensorStorageExpansionPass>();
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
new file mode 100644
index 0000000000000..c1305eba79009
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -0,0 +1,96 @@
+//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The sparse tensor storage expansion pass expands the compound storage for
+// sparse tensors (using tuple) to flattened SSA values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Expands sparse tensor storage tuple.
+static Optional<LogicalResult>
+convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
+  if (auto tuple = t.dyn_cast<TupleType>()) {
+    // Note that it does not handle nest tuples, but it is fine
+    // for sparse compiler as they will not be generated.
+    result.append(tuple.getTypes().begin(), tuple.getTypes().end());
+    return success();
+  }
+  return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage conversion rule for returns.
+class SparseStorageReturnConverter
+    : public OpConversionPattern<func::ReturnOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value, 8> flattened;
+    for (auto operand : adaptor.getOperands()) {
+      if (auto cast =
+              dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+          cast && cast->getResultTypes()[0].isa<TupleType>())
+        // An unrealized_conversion_cast will be inserted by type converter to
+        // inter-mix the gap between 1:N conversion between tuple and types.
+        // In this case, take the operands in the cast and replace the tuple
+        // output with the flattened type array.
+        flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+      else
+        flattened.push_back(operand);
+    }
+    // Create a return with the flattened value extracted from tuple.
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor storage expansion
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorStorageTuple);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required
+/// to expand compounded sparse tensor tuples.
+void mlir::populateSparseTensorStorageExpansionPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  patterns.add<SparseStorageReturnConverter>(typeConverter,
+                                             patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
new file mode 100644
index 0000000000000..445b234a2a8d2
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
+
+// CHECK-LABEL:  func @sparse_storage_expand(
+// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64
+// CHECK           return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
+func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
+                                     -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}


        


More information about the Mlir-commits mailing list