[Mlir-commits] [mlir] c89c31a - [mlir][bufferization] Fix bufferization of repetitive regions

Matthias Springer llvmlistbot at llvm.org
Mon Feb 6 07:23:23 PST 2023


Author: Matthias Springer
Date: 2023-02-06T16:23:08+01:00
New Revision: c89c31a2306e07662f1e711a4bc84de1060e0def

URL: https://github.com/llvm/llvm-project/commit/c89c31a2306e07662f1e711a4bc84de1060e0def
DIFF: https://github.com/llvm/llvm-project/commit/c89c31a2306e07662f1e711a4bc84de1060e0def.diff

LOG: [mlir][bufferization] Fix bufferization of repetitive regions

The previous strategy was too complex and faulty. Op dominance cannot be used to rule out RaW conflicts due to op ordering if the reading op and the conflicting writing op are in a sub repetitive region of the closest enclosing repetitive region of the definition of the read value.

Differential Revision: https://reviews.llvm.org/D143087

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index a93a5d9a2cfed..417e3ff3871a6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -575,6 +575,11 @@ Region *getEnclosingRepetitiveRegion(Value value,
 Region *getEnclosingRepetitiveRegion(Block *block,
                                      const BufferizationOptions &options);
 
+/// Assuming that the given region is repetitive, find the next enclosing
+/// repetitive region.
+Region *getNextEnclosingRepetitiveRegion(Region *region,
+                                         const BufferizationOptions &options);
+
 namespace detail {
 /// This is the default implementation of
 /// BufferizableOpInterface::getAliasingOpOperands. Should not be called from

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 12fc89740d15c..c9d57413ed68b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -41,6 +41,15 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
 using namespace mlir;
 using namespace bufferization;
 
+static bool isRepetitiveRegion(Region *region,
+                               const BufferizationOptions &options) {
+  Operation *op = region->getParentOp();
+  if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+    if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+      return true;
+  return false;
+}
+
 Region *bufferization::getEnclosingRepetitiveRegion(
     Operation *op, const BufferizationOptions &options) {
   if (!op->getBlock())
@@ -52,11 +61,9 @@ Region *bufferization::getEnclosingRepetitiveRegion(
     Value value, const BufferizationOptions &options) {
   Region *region = value.getParentRegion();
   while (region) {
-    Operation *op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
-    region = op->getParentRegion();
+    if (isRepetitiveRegion(region, options))
+      return region;
+    region = region->getParentRegion();
   }
   return nullptr;
 }
@@ -67,13 +74,22 @@ Region *bufferization::getEnclosingRepetitiveRegion(
   Operation *op = nullptr;
   do {
     op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
+    if (isRepetitiveRegion(region, options))
+      return region;
   } while ((region = op->getParentRegion()));
   return nullptr;
 }
 
+Region *bufferization::getNextEnclosingRepetitiveRegion(
+    Region *region, const BufferizationOptions &options) {
+  assert(isRepetitiveRegion(region, options) && "expected repetitive region");
+  while ((region = region->getParentRegion())) {
+    if (isRepetitiveRegion(region, options))
+      break;
+  }
+  return region;
+}
+
 Operation *bufferization::getOwnerOfValue(Value value) {
   if (auto opResult = value.dyn_cast<OpResult>())
     return opResult.getDefiningOp();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 7aaf79b8d72f9..3cd84c529be61 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -346,25 +346,27 @@ static bool happensBefore(Operation *a, Operation *b,
   return false;
 }
 
-/// Return `true` if op dominance can be used to rule out read-after-write
-/// conflicts wrt. the given reads and writes.
+/// Return `true` if op dominance can be used to rule out a read-after-write
+/// conflicts based on the ordering of ops.
 ///
-/// Op dominance can often be used to rule out potential conflicts such as
-/// "read" happens before "write". E.g., the following IR is not a RaW conflict
-/// because the the read happens *before* the write.
+/// Generalized op dominance can often be used to rule out potential conflicts
+/// due to "read happens before write". E.g., the following IR is not a RaW
+/// conflict because the read happens *before* the write.
 ///
-/// %0 = ... : tensor<?xf32>
-/// "reading_op"(%0) : tensor<?xf32>
-/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
+/// Example 1:
+/// %0 = ... : tensor<?xf32>                                // DEF
+/// "reading_op"(%0) : tensor<?xf32>                        // READ
+/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
 ///
 /// This is no longer true inside loops (or repetitive regions). In such cases,
 /// there may not be a meaningful `happensBefore` relationship because ops
 /// could be executed multiple times. E.g.:
 ///
-/// %0 = ... : tensor<?xf32>
+/// Example 2:
+/// %0 = ... : tensor<?xf32>                                  // DEF
 /// scf.for ... {
-///   "reading_op"(%0) : tensor<?xf32>
-///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
+///   "reading_op"(%0) : tensor<?xf32>                        // READ
+///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
 ///   ...
 /// }
 ///
@@ -374,92 +376,78 @@ static bool happensBefore(Operation *a, Operation *b,
 /// execution of writing_op. This is problematic because the tensor %0 they
 /// operate on (i.e., the "definition") is defined outside of the loop.
 ///
-/// Counter example:
+/// On a high-level, there is a potential RaW in a program if there exists a
+/// possible program execution such that there is a sequence of DEF, followed
+/// by WRITE, followed by READ. Each additional DEF resets the sequence.
 ///
+/// E.g.:
+/// No conflict:        DEF, WRITE, DEF, READ
+/// Potential conflict: DEF, READ, WRITE, READ, WRITE
+///
+/// Example 1 has no conflict:          DEF, READ, WRITE
+/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
+//
+/// Example 3:
 /// scf.for ... {
 ///   %0 = ... : tensor<?xf32>
 ///   "reading_op"(%0) : tensor<?xf32>
 ///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
 ///   ...
 /// }
+/// This has no conflict: (DEF, READ, WRITE)*
 ///
-/// In this example, the definition %0 is in the same repetitive region as
-/// "writing_op", so op dominance can be used to compute the `happensBefore`
-/// relationship.
-///
-/// Whether op dominance can be used or not is decided as follows: Find the
-/// closest enclosing repetitive region of all buffer writes wrt. the given
-/// tensor reads and writes. (The given sets of reads and writes contain the
-/// entire alias set.) In case of a read, we look at the op that defines the
-/// read value. In case of a write, we look at the op that is writing. If all of
-/// those ops are in the same closest enclosing repetitive region (nullptr in
-/// case of "no repetitive region" found at all), then op dominance can be used.
-/// Otherwise, it cannot be used.
-///
-/// Example: The common enclosing repetitive region is the scf.for loop.
-///          Op dominance can be used.
+/// Example 4:
+/// %0 = ... : tensor<?xf32>
 /// scf.for ... {
-///   %0 = tensor.generate
-///   "read"(%0)
+///   scf.for ... { "reading_op"(%0) }
+///   %1 = "writing_op"(%0)
 /// }
+/// This has a potential conflict: DEF, ((READ)*, WRITE)*
 ///
-/// Example: The common enclosing repetitive region is nullptr: There is no
-///          repetitive region around the tensor.generate. Op dominance can be
-///          used.
-/// %0 = tensor.generate
-/// scf.for ... { "read"(%0) }
+/// Example 5:
+/// %0 = ... : tensor<?xf32>
+/// scf.for ... { %1 = "writing_op"(%0) }
+/// scf.for ... { "reading_op"(%0) }
+/// This has a potential conflict: DEF, WRITE*, READ*
 ///
-/// Example: The common enclosing repetitive regions of tensor.generate and
-///          "write" 
diff er. Op dominance cannot be used.
-/// %0 = tensor.generate
-/// scf.for ... {
-///   "read"(%0)
-///   "write"(%0)
-/// }
+/// The following rules are used to rule out RaW conflicts via ordering of ops:
 ///
-/// Example: The common enclosing repetitive regions of tensor.generate and
-///          "write" 
diff er, but there is no read of %0, so op dominance can be
-///          used.
-/// %0 = tensor.generate
-/// scf.for ... {
-///   "write"(%0)
-/// }
+/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
+///    a repetitive region that enclosing both READ and WRITE, we cannot rule
+///    out RaW conflict due to the ordering of ops.
+/// 2. Otherwise: There are no loops that interfere with our analysis; for
+///    analysis purposes, we can assume that there are no loops/repetitive
+///    regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
+///    or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
 ///
-/// Note: iter_args of loops are not aliases of their respective block
-/// arguments, so op domanice can be used when analyzing ops that operate
-/// on them.
-bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
-                       const DenseSet<OpOperand *> &usesWrite,
+bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
+                       const SetVector<Value> &definitions,
                        const AnalysisState &state) {
   const BufferizationOptions &options = state.getOptions();
-  std::optional<Region *> commonEnclosingRegion;
+  for (Value def : definitions) {
+    Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options);
+    Region *rDef = getEnclosingRepetitiveRegion(def, options);
 
-  // In case of a write, take the region in which the write takes place.
-  for (OpOperand *uWrite : usesWrite) {
-    Region *r = getEnclosingRepetitiveRegion(uWrite->getOwner(), options);
-    if (!commonEnclosingRegion.has_value()) {
-      commonEnclosingRegion = r;
+    // READ and DEF are in the same repetitive region. `happensBefore` can be
+    // used to rule out RaW conflicts due to op ordering.
+    if (rRead == rDef)
       continue;
-    }
-    if (*commonEnclosingRegion != r)
-      return false;
-  }
 
-  // In case of a read, take the region which the read value is defined.
-  for (OpOperand *uRead : usesRead) {
-    // Optimization: Skip reads of values that have no defined contents.
-    if (!state.bufferizesToMemoryWrite(uRead->get()))
-      continue;
-    Region *r = getEnclosingRepetitiveRegion(uRead->get(), options);
-    if (!commonEnclosingRegion.has_value()) {
-      commonEnclosingRegion = r;
-      continue;
+    // Find the enclosing repetitive region of READ that is closest to DEF but
+    // not the repetitive region of DEF itself.
+    while (true) {
+      Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
+      if (nextRegion == rDef)
+        break;
+      assert(nextRegion && "expected to find another repetitive region");
+      rRead = nextRegion;
     }
-    if (*commonEnclosingRegion != r)
+
+    // We cannot use op dominance if WRITE is inside the same repetitive region.
+    if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
       return false;
   }
-
-  return commonEnclosingRegion.has_value();
+  return true;
 }
 
 /// Annotate IR with details about the detected RaW conflict.
@@ -507,10 +495,6 @@ static bool hasReadAfterWriteInterference(
     AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
   const BufferizationOptions &options = state.getOptions();
 
-  // Check if op dominance can be used to rule out read-after-write conflicts.
-  bool useDominance = canUseOpDominance(usesRead, usesWrite, state);
-  LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
-
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
     LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
@@ -542,6 +526,12 @@ static bool hasReadAfterWriteInterference(
                               << uConflictingWrite->getOperandNumber() << " of "
                               << *uConflictingWrite->getOwner() << "\n");
 
+      // Check if op dominance can be used to rule out read-after-write
+      // conflicts.
+      bool useDominance =
+          canUseOpDominance(uRead, uConflictingWrite, definitions, state);
+      LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
+
       // Throughout this loop, check for multiple requirements that have to be
       // met for uConflictingWrite to be an actual conflict.
       Operation *conflictingWritingOp = uConflictingWrite->getOwner();

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
index d4a8febf4722e..b764b41877250 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
@@ -697,3 +697,104 @@ func.func @no_raw_conflict_after_repetitive_use(%arg0: tensor<4xf32>,
 
   return %2, %7 : tensor<4xf32>, tensor<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @read_of_bbarg_in_repetitive_region(
+func.func @read_of_bbarg_in_repetitive_region(
+    %t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
+  // CHECK: scf.for
+  scf.for %iv = %a to %b step %c {
+    // Must bufferize out-of-place because definition of read is in a 
diff erent
+    // repetitive region.
+    // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["false"]}
+    %2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
+    %3 = tensor.extract %2[%a] : tensor<4xf32>
+    vector.print %3 : f32
+    // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+    %4 = tensor.insert %cst into %2[%a] : tensor<4xf32>
+    %5 = tensor.extract %4[%a] : tensor<4xf32>
+    vector.print %5 : f32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_definition_in_same_repetitive_region_as_write(
+func.func @read_definition_in_same_repetitive_region_as_write(
+    %t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
+  // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
+  // CHECK: scf.for
+  scf.for %iv = %a to %b step %c {
+    // Can bufferize in-place.
+    // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
+    %2 = tensor.extract_slice %1[0][4][1] : tensor<10xf32> to tensor<4xf32>
+    %3 = tensor.extract %2[%a] : tensor<4xf32>
+    vector.print %3 : f32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_definition_in_same_repetitive_region_as_conflicting_write(
+func.func @read_definition_in_same_repetitive_region_as_conflicting_write(
+    %t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
+  // Cannot bufferize in-place according to normal op dominance rules.
+  // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "false", "none"]}
+  %1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
+  // CHECK: scf.for
+  scf.for %iv = %a to %b step %c {
+    // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
+    %2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
+    %3 = tensor.extract %2[%a] : tensor<4xf32>
+    vector.print %3 : f32
+  }
+  return
+}
+
+// -----
+
+// CHECK: func @write_value_in_repetitive_region(
+func.func @write_value_in_repetitive_region(
+    %t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
+  %0 = tensor.extract %t[%a] : tensor<10xf32>
+  vector.print %0 : f32
+
+  scf.for %iv = %a to %b step %c {
+    // No further read of %0, so this can bufferize in-place.
+    // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
+    %2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
+    // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
+    %filled = linalg.fill ins(%cst : f32) outs(%2 : tensor<4xf32>) -> tensor<4xf32>
+    %3 = tensor.extract %filled[%a] : tensor<4xf32>
+    vector.print %3 : f32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @nesting_op_repetitive_regions(
+func.func @nesting_op_repetitive_regions(
+    %t: tensor<10xf32>, %a: index, %b: index, %c: index, %cst: f32) {
+  // Cannot bufferize in-place according to normal op dominance rules.
+  // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "false", "none"]}
+  %1 = tensor.insert %cst into %t[%a] : tensor<10xf32>
+  // CHECK: scf.for
+  scf.for %iv1 = %a to %b step %c {
+    // CHECK: scf.for
+    scf.for %iv2 = %a to %b step %c {
+      // CHECK: scf.for
+      scf.for %iv3 = %a to %b step %c {
+        // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
+        %2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
+        %3 = tensor.extract %2[%a] : tensor<4xf32>
+        vector.print %3 : f32
+      }
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list