[Mlir-commits] [mlir] fe2bd54 - [mlir] Add file to implement bufferization for shape ops.
Tres Popp
llvmlistbot at llvm.org
Tue Oct 6 02:36:10 PDT 2020
Author: Tres Popp
Date: 2020-10-06T11:35:16+02:00
New Revision: fe2bd543f5e82bc14ef37dc5ec2228812098cf7a
URL: https://github.com/llvm/llvm-project/commit/fe2bd543f5e82bc14ef37dc5ec2228812098cf7a
DIFF: https://github.com/llvm/llvm-project/commit/fe2bd543f5e82bc14ef37dc5ec2228812098cf7a.diff
LOG: [mlir] Add file to implement bufferization for shape ops.
This adds a shape-bufferize pass and implements the pattern for
shape.assuming.
Differential Revision: https://reviews.llvm.org/D88083
Added:
mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
mlir/test/Dialect/Shape/shape-type-conversion.mlir
Modified:
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 543ffc617a5c..72816b72f41e 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -16,6 +16,10 @@
#include "mlir/Pass/Pass.h"
+namespace mlir {
+class BufferAssignmentTypeConverter;
+} // namespace mlir
+
namespace mlir {
/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
/// dialect to be convertible to Standard. For example, `shape.num_elements` get
@@ -36,6 +40,13 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
+void populateShapeTypeConversionPatterns(
+ MLIRContext *ctx, BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns);
+// Collects a set of patterns to replace tensors as inputs and outputs to shape
+// operations with buffers. This only modifies the shape operations.
+std::unique_ptr<FunctionPass> createShapeTensorToMemrefPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 022bd3773ce2..09cc7a1a5c93 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -21,4 +21,9 @@ def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
let constructor = "mlir::createShapeToShapeLowering()";
}
+// TODO(tpopp): Generalize this to allow any type conversions desired.
+def ShapeTensorToMemref : FunctionPass<"shape-tensor-to-memref"> {
+ let summary = "Replace tensors involving shape operations with memrefs";
+ let constructor = "mlir::createShapeTensorToMemrefPass()";
+}
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 987f9c544b33..9df40a0fb740 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
RemoveShapeConstraints.cpp
+ ShapeTypeConversion.cpp
ShapeToShapeLowering.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
new file mode 100644
index 000000000000..98398fbc70e6
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
@@ -0,0 +1,98 @@
+//=====------- ShapeTypeConversion.cpp - Shape Type Conversions ----------*- C++
+//-*-=====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines patterns to convert types of inputs and outputs to shape
+// operations to be memrefs instead of tensors.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/BufferPlacement.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+// Propagate tensor to memref conversions through shape.assuming ops.
+class TypeConversionAssumingOpConverter
+ : public BufferAssignmentOpConversionPattern<shape::AssumingOp> {
+public:
+ using BufferAssignmentOpConversionPattern<
+ shape::AssumingOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ SmallVector<Type, 2> newResultTypes;
+ newResultTypes.reserve(assumingOp.getNumResults());
+ for (auto result : assumingOp.getResults()) {
+ auto originalType = result.getType();
+ Type convertedType = converter->convertType(originalType);
+ newResultTypes.push_back(convertedType);
+ }
+
+ auto newAssumingOp = rewriter.create<shape::AssumingOp>(
+ assumingOp.getLoc(), newResultTypes, assumingOp.witness());
+
+ // Handle the region transfer carefully here to avoid assertions that both
+ // operations are valid at replacement time.
+ newAssumingOp.doRegion().push_back(new Block());
+ rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
+ newAssumingOp.doRegion().takeBody(assumingOp.doRegion());
+
+ return success();
+ }
+};
+
+struct ShapeTensorToMemrefPass
+ : public ShapeTensorToMemrefBase<ShapeTensorToMemrefPass> {
+ void runOnFunction() override {
+ MLIRContext &ctx = getContext();
+
+ OwningRewritePatternList patterns;
+ BufferAssignmentTypeConverter converter;
+ populateShapeTypeConversionPatterns(&ctx, &converter, &patterns);
+
+ ConversionTarget target(getContext());
+ auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
+
+ target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
+ return std::all_of(op.result_type_begin(), op.result_type_end(),
+ isMemRefType);
+ });
+
+ if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+/// Populates `patterns` with the conversion patterns of tensor->memref.
+//
+// TODO(tpopp): Change this to work generally with any type conversions.
+void mlir::populateShapeTypeConversionPatterns(
+ MLIRContext *context, BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
+ patterns->insert<TypeConversionAssumingOpConverter>(context, converter);
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeTensorToMemrefPass construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<FunctionPass> mlir::createShapeTensorToMemrefPass() {
+ return std::make_unique<ShapeTensorToMemrefPass>();
+}
diff --git a/mlir/test/Dialect/Shape/shape-type-conversion.mlir b/mlir/test/Dialect/Shape/shape-type-conversion.mlir
new file mode 100644
index 000000000000..8985a6da0251
--- /dev/null
+++ b/mlir/test/Dialect/Shape/shape-type-conversion.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -shape-tensor-to-memref <%s | FileCheck %s
+
+// -----
+// Check that shape.assuming returns a memref.
+//
+// CHECK-LABEL: @shape_assuming_returns_memref
+func @shape_assuming_returns_memref() {
+ %0 = shape.const_witness true
+ // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
+ %1 = shape.assuming %0 -> (tensor<2xf16>) {
+ %2 = "test.source"() : () -> (tensor<2xf16>)
+ shape.assuming_yield %2 : tensor<2xf16>
+ }
+ "test.sink"(%1) : (tensor<2xf16>) -> ()
+ return
+}
+
+
More information about the Mlir-commits
mailing list