[Mlir-commits] [mlir] 86b22d3 - [mlir][sparse] start a sparse codegen conversion pass

Aart Bik llvmlistbot at llvm.org
Mon Aug 29 09:39:45 PDT 2022


Author: Aart Bik
Date: 2022-08-29T09:39:33-07:00
New Revision: 86b22d312053f38c7ea94af49dd0e93c660ffec8

URL: https://github.com/llvm/llvm-project/commit/86b22d312053f38c7ea94af49dd0e93c660ffec8
DIFF: https://github.com/llvm/llvm-project/commit/86b22d312053f38c7ea94af49dd0e93c660ffec8.diff

LOG: [mlir][sparse] start a sparse codegen conversion pass

This new pass provides an alternative to the current conversion pass
that converts sparse tensor types and sparse primitives to opaque pointers
and calls into a runtime support library. This pass will map sparse tensor
types to actual data structures and primitives to actual code. In the long
run, this new pass will remove our dependence on the support library, avoid
the need to link in fully templated and expanded code, and provide much better
opportunities for optimization on the generated code.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D132766

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 71afddcb49245..2d4bdb3a4b5e6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -11,8 +11,6 @@
 // In general, this file takes the approach of keeping "mechanism" (the
 // actual steps of applying a transformation) completely separate from
 // "policy" (heuristics for when and where to apply transformations).
-// The only exception is in `SparseToSparseConversionStrategy`; for which,
-// see further discussion there.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,15 +19,13 @@
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 namespace bufferization {
 struct OneShotBufferizationOptions;
 } // namespace bufferization
 
-// Forward.
-class TypeConverter;
-
 //===----------------------------------------------------------------------===//
 // The Sparsification pass.
 //===----------------------------------------------------------------------===//
@@ -95,6 +91,12 @@ createSparsificationPass(const SparsificationOptions &options);
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
 
+/// Sparse tensor type converter into an opaque pointer.
+class SparseTensorTypeToPtrConverter : public TypeConverter {
+public:
+  SparseTensorTypeToPtrConverter();
+};
+
 /// Defines a strategy for implementing sparse-to-sparse conversion.
 /// `kAuto` leaves it up to the compiler to automatically determine
 /// the method used.  `kViaCOO` converts the source tensor to COO and
@@ -138,6 +140,22 @@ std::unique_ptr<Pass> createSparseTensorConversionPass();
 std::unique_ptr<Pass>
 createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
 
+//===----------------------------------------------------------------------===//
+// The SparseTensorCodegen pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor type converter into an actual buffer.
+class SparseTensorTypeToBufferConverter : public TypeConverter {
+public:
+  SparseTensorTypeToBufferConverter();
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorCodegenPass();
+
 //===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 6e36259de9490..4ca224b167f20 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -77,16 +77,16 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
 }
 
 def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
-  let summary = "Apply conversion rules to sparse tensor primitives and types";
+  let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{
-    A pass that converts sparse tensor primitives to calls into a runtime
-    support library. All sparse tensor types are converted into opaque
-    pointers to the underlying sparse storage schemes.
+    A pass that converts sparse tensor primitives into calls into a runtime
+    support library. Sparse tensor types are converted into opaque pointers
+    to the underlying sparse storage schemes.
 
-    Note that this is a current implementation choice to keep the conversion
-    relatively simple. In principle, these primitives could also be
-    converted to actual elaborate IR code that implements the primitives
-    on the selected sparse tensor storage schemes.
+    The use of opaque pointers together with runtime support library keeps
+    the conversion relatively simple, but at the expense of IR opacity,
+    which obscures opportunities for subsequent optimization of the IR.
+    An alternative is provided by the SparseTensorCodegen pass.
 
     Example of the conversion:
 
@@ -122,4 +122,28 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   ];
 }
 
+def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
+  let summary = "Convert sparse tensors and primitives to actual code";
+  let description = [{
+    A pass that converts sparse tensor types and primitives to actual
+    compiler visible buffers and compiler IR that implements these
+    primitives on the selected sparse tensor storage schemes.
+
+    This pass provides an alternative to the SparseTensorConversion pass,
+    eliminating the dependence on a runtime support library, and providing
+    much more opportunities for subsequent compiler optimization of the
+    generated code.
+
+    Example of the conversion:
+
+    ```mlir
+    TBD
+    ```
+  }];
+  let constructor = "mlir::createSparseTensorCodegenPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 9d99d2f7a5c8b..640ee67302b1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenUtils.cpp
   DenseBufferizationPass.cpp
   Sparsification.cpp
+  SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
new file mode 100644
index 0000000000000..86669260c970f
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -0,0 +1,82 @@
+//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that converts sparse tensor types and primitives to actual compiler
+// visible buffers and actual compiler IR that implements these primitives on
+// the selected sparse tensor storage schemes. This pass provides an alternative
+// to the SparseTensorConversion pass, eliminating the dependence on a runtime
+// support library, and providing much more opportunities for subsequent
+// compiler optimization of the generated code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Maps each sparse tensor type to the appropriate buffer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+  if (getSparseTensorEncoding(type) != nullptr) {
+    // TODO: this is just a dummy rule to get the ball rolling....
+    RankedTensorType rTp = type.cast<RankedTensorType>();
+    return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType());
+  }
+  return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse conversion rule for returns.
+class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into an actual buffer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required for
+/// the sparsification of linear algebra operations.
+void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+                                               RewritePatternSet &patterns) {
+  patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 716101562ead0..5bd10e89caa95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -6,11 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Convert sparse tensor primitives to calls into a runtime support library.
-// Note that this is a current implementation choice to keep the conversion
-// simple. In principle, these primitives could also be converted to actual
-// elaborate IR code that implements the primitives on the selected sparse
-// tensor storage schemes.
+// A pass that converts sparse tensor primitives into calls into a runtime
+// support library. Sparse tensor types are converted into opaque pointers
+// to the underlying sparse storage schemes. The use of opaque pointers
+// together with runtime support library keeps the conversion relatively
+// simple, but at the expense of IR opacity, which obscures opportunities
+// for subsequent optimization of the IR. An alternative is provided by
+// the SparseTensorCodegen pass.
 //
 //===----------------------------------------------------------------------===//
 
@@ -48,6 +50,13 @@ static Type getOpaquePointerType(OpBuilder &builder) {
   return LLVM::LLVMPointerType::get(builder.getI8Type());
 }
 
+/// Maps each sparse tensor type to an opaque pointer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+  if (getSparseTensorEncoding(type) != nullptr)
+    return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
+  return llvm::None;
+}
+
 /// Returns a function reference (first hit also inserts into module). Sets
 /// the "_emit_c_interface" on the function declaration when requested,
 /// so that LLVM lowering generates a wrapper function that takes care
@@ -1345,6 +1354,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
     return success();
   }
 };
+
 /// Sparse conversion rule for the output operator.
 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
 public:
@@ -1387,6 +1397,15 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into opaque pointer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorTypes);
+}
+
 //===----------------------------------------------------------------------===//
 // Public method for populating conversion rules.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 2014781610536..643cff9844492 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -67,20 +67,6 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
   }
 };
 
-class SparseTensorTypeConverter : public TypeConverter {
-public:
-  SparseTensorTypeConverter() {
-    addConversion([](Type type) { return type; });
-    addConversion(convertSparseTensorTypes);
-  }
-  // Maps each sparse tensor type to an opaque pointer.
-  static Optional<Type> convertSparseTensorTypes(Type type) {
-    if (getSparseTensorEncoding(type) != nullptr)
-      return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
-    return llvm::None;
-  }
-};
-
 struct SparseTensorConversionPass
     : public SparseTensorConversionBase<SparseTensorConversionPass> {
 
@@ -93,7 +79,7 @@ struct SparseTensorConversionPass
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    SparseTensorTypeConverter converter;
+    SparseTensorTypeToPtrConverter converter;
     ConversionTarget target(*ctx);
     // Everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
@@ -158,8 +144,49 @@ struct SparseTensorConversionPass
   }
 };
 
+struct SparseTensorCodegenPass
+    : public SparseTensorCodegenBase<SparseTensorCodegenPass> {
+
+  SparseTensorCodegenPass() = default;
+  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseTensorTypeToBufferConverter converter;
+    ConversionTarget target(*ctx);
+    // Everything in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    // All dynamic rules below accept new function, call, return, and various
+    // tensor and bufferization operations as legal output of the rewriting.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    // Populate with rules and apply rewriting rules.
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    populateSparseTensorCodegenPatterns(converter, patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Strategy flag methods.
+//===----------------------------------------------------------------------===//
+
 SparseParallelizationStrategy
 mlir::sparseParallelizationStrategy(int32_t flag) {
   switch (flag) {
@@ -199,6 +226,10 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Pass creation methods.
+//===----------------------------------------------------------------------===//
+
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
   return std::make_unique<SparsificationPass>();
 }
@@ -216,3 +247,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
     const SparseTensorConversionOptions &options) {
   return std::make_unique<SparseTensorConversionPass>(options);
 }
+
+std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
+  return std::make_unique<SparseTensorCodegenPass>();
+}

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
new file mode 100644
index 0000000000000..a3cecaf367a40
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+}>
+
+// TODO: just a dummy memref rewriting to get the ball rolling....
+
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A:.*]]: memref<?xf64>) -> memref<?xf64> {
+//       CHECK: return %[[A]] : memref<?xf64>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+  return %arg0 : tensor<?xf64, #SparseVector>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 336d8158eb2af..4de4021b26876 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -25,6 +25,13 @@
   dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
 }>
 
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+  return %arg0 : tensor<?xf64, #SparseVector>
+}
+
 // CHECK-LABEL: func @sparse_dim1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = arith.constant 0 : index


        


More information about the Mlir-commits mailing list