[Mlir-commits] [mlir] 57b338c - [mlir][shape] Split out structural type conversions for shape dialect.
Sean Silva
llvmlistbot at llvm.org
Wed Oct 21 11:59:04 PDT 2020
Author: Sean Silva
Date: 2020-10-21T11:58:27-07:00
New Revision: 57b338c08a4942bda6e58c77870c657c53b6fb5b
URL: https://github.com/llvm/llvm-project/commit/57b338c08a4942bda6e58c77870c657c53b6fb5b
DIFF: https://github.com/llvm/llvm-project/commit/57b338c08a4942bda6e58c77870c657c53b6fb5b.diff
LOG: [mlir][shape] Split out structural type conversions for shape dialect.
A "structural" type conversion is one where the underlying ops are
completely agnostic to the actual types involved and simply need to update
their types. An example of this is shape.assuming -- the shape.assuming op
and the corresponding shape.assuming_yield op need to update their types
accordingly to the TypeConverter, but otherwise don't care what type
conversions are happening.
Also, the previous conversion code would not correctly materialize
conversions for the shape.assuming_yield op. This should have caused a
verification failure, but shape.assuming's verifier wasn't calling
RegionBranchOpInterface::verifyTypes (which for reasons can't be called
automatically as part of the trait verification, and requires being
called manually). This patch also adds that verification.
Differential Revision: https://reviews.llvm.org/D89833
Added:
mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
mlir/test/Dialect/Shape/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index c6c52f2eb6ee..6541cfadfc1b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -635,6 +635,7 @@ def Shape_AssumingOp : Shape_Op<"assuming",
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
+ let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let extraClassDeclaration = [{
// Inline the region into the region containing the AssumingOp and delete
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index f8976e9c75eb..6df12998566a 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -17,7 +17,8 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-class BufferizeTypeConverter;
+class ConversionTarget;
+class TypeConverter;
} // namespace mlir
namespace mlir {
@@ -40,9 +41,21 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
-void populateShapeTypeConversionPatterns(MLIRContext *ctx,
- BufferizeTypeConverter &converter,
- OwningRewritePatternList &patterns);
+/// Populates patterns for shape dialect structural type conversions and sets up
+/// the provided ConversionTarget with the appropriate legality configuration
+/// for the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types consistently. An example of this is shape.assuming -- the
+/// shape.assuming op and the corresponding shape.assuming_yield op need to have
+/// consistent types, but the exact types don't matter. So all that we need to
+/// do for a structural type conversion is to update both of their types
+/// consistently to the new types prescribed by the TypeConverter.
+void populateShapeStructuralTypeConversionsAndLegality(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns, ConversionTarget &target);
+
// Bufferizes shape dialect ops.
//
// Note that most shape dialect ops must be converted to std before
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index bdebfa9a32d7..20cd960e040f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -8,82 +8,30 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
-using namespace mlir::shape;
namespace {
-// Propagate tensor to memref conversions through shape.assuming ops.
-class TypeConversionAssumingOpConverter
- : public BufferizeOpConversionPattern<shape::AssumingOp> {
-public:
- using BufferizeOpConversionPattern<
- shape::AssumingOp>::BufferizeOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- SmallVector<Type, 2> newResultTypes;
- newResultTypes.reserve(assumingOp.getNumResults());
- for (auto result : assumingOp.getResults()) {
- auto originalType = result.getType();
- Type convertedType = converter.convertType(originalType);
- newResultTypes.push_back(convertedType);
- }
-
- auto newAssumingOp = rewriter.create<shape::AssumingOp>(
- assumingOp.getLoc(), newResultTypes, assumingOp.witness());
-
- rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
- rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(),
- newAssumingOp.doRegion().end());
-
- return success();
- }
-};
-
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
OwningRewritePatternList patterns;
- BufferizeTypeConverter converter;
- populateShapeTypeConversionPatterns(&ctx, converter, patterns);
-
+ BufferizeTypeConverter typeConverter;
ConversionTarget target(getContext());
- auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
- target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
- return std::all_of(op.result_type_begin(), op.result_type_end(),
- isMemRefType);
- });
+ populateBufferizeMaterializationLegality(target);
+ populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
+ patterns, target);
- if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
};
-
} // namespace
-/// Populates `patterns` with the conversion patterns of tensor->memref.
-//
-// TODO: Change this to work generally with any type conversions.
-void mlir::populateShapeTypeConversionPatterns(
- MLIRContext *context, BufferizeTypeConverter &converter,
- OwningRewritePatternList &patterns) {
- patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeBufferizePass construction
-//===----------------------------------------------------------------------===//
-
std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
return std::make_unique<ShapeBufferizePass>();
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index ce413f57d989..123a3664df89 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
Bufferize.cpp
RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
+ StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 000000000000..61e862836a73
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,71 @@
+//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ SmallVector<Type, 2> newResultTypes;
+ newResultTypes.reserve(op.getNumResults());
+ for (auto result : op.getResults()) {
+ auto originalType = result.getType();
+ Type convertedType = getTypeConverter()->convertType(originalType);
+ newResultTypes.push_back(convertedType);
+ }
+
+ auto newAssumingOp =
+ rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
+
+ rewriter.replaceOp(op, newAssumingOp.getResults());
+ rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
+ newAssumingOp.doRegion().end());
+
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+class ConvertAssumingYieldOpTypes
+ : public OpConversionPattern<AssumingYieldOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateShapeStructuralTypeConversionsAndLegality(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns, ConversionTarget &target) {
+ patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
+ typeConverter, context);
+ target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
+ return typeConverter.isLegal(op.getResultTypes());
+ });
+ target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
+ return typeConverter.isLegal(op.getOperandTypes());
+ });
+}
diff --git a/mlir/test/Dialect/Shape/bufferize.mlir b/mlir/test/Dialect/Shape/bufferize.mlir
index 7393de101466..cb65a5d42d4b 100644
--- a/mlir/test/Dialect/Shape/bufferize.mlir
+++ b/mlir/test/Dialect/Shape/bufferize.mlir
@@ -1,12 +1,20 @@
// RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
// -----
-// Check that shape.assuming returns a memref.
-//
-// CHECK-LABEL: @shape_assuming_returns_memref
-func @shape_assuming_returns_memref() {
+
+// CHECK-LABEL: func @shape_assuming() {
+// CHECK: %[[WTRUE:.*]] = shape.const_witness true
+// CHECK: %[[MEMREF:.*]] = shape.assuming %[[WTRUE]] -> (memref<2xf16>) {
+// CHECK: %[[TENSOR_VAL:.*]] = "test.source"() : () -> tensor<2xf16>
+// CHECK: %[[YIELDED_MEMREF:.*]] = tensor_to_memref %[[TENSOR_VAL]] : memref<2xf16>
+// CHECK: shape.assuming_yield %[[YIELDED_MEMREF]] : memref<2xf16>
+// CHECK: }
+// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<2xf16>
+// CHECK: "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> ()
+// CHECK: return
+// CHECK: }
+func @shape_assuming() {
%0 = shape.const_witness true
- // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
%1 = shape.assuming %0 -> (tensor<2xf16>) {
%2 = "test.source"() : () -> (tensor<2xf16>)
shape.assuming_yield %2 : tensor<2xf16>
More information about the Mlir-commits
mailing list