[Mlir-commits] [mlir] 57b338c - [mlir][shape] Split out structural type conversions for shape dialect.

Sean Silva llvmlistbot at llvm.org
Wed Oct 21 11:59:04 PDT 2020


Author: Sean Silva
Date: 2020-10-21T11:58:27-07:00
New Revision: 57b338c08a4942bda6e58c77870c657c53b6fb5b

URL: https://github.com/llvm/llvm-project/commit/57b338c08a4942bda6e58c77870c657c53b6fb5b
DIFF: https://github.com/llvm/llvm-project/commit/57b338c08a4942bda6e58c77870c657c53b6fb5b.diff

LOG: [mlir][shape] Split out structural type conversions for shape dialect.

A "structural" type conversion is one where the underlying ops are
completely agnostic to the actual types involved and simply need to update
their types. An example of this is shape.assuming -- the shape.assuming op
and the corresponding shape.assuming_yield op need to update their types
accordingly to the TypeConverter, but otherwise don't care what type
conversions are happening.

Also, the previous conversion code would not correctly materialize
conversions for the shape.assuming_yield op. This should have caused a
verification failure, but shape.assuming's verifier wasn't calling
RegionBranchOpInterface::verifyTypes (which for reasons can't be called
automatically as part of the trait verification, and requires being
called manually). This patch also adds that verification.

Differential Revision: https://reviews.llvm.org/D89833

Added: 
    mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
    mlir/test/Dialect/Shape/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index c6c52f2eb6ee..6541cfadfc1b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -635,6 +635,7 @@ def Shape_AssumingOp : Shape_Op<"assuming",
 
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
+  let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
 
   let extraClassDeclaration = [{
     // Inline the region into the region containing the AssumingOp and delete

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index f8976e9c75eb..6df12998566a 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -17,7 +17,8 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
-class BufferizeTypeConverter;
+class ConversionTarget;
+class TypeConverter;
 } // namespace mlir
 
 namespace mlir {
@@ -40,9 +41,21 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
                                             MLIRContext *ctx);
 std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 
-void populateShapeTypeConversionPatterns(MLIRContext *ctx,
-                                         BufferizeTypeConverter &converter,
-                                         OwningRewritePatternList &patterns);
+/// Populates patterns for shape dialect structural type conversions and sets up
+/// the provided ConversionTarget with the appropriate legality configuration
+/// for the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types consistently. An example of this is shape.assuming -- the
+/// shape.assuming op and the corresponding shape.assuming_yield op need to have
+/// consistent types, but the exact types don't matter. So all that we need to
+/// do for a structural type conversion is to update both of their types
+/// consistently to the new types prescribed by the TypeConverter.
+void populateShapeStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target);
+
 // Bufferizes shape dialect ops.
 //
 // Note that most shape dialect ops must be converted to std before

diff  --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index bdebfa9a32d7..20cd960e040f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -8,82 +8,30 @@
 
 #include "mlir/Transforms/Bufferize.h"
 #include "PassDetail.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
-using namespace mlir::shape;
 
 namespace {
-// Propagate tensor to memref conversions through shape.assuming ops.
-class TypeConversionAssumingOpConverter
-    : public BufferizeOpConversionPattern<shape::AssumingOp> {
-public:
-  using BufferizeOpConversionPattern<
-      shape::AssumingOp>::BufferizeOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    SmallVector<Type, 2> newResultTypes;
-    newResultTypes.reserve(assumingOp.getNumResults());
-    for (auto result : assumingOp.getResults()) {
-      auto originalType = result.getType();
-      Type convertedType = converter.convertType(originalType);
-      newResultTypes.push_back(convertedType);
-    }
-
-    auto newAssumingOp = rewriter.create<shape::AssumingOp>(
-        assumingOp.getLoc(), newResultTypes, assumingOp.witness());
-
-    rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
-    rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(),
-                                newAssumingOp.doRegion().end());
-
-    return success();
-  }
-};
-
 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
   void runOnFunction() override {
     MLIRContext &ctx = getContext();
 
     OwningRewritePatternList patterns;
-    BufferizeTypeConverter converter;
-    populateShapeTypeConversionPatterns(&ctx, converter, patterns);
-
+    BufferizeTypeConverter typeConverter;
     ConversionTarget target(getContext());
-    auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
 
-    target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
-      return std::all_of(op.result_type_begin(), op.result_type_end(),
-                         isMemRefType);
-    });
+    populateBufferizeMaterializationLegality(target);
+    populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
+                                                      patterns, target);
 
-    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+    if (failed(applyPartialConversion(getFunction(), target, patterns)))
       signalPassFailure();
   }
 };
-
 } // namespace
 
-/// Populates `patterns` with the conversion patterns of tensor->memref.
-//
-// TODO: Change this to work generally with any type conversions.
-void mlir::populateShapeTypeConversionPatterns(
-    MLIRContext *context, BufferizeTypeConverter &converter,
-    OwningRewritePatternList &patterns) {
-  patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeBufferizePass construction
-//===----------------------------------------------------------------------===//
-
 std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
   return std::make_unique<ShapeBufferizePass>();
 }

diff  --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index ce413f57d989..123a3664df89 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
   Bufferize.cpp
   RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms

diff  --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 000000000000..61e862836a73
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,71 @@
+//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    SmallVector<Type, 2> newResultTypes;
+    newResultTypes.reserve(op.getNumResults());
+    for (auto result : op.getResults()) {
+      auto originalType = result.getType();
+      Type convertedType = getTypeConverter()->convertType(originalType);
+      newResultTypes.push_back(convertedType);
+    }
+
+    auto newAssumingOp =
+        rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
+
+    rewriter.replaceOp(op, newAssumingOp.getResults());
+    rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
+                                newAssumingOp.doRegion().end());
+
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+class ConvertAssumingYieldOpTypes
+    : public OpConversionPattern<AssumingYieldOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateShapeStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target) {
+  patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
+      typeConverter, context);
+  target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
+    return typeConverter.isLegal(op.getResultTypes());
+  });
+  target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
+    return typeConverter.isLegal(op.getOperandTypes());
+  });
+}

diff  --git a/mlir/test/Dialect/Shape/bufferize.mlir b/mlir/test/Dialect/Shape/bufferize.mlir
index 7393de101466..cb65a5d42d4b 100644
--- a/mlir/test/Dialect/Shape/bufferize.mlir
+++ b/mlir/test/Dialect/Shape/bufferize.mlir
@@ -1,12 +1,20 @@
 // RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
 
 // -----
-// Check that shape.assuming returns a memref.
-//
-// CHECK-LABEL: @shape_assuming_returns_memref
-func @shape_assuming_returns_memref() {
+
+// CHECK-LABEL:   func @shape_assuming() {
+// CHECK:           %[[WTRUE:.*]] = shape.const_witness true
+// CHECK:           %[[MEMREF:.*]] = shape.assuming %[[WTRUE]] -> (memref<2xf16>) {
+// CHECK:             %[[TENSOR_VAL:.*]] = "test.source"() : () -> tensor<2xf16>
+// CHECK:             %[[YIELDED_MEMREF:.*]] = tensor_to_memref %[[TENSOR_VAL]] : memref<2xf16>
+// CHECK:             shape.assuming_yield %[[YIELDED_MEMREF]] : memref<2xf16>
+// CHECK:           }
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<2xf16>
+// CHECK:           "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> ()
+// CHECK:           return
+// CHECK:         }
+func @shape_assuming() {
   %0 = shape.const_witness true
-  // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
   %1 = shape.assuming %0 -> (tensor<2xf16>) {
     %2 = "test.source"() : () -> (tensor<2xf16>)
     shape.assuming_yield %2 : tensor<2xf16>


        


More information about the Mlir-commits mailing list