[Mlir-commits] [mlir] dacfb24 - [mlir] Support inlining into affine operations

Alex Zinenko llvmlistbot at llvm.org
Fri Dec 11 07:24:37 PST 2020


Author: Alex Zinenko
Date: 2020-12-11T16:24:27+01:00
New Revision: dacfb24b301d2f0422f2c7a23e2919e2f35cd932

URL: https://github.com/llvm/llvm-project/commit/dacfb24b301d2f0422f2c7a23e2919e2f35cd932
DIFF: https://github.com/llvm/llvm-project/commit/dacfb24b301d2f0422f2c7a23e2919e2f35cd932.diff

LOG: [mlir] Support inlining into affine operations

Introduce support for inlining into affine operations. This uses the generic
inline infrastructure and boils down to checking that, if applied, the inlining
doesn't violate the affine dimension/symbol value categorization. Given valid
IR, only the values that are valid dimensions/symbols thanks to being top-level
in their affine scope need special handling.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D92770

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/inlining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 005db18c54e5..d1d577799b39 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
@@ -25,6 +26,99 @@ using llvm::dbgs;
 
 #define DEBUG_TYPE "affine-analysis"
 
+/// A utility function to check if a value is defined at the top level of
+/// `region` or is an argument of `region`. A value of index type defined at the
+/// top level of a `AffineScope` region is always a valid symbol for all
+/// uses in that region.
+static bool isTopLevelValue(Value value, Region *region) {
+  if (auto arg = value.dyn_cast<BlockArgument>())
+    return arg.getParentRegion() == region;
+  return value.getDefiningOp()->getParentRegion() == region;
+}
+
+/// Checks if `value` known to be a legal affine dimension or symbol in `src`
+/// region remains legal if the operation that uses it is inlined into `dest`
+/// with the given value mapping. `legalityCheck` is either `isValidDim` or
+/// `isValidSymbol`, depending on the value being required to remain a valid
+/// dimension or symbol.
+static bool
+remainsLegalAfterInline(Value value, Region *src, Region *dest,
+                        const BlockAndValueMapping &mapping,
+                        function_ref<bool(Value, Region *)> legalityCheck) {
+  // If the value is a valid dimension for any other reason than being
+  // a top-level value, it will remain valid: constants get inlined
+  // with the function, transitive affine applies also get inlined and
+  // will be checked themselves, etc.
+  if (!isTopLevelValue(value, src))
+    return true;
+
+  // If it's a top-level value because it's a block operand, i.e. a
+  // function argument, check whether the value replacing it after
+  // inlining is a valid dimension in the new region.
+  if (value.isa<BlockArgument>())
+    return legalityCheck(mapping.lookup(value), dest);
+
+  // If it's a top-level value beacuse it's defined in the region,
+  // it can only be inlined if the defining op is a constant or a
+  // `dim`, which can appear anywhere and be valid, since the defining
+  // op won't be top-level anymore after inlining.
+  Attribute operandCst;
+  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
+         value.getDefiningOp<DimOp>();
+}
+
+/// Checks if all values known to be legal affine dimensions or symbols in `src`
+/// remain so if their respective users are inlined into `dest`.
+static bool
+remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
+                        const BlockAndValueMapping &mapping,
+                        function_ref<bool(Value, Region *)> legalityCheck) {
+  return llvm::all_of(values, [&](Value v) {
+    return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
+  });
+}
+
+/// Checks if an affine read or write operation remains legal after inlining
+/// from `src` to `dest`.
+template <typename OpTy>
+static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
+                                    const BlockAndValueMapping &mapping) {
+  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
+                                AffineWriteOpInterface>::value,
+                "only ops with affine read/write interface are supported");
+
+  AffineMap map = op.getAffineMap();
+  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
+  ValueRange symbolOperands =
+      op.getMapOperands().take_back(map.getNumSymbols());
+  if (!remainsLegalAfterInline(
+          dimOperands, src, dest, mapping,
+          static_cast<bool (*)(Value, Region *)>(isValidDim)))
+    return false;
+  if (!remainsLegalAfterInline(
+          symbolOperands, src, dest, mapping,
+          static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
+    return false;
+  return true;
+}
+
+/// Checks if an affine apply operation remains legal after inlining from `src`
+/// to `dest`.
+template <>
+bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
+                             const BlockAndValueMapping &mapping) {
+  // If it's a valid dimension, we need to check that it remains so.
+  if (isValidDim(op.getResult(), src))
+    return remainsLegalAfterInline(
+        op.getMapOperands(), src, dest, mapping,
+        static_cast<bool (*)(Value, Region *)>(isValidDim));
+
+  // Otherwise it must be a valid symbol, check that it remains so.
+  return remainsLegalAfterInline(
+      op.getMapOperands(), src, dest, mapping,
+      static_cast<bool (*)(Value, Region *)>(isValidSymbol));
+}
+
 //===----------------------------------------------------------------------===//
 // AffineDialect Interfaces
 //===----------------------------------------------------------------------===//
@@ -41,22 +135,62 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
 
   /// Returns true if the given region 'src' can be inlined into the region
   /// 'dest' that is attached to an operation registered to the current dialect.
+  /// 'wouldBeCloned' is set if the region is cloned into its new location
+  /// rather than moved, indicating there may be other users.
   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
                        BlockAndValueMapping &valueMapping) const final {
-    // Conservatively don't allow inlining into affine structures.
-    return false;
+    // We can inline into affine loops and conditionals if this doesn't break
+    // affine value categorization rules.
+    Operation *destOp = dest->getParentOp();
+    if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
+      return false;
+
+    // Multi-block regions cannot be inlined into affine constructs, all of
+    // which require single-block regions.
+    if (!llvm::hasSingleElement(*src))
+      return false;
+
+    // Side-effecting operations that the affine dialect cannot understand
+    // should not be inlined.
+    Block &srcBlock = src->front();
+    for (Operation &op : srcBlock) {
+      // Ops with no side effects are fine,
+      if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
+        if (iface.hasNoEffect())
+          continue;
+      }
+
+      // Assuming the inlined region is valid, we only need to check if the
+      // inlining would change it.
+      bool remainsValid =
+          llvm::TypeSwitch<Operation *, bool>(&op)
+              .Case<AffineApplyOp, AffineReadOpInterface,
+                    AffineWriteOpInterface>([&](auto op) {
+                return remainsLegalAfterInline(op, src, dest, valueMapping);
+              })
+              .Default([](Operation *) {
+                // Conservatively disallow inlining ops we cannot reason about.
+                return false;
+              });
+
+      if (!remainsValid)
+        return false;
+    }
+
+    return true;
   }
 
   /// Returns true if the given operation 'op', that is registered to this
   /// dialect, can be inlined into the given region, false otherwise.
   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
                        BlockAndValueMapping &valueMapping) const final {
-    // Always allow inlining affine operations into the top-level region of a
-    // function. There are some edge cases when inlining *into* affine
-    // structures, but that is handled in the other 'isLegalToInline' hook
-    // above.
-    // TODO: We should be able to inline into other regions than functions.
-    return isa<FuncOp>(region->getParentOp());
+    // Always allow inlining affine operations into a region that is marked as
+    // affine scope, or into affine loops and conditionals. There are some edge
+    // cases when inlining *into* affine structures, but that is handled in the
+    // other 'isLegalToInline' hook above.
+    Operation *parentOp = region->getParentOp();
+    return parentOp->hasTrait<OpTrait::AffineScope>() ||
+           isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
   }
 
   /// Affine regions should be analyzed recursively.
@@ -101,16 +235,6 @@ bool mlir::isTopLevelValue(Value value) {
   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
 }
 
-/// A utility function to check if a value is defined at the top level of
-/// `region` or is an argument of `region`. A value of index type defined at the
-/// top level of a `AffineScope` region is always a valid symbol for all
-/// uses in that region.
-static bool isTopLevelValue(Value value, Region *region) {
-  if (auto arg = value.dyn_cast<BlockArgument>())
-    return arg.getParentRegion() == region;
-  return value.getDefiningOp()->getParentRegion() == region;
-}
-
 /// Returns the closest region enclosing `op` that is held by an operation with
 /// trait `AffineScope`; `nullptr` if there is no such region.
 //  TODO: getAffineScope should be publicly exposed for affine passes/utilities.

diff  --git a/mlir/test/Dialect/Affine/inlining.mlir b/mlir/test/Dialect/Affine/inlining.mlir
index e65ae5d0b73a..5879acdeaedb 100644
--- a/mlir/test/Dialect/Affine/inlining.mlir
+++ b/mlir/test/Dialect/Affine/inlining.mlir
@@ -54,16 +54,77 @@ func @not_inline_invalid_nest_op() {
 
 // -----
 
-// Test that calls are not inlined into affine structures.
+// Test that calls are inlined into affine structures.
 func @func_noop() {
   return
 }
 
-// CHECK-LABEL: func @not_inline_into_affine_ops
-func @not_inline_into_affine_ops() {
-  // CHECK: call @func_noop
+// CHECK-LABEL: func @inline_into_affine_ops
+func @inline_into_affine_ops() {
+  // CHECK-NOT: call @func_noop
   affine.for %i = 1 to 10 {
     call @func_noop() : () -> ()
   }
   return
 }
+
+// -----
+
+// Test that calls with dimension arguments are properly inlined.
+func @func_dim(%arg0: index, %arg1: memref<?xf32>) {
+  affine.load %arg1[%arg0] : memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: @inline_dimension
+// CHECK: (%[[ARG0:.*]]: memref<?xf32>)
+func @inline_dimension(%arg0: memref<?xf32>) {
+  // CHECK: affine.for %[[IV:.*]] =
+  affine.for %i = 1 to 42 {
+    // CHECK-NOT: call @func_dim
+    // CHECK: affine.load %[[ARG0]][%[[IV]]]
+    call @func_dim(%i, %arg0) : (index, memref<?xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// Test that calls with vector operations are also inlined.
+func @func_vector_dim(%arg0: index, %arg1: memref<32xf32>) {
+  affine.vector_load %arg1[%arg0] : memref<32xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @inline_dimension_vector
+// CHECK: (%[[ARG0:.*]]: memref<32xf32>)
+func @inline_dimension_vector(%arg0: memref<32xf32>) {
+  // CHECK: affine.for %[[IV:.*]] =
+  affine.for %i = 1 to 42 {
+    // CHECK-NOT: call @func_dim
+    // CHECK: affine.vector_load %[[ARG0]][%[[IV]]]
+    call @func_vector_dim(%i, %arg0) : (index, memref<32xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// Test that calls that would result in violation of affine value
+// categorization (top-level value stop being top-level) are not inlined.
+func private @get_index() -> index
+
+func @func_top_level(%arg0: memref<?xf32>) {
+  %0 = call @get_index() : () -> index
+  affine.load %arg0[%0] : memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: @no_inline_not_top_level
+func @no_inline_not_top_level(%arg0: memref<?xf32>) {
+  affine.for %i = 1 to 42 {
+    // CHECK: call @func_top_level
+    call @func_top_level(%arg0) : (memref<?xf32>) -> ()
+  }
+  return
+}


        


More information about the Mlir-commits mailing list