[Mlir-commits] [mlir] 3b021fb - [MLIR][LinAlg] Detensorize interal function control flow.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 2 02:49:23 PST 2021


Author: KareemErgawy-TomTom
Date: 2021-03-02T11:46:20+01:00
New Revision: 3b021fbdc04b627b8bc1f53835dc2f6aefddd7c2

URL: https://github.com/llvm/llvm-project/commit/3b021fbdc04b627b8bc1f53835dc2f6aefddd7c2
DIFF: https://github.com/llvm/llvm-project/commit/3b021fbdc04b627b8bc1f53835dc2f6aefddd7c2.diff

LOG: [MLIR][LinAlg] Detensorize interal function control flow.

This patch continues detensorizing implementation by detensoring
internal control flow in functions.

In order to detensorize functions, all the non-entry block's arguments
are detensored and branches between such blocks are properly updated to
reflect the detensored types as well. Function entry block (signature)
is left intact.

This continues work towards handling github/google/iree#1159.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D97148

Added: 
    mlir/test/Dialect/Linalg/detensorized_while.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index e51d08d3770d..ec54e93c988d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -151,12 +151,16 @@ def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
     linalg-on-tensor op is checked to see whether *all* its operands can be
     detensored. If so, those operands are converted to their primitive
     counterparts and the linalg op is replaced by an equivalent op that takes
-    those new primitive values as operands. Therefore, the detensoring process
-    can be divided into 2 main logical phases:
+    those new primitive values as operands. Therefore, detensoring an op can be
+    divided into 2 main logical phases:
 
     1. Detect/match an op that can be detensored.
     2. Detensor the operands of the op and replace it with a primitive
        equivalent.
+
+    In addition to detensoring individual ops, this pass detensors internal
+    control flow inside a function. All blocks except for the entry block are
+    detensored by converting their arguments whenever possible.
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index 55da3af88758..1a0308d96259 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -16,7 +16,9 @@
 namespace mlir {
 
 // Forward declarations.
+class ConversionTarget;
 class MLIRContext;
+class Operation;
 class OwningRewritePatternList;
 class TypeConverter;
 
@@ -26,13 +28,38 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
                                          MLIRContext *ctx,
                                          TypeConverter &converter);
 
-/// Add a pattern to the given pattern list to rewrite branch operations and
-/// `return` to use operands that have been legalized by the conversion
-/// framework. This can only be done if the branch operation implements the
-/// BranchOpInterface. Only needed for partial conversions.
-void populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+/// Add a pattern to the given pattern list to rewrite branch operations to use
+/// operands that have been legalized by the conversion framework. This can only
+/// be done if the branch operation implements the BranchOpInterface. Only
+/// needed for partial conversions.
+void populateBranchOpInterfaceTypeConversionPattern(
     OwningRewritePatternList &patterns, MLIRContext *ctx,
     TypeConverter &converter);
+
+/// Return true if op is a BranchOpInterface op whose operands are all legal
+/// according to converter.
+bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
+                                                      TypeConverter &converter);
+
+/// Add a pattern to the given pattern list to rewrite `return` ops to use
+/// operands that have been legalized by the conversion framework.
+void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns,
+                                           MLIRContext *ctx,
+                                           TypeConverter &converter);
+
+/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
+/// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise,
+/// return false.
+bool isLegalForReturnOpTypeConversionPattern(Operation *op,
+                                             TypeConverter &converter,
+                                             bool returnOpAlwaysLegal = false);
+
+/// Return true if op is neither BranchOpInterface nor ReturnLike.
+///
+/// TODO Try to get rid of this function and invert the meaning of
+/// `isLegalForBranchOpInterfaceTypeConversionPattern` and
+/// `isLegalForReturnOpTypeConversionPattern`.
+bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 0c1161b9945c..fa1752d96268 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -474,6 +474,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
+  /// Convert the types of block arguments within the given region except for
+  /// the entry region. This replaces each non-entry block with a new block
+  /// containing the updated signature.
+  LogicalResult convertNonEntryRegionTypes(Region *region,
+                                           TypeConverter &converter);
+
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 2e2e3b94a34a..2d34468dae72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -21,6 +21,20 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
+                                           ValueRange inputs, Location loc) {
+  assert(inputs.size() == 1);
+  // A detensored value is converted back by creating a new tensor from its
+  // element(s).
+  auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
+      loc, inputs[0].getType(), inputs[0]);
+
+  // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
+  // a tensor<dtype> instead.
+  return builder.create<linalg::TensorReshapeOp>(
+      loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
+}
+
 namespace {
 /// Defines the criteria a TensorType must follow in order to be considered
 /// "detensorable".
@@ -64,6 +78,29 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
   }
 };
 
+/// A conversion pattern for detensoring internal (non-entry) blocks within a
+/// function.
+struct FunctionNonEntryBlockConversion : public ConversionPattern {
+  FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
+                                  MLIRContext *ctx, TypeConverter &converter)
+      : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.startRootUpdate(op);
+
+    if (failed(rewriter.convertNonEntryRegionTypes(
+            &mlir::impl::getFunctionBody(op), *typeConverter))) {
+      rewriter.cancelRootUpdate(op);
+      return failure();
+    }
+
+    rewriter.finalizeRootUpdate(op);
+    return success();
+  }
+};
+
 class DetensorizeTypeConverter : public TypeConverter {
 public:
   DetensorizeTypeConverter() {
@@ -84,18 +121,8 @@ class DetensorizeTypeConverter : public TypeConverter {
       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
     });
 
-    // A detensored value is converted back by creating a new tensor from its
-    // element(s).
-    addSourceMaterialization([](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) -> Value {
-      auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
-          loc, inputs[0].getType(), inputs[0]);
-
-      // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
-      // a tensor<dtype> instead.
-      return builder.create<linalg::TensorReshapeOp>(
-          loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
-    });
+    addSourceMaterialization(sourceMaterializationCallback);
+    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
@@ -139,22 +166,43 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
     OwningRewritePatternList patterns;
     ConversionTarget target(*context);
 
-    target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
-    target.addLegalDialect<linalg::LinalgDialect>();
     target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
-      // If any of the operands or results cannot be detensored, the op is
-      // considered legal and won't be detensored.
-      return llvm::any_of(
-          op.getShapedOperandTypes(), [](ShapedType shapedType) {
-            assert(shapedType.isa<TensorType>());
-            return !canBeDetensored(shapedType.cast<TensorType>());
-          });
+      // If any of the operands or results cannot be detensored (i.e. they are
+      // all legal according the DetensorizeTypeConverter), the op is considered
+      // legal and won't be detensored.
+      return llvm::any_of(op.getShapedOperandTypes(),
+                          [&](ShapedType shapedType) {
+                            return typeConverter.isLegal(shapedType);
+                          });
     });
 
-    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      // A function is legal if all of its non-entry blocks are legal. We don't
+      // legalize the entry block (i.e. the function's signature) since
+      // detensoring can't happen along external calling convention boundaries,
+      // which we conservatively approximate as all function signatures.
+      return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
+        return typeConverter.isLegal(block.getArgumentTypes());
+      });
+    });
+
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
+             isLegalForBranchOpInterfaceTypeConversionPattern(op,
+                                                              typeConverter) ||
+             isLegalForReturnOpTypeConversionPattern(
+                 op, typeConverter, /*returnOpAlwaysLegal*/ true);
+    });
 
-    if (failed(
-            applyPartialConversion(getFunction(), target, std::move(patterns))))
+    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+    patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
+                                                     context, typeConverter);
+    // Since non-entry block arguments get detensorized, we also need to update
+    // the control flow inside the function to reflect the correct types.
+    populateBranchOpInterfaceTypeConversionPattern(patterns, context,
+                                                   typeConverter);
+
+    if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
 
     OwningRewritePatternList canonPatterns;
@@ -162,8 +210,6 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
     if (failed(applyPatternsAndFoldGreedily(getFunction(),
                                             std::move(canonPatterns))))
       signalPassFailure();
-
-    // TODO Properly handle control flow within function boundaries.
   }
 };
 } // namespace

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
index c63150f3ab87..4b5a2d632670 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -40,39 +40,17 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
     target.addDynamicallyLegalOp<CallOp>(
         [&](CallOp op) { return typeConverter.isLegal(op); });
 
-    populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context,
-                                                              typeConverter);
+    populateBranchOpInterfaceTypeConversionPattern(patterns, context,
+                                                   typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, context, typeConverter);
     target.addLegalOp<ModuleOp, ModuleTerminatorOp, TensorLoadOp,
                       TensorToMemrefOp>();
-    target.addDynamicallyLegalOp<ReturnOp>(
-        [&](ReturnOp op) { return typeConverter.isLegal(op); });
-    // Mark terminators as legal if they have the ReturnLike trait or
-    // implement the BranchOpInterface and have valid types. If they do not
-    // implement the trait or interface, mark them as illegal no matter what.
+
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
-      // If it is not a terminator, ignore it.
-      if (!op->mightHaveTrait<OpTrait::IsTerminator>())
-        return true;
-      // If it is not the last operation in the block, also ignore it. We do
-      // this to handle unknown operations, as well.
-      Block *block = op->getBlock();
-      if (!block || &block->back() != op)
-        return true;
-      // ReturnLike operations have to be legalized with their parent. For
-      // return this is handled, for other ops they remain as is.
-      if (op->hasTrait<OpTrait::ReturnLike>())
-        return true;
-      // All successor operands of branch like operations must be rewritten.
-      if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-        for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
-          auto successorOperands = branchOp.getSuccessorOperands(p);
-          if (successorOperands.hasValue() &&
-              !typeConverter.isLegal(successorOperands.getValue().getTypes()))
-            return false;
-        }
-        return true;
-      }
-      return false;
+      return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
+             isLegalForBranchOpInterfaceTypeConversionPattern(op,
+                                                              typeConverter) ||
+             isLegalForReturnOpTypeConversionPattern(op, typeConverter);
     });
 
     if (failed(applyFullConversion(module, target, std::move(patterns))))

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index 07d7c59e192b..4ba2069817a3 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -102,9 +102,61 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
 };
 } // end anonymous namespace
 
-void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
+void mlir::populateBranchOpInterfaceTypeConversionPattern(
     OwningRewritePatternList &patterns, MLIRContext *ctx,
     TypeConverter &typeConverter) {
-  patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
-      typeConverter, ctx);
+  patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter, ctx);
+}
+
+bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
+    Operation *op, TypeConverter &converter) {
+  // All successor operands of branch like operations must be rewritten.
+  if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+    for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
+      auto successorOperands = branchOp.getSuccessorOperands(p);
+      if (successorOperands.hasValue() &&
+          !converter.isLegal(successorOperands.getValue().getTypes()))
+        return false;
+    }
+    return true;
+  }
+
+  return false;
+}
+
+void mlir::populateReturnOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &typeConverter) {
+  patterns.insert<ReturnOpTypeConversion>(typeConverter, ctx);
+}
+
+bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
+                                                   TypeConverter &converter,
+                                                   bool returnOpAlwaysLegal) {
+  // If this is a `return` and the user pass wants to convert/transform across
+  // function boundaries, then `converter` is invoked to check whether the the
+  // `return` op is legal.
+  if (dyn_cast<ReturnOp>(op) && !returnOpAlwaysLegal)
+    return converter.isLegal(op);
+
+  // ReturnLike operations have to be legalized with their parent. For
+  // return this is handled, for other ops they remain as is.
+  if (op->hasTrait<OpTrait::ReturnLike>())
+    return true;
+
+  return false;
+}
+
+bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
+  // If it is not a terminator, ignore it.
+  if (!op->mightHaveTrait<OpTrait::IsTerminator>())
+    return true;
+
+  // If it is not the last operation in the block, also ignore it. We do
+  // this to handle unknown operations, as well.
+  Block *block = op->getBlock();
+  if (!block || &block->back() != op)
+    return true;
+
+  return false;
 }

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 002843c27c6f..ae5b566f32d1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -749,6 +749,10 @@ struct ConversionPatternRewriterImpl {
   convertRegionTypes(Region *region, TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
+  /// Convert the types of non-entry block arguments within the given region.
+  LogicalResult convertNonEntryRegionTypes(Region *region,
+                                           TypeConverter &converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1150,13 +1154,25 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
   if (region->empty())
     return nullptr;
 
-  // Convert the arguments of each block within the region.
+  if (failed(convertNonEntryRegionTypes(region, converter)))
+    return failure();
+
   FailureOr<Block *> newEntry =
       convertBlockSignature(&region->front(), converter, entryConversion);
+  return newEntry;
+}
+
+LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
+    Region *region, TypeConverter &converter) {
+  argConverter.setConverter(region, &converter);
+  if (region->empty())
+    return success();
+
+  // Convert the arguments of each block within the region.
   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
     if (failed(convertBlockSignature(&block, converter)))
       return failure();
-  return newEntry;
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1323,6 +1339,11 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
+LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
+    Region *region, TypeConverter &converter) {
+  return impl->convertNonEntryRegionTypes(region, converter);
+}
+
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
   LLVM_DEBUG({

diff  --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir
new file mode 100644
index 000000000000..a227e753006c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorized_while.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+  br ^bb1(%farg0 : tensor<i32>)
+
+^bb1(%0: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %1 = linalg.init_tensor [] : tensor<i1>
+  %2 = linalg.generic #attrs
+    ins(%0, %farg1 : tensor<i32>, tensor<i32>)
+    outs(%1 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %3 = tensor.extract %2[] : tensor<i1>
+  cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
+
+^bb2(%4: tensor<i32>):  // pred: ^bb1
+  %5 = linalg.init_tensor [] : tensor<i32>
+  %6 = linalg.generic #attrs
+    ins(%4, %4 : tensor<i32>, tensor<i32>)
+    outs(%5 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %8 = addi %arg0, %arg1 : i32
+      linalg.yield %8 : i32
+  } -> tensor<i32>
+  br ^bb1(%6 : tensor<i32>)
+
+^bb3(%7: tensor<i32>):  // pred: ^bb1
+  return %7 : tensor<i32>
+}
+
+// CHECK-LABEL: func @main
+// CHECK-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// CHECK:         tensor.extract {{.*}}
+// CHECK:         br ^[[bb1:.*]](%{{.*}} : i32)
+// CHECK:       ^[[bb1]](%{{.*}}: i32)
+// CHECK:         cmpi slt, {{.*}}
+// CHECK:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK:       ^[[bb2]](%{{.*}}: i32)
+// CHECK:         addi {{.*}}
+// CHECK:         br ^[[bb1]](%{{.*}} : i32)
+// CHECK:       ^[[bb3]](%{{.*}}: i32)
+// CHECK:         tensor.from_elements {{.*}}
+// CHECK:         linalg.tensor_reshape {{.*}}
+// CHECK:         return %{{.*}} : tensor<i32>


        


More information about the Mlir-commits mailing list