[Mlir-commits] [mlir] fe2bd54 - [mlir] Add file to implement bufferization for shape ops.

Tres Popp llvmlistbot at llvm.org
Tue Oct 6 02:36:10 PDT 2020


Author: Tres Popp
Date: 2020-10-06T11:35:16+02:00
New Revision: fe2bd543f5e82bc14ef37dc5ec2228812098cf7a

URL: https://github.com/llvm/llvm-project/commit/fe2bd543f5e82bc14ef37dc5ec2228812098cf7a
DIFF: https://github.com/llvm/llvm-project/commit/fe2bd543f5e82bc14ef37dc5ec2228812098cf7a.diff

LOG: [mlir] Add file to implement bufferization for shape ops.

This adds a shape-bufferize pass and implements the pattern for
shape.assuming.

Differential Revision: https://reviews.llvm.org/D88083

Added: 
    mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
    mlir/test/Dialect/Shape/shape-type-conversion.mlir

Modified: 
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
    mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 543ffc617a5c..72816b72f41e 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -16,6 +16,10 @@
 
 #include "mlir/Pass/Pass.h"
 
+namespace mlir {
+class BufferAssignmentTypeConverter;
+} // namespace mlir
+
 namespace mlir {
 /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
 /// dialect to be convertible to Standard. For example, `shape.num_elements` get
@@ -36,6 +40,13 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
                                             MLIRContext *ctx);
 std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 
+void populateShapeTypeConversionPatterns(
+    MLIRContext *ctx, BufferAssignmentTypeConverter *converter,
+    OwningRewritePatternList *patterns);
+// Collects a set of patterns to replace tensors as inputs and outputs to shape
+// operations with buffers. This only modifies the shape operations.
+std::unique_ptr<FunctionPass> createShapeTensorToMemrefPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 022bd3773ce2..09cc7a1a5c93 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -21,4 +21,9 @@ def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
   let constructor = "mlir::createShapeToShapeLowering()";
 }
 
+// TODO(tpopp): Generalize this to allow any type conversions desired.
+def ShapeTensorToMemref : FunctionPass<"shape-tensor-to-memref"> {
+  let summary = "Replace tensors involving shape operations with memrefs";
+  let constructor = "mlir::createShapeTensorToMemrefPass()";
+}
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 987f9c544b33..9df40a0fb740 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
   RemoveShapeConstraints.cpp
+  ShapeTypeConversion.cpp
   ShapeToShapeLowering.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
new file mode 100644
index 000000000000..98398fbc70e6
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
@@ -0,0 +1,98 @@
+//=====------- ShapeTypeConversion.cpp - Shape Type Conversions ----------*- C++
+//-*-=====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines patterns to convert types of inputs and outputs to shape
+// operations to be memrefs instead of tensors.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/BufferPlacement.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+// Propagate tensor to memref conversions through shape.assuming ops.
+class TypeConversionAssumingOpConverter
+    : public BufferAssignmentOpConversionPattern<shape::AssumingOp> {
+public:
+  using BufferAssignmentOpConversionPattern<
+      shape::AssumingOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    SmallVector<Type, 2> newResultTypes;
+    newResultTypes.reserve(assumingOp.getNumResults());
+    for (auto result : assumingOp.getResults()) {
+      auto originalType = result.getType();
+      Type convertedType = converter->convertType(originalType);
+      newResultTypes.push_back(convertedType);
+    }
+
+    auto newAssumingOp = rewriter.create<shape::AssumingOp>(
+        assumingOp.getLoc(), newResultTypes, assumingOp.witness());
+
+    // Handle the region transfer carefully here to avoid assertions that both
+    // operations are valid at replacement time.
+    newAssumingOp.doRegion().push_back(new Block());
+    rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
+    newAssumingOp.doRegion().takeBody(assumingOp.doRegion());
+
+    return success();
+  }
+};
+
+struct ShapeTensorToMemrefPass
+    : public ShapeTensorToMemrefBase<ShapeTensorToMemrefPass> {
+  void runOnFunction() override {
+    MLIRContext &ctx = getContext();
+
+    OwningRewritePatternList patterns;
+    BufferAssignmentTypeConverter converter;
+    populateShapeTypeConversionPatterns(&ctx, &converter, &patterns);
+
+    ConversionTarget target(getContext());
+    auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
+
+    target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
+      return std::all_of(op.result_type_begin(), op.result_type_end(),
+                         isMemRefType);
+    });
+
+    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+/// Populates `patterns` with the conversion patterns of tensor->memref.
+//
+// TODO(tpopp): Change this to work generally with any type conversions.
+void mlir::populateShapeTypeConversionPatterns(
+    MLIRContext *context, BufferAssignmentTypeConverter *converter,
+    OwningRewritePatternList *patterns) {
+  patterns->insert<TypeConversionAssumingOpConverter>(context, converter);
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeTensorToMemrefPass construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<FunctionPass> mlir::createShapeTensorToMemrefPass() {
+  return std::make_unique<ShapeTensorToMemrefPass>();
+}

diff  --git a/mlir/test/Dialect/Shape/shape-type-conversion.mlir b/mlir/test/Dialect/Shape/shape-type-conversion.mlir
new file mode 100644
index 000000000000..8985a6da0251
--- /dev/null
+++ b/mlir/test/Dialect/Shape/shape-type-conversion.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -shape-tensor-to-memref <%s | FileCheck %s
+
+// -----
+// Check that shape.assuming returns a memref.
+//
+// CHECK-LABEL: @shape_assuming_returns_memref
+func @shape_assuming_returns_memref() {
+  %0 = shape.const_witness true
+  // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
+  %1 = shape.assuming %0 -> (tensor<2xf16>) {
+    %2 = "test.source"() : () -> (tensor<2xf16>)
+    shape.assuming_yield %2 : tensor<2xf16>
+  }
+  "test.sink"(%1) : (tensor<2xf16>) -> ()
+  return
+}
+
+


        


More information about the Mlir-commits mailing list