[Mlir-commits] [mlir] b44defa - [mlir][linalg][bufferize] Generalize destination-passing style detection

Matthias Springer llvmlistbot at llvm.org
Wed Jan 19 01:21:45 PST 2022


Author: Matthias Springer
Date: 2022-01-19T18:21:29+09:00
New Revision: b44defa5a5964ea7a6ce0f1ada9b59510e59f6d7

URL: https://github.com/llvm/llvm-project/commit/b44defa5a5964ea7a6ce0f1ada9b59510e59f6d7
DIFF: https://github.com/llvm/llvm-project/commit/b44defa5a5964ea7a6ce0f1ada9b59510e59f6d7.diff

LOG: [mlir][linalg][bufferize] Generalize destination-passing style detection

If not allow-return-memref, raise an error if a new memory allocation is returned/yielded from a block. We do not check for new allocations directly, but for ops that yield/return values that are not equivalent to values that are defined outside of the current of the block.

Note: We still need to check that scf.for yield values and bbArgs are aliasing to ensure that getAliasingOpOperand/getAliasingOpResult is correct.

Differential Revision: https://reviews.llvm.org/D116687

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index a93caea59f043..f86550e359010 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -19,9 +19,11 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace scf_ext {
 
-/// Equivalence analysis for scf.for. Raise an error if iter_args are not
-/// equivalent to their corresponding loop yield values.
-struct AssertDestinationPassingStyle : public PostAnalysisStep {
+/// Assert that yielded values of an scf.for op are aliasing their corresponding
+/// bbArgs. This is required because the i-th OpResult of an scf.for op is
+/// currently assumed to alias with the i-th iter_arg (in the absence of
+/// conflicts).
+struct AssertScfForAliasingProperties : public PostAnalysisStep {
   LogicalResult run(Operation *op, BufferizationState &state,
                     BufferizationAliasInfo &aliasInfo,
                     SmallVector<Operation *> &newOps) override;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 99c28fe124f76..1b1467ec36a63 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -96,6 +96,7 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
 
   LINK_LIBS PUBLIC
   MLIRBufferizableOpInterface
+  MLIRControlFlowInterfaces
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRef

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index ffebdfc665061..7955f7b35b61b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -54,6 +54,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
@@ -559,6 +560,76 @@ annotateOpsWithBufferizationMarkers(Operation *op,
   });
 }
 
+/// Assert that IR is in destination-passing style. I.e., every value that is
+/// returned or yielded from a block is:
+/// * aliasing a bbArg of that block or a parent block, or
+/// * aliasing an OpResult of a op in a parent block.
+///
+/// Example:
+/// ```
+/// %0 = "some_op" : tensor<?xf32>
+/// %1 = scf.if %c -> (tensor<?xf32>) {
+///   scf.yield %0 : tensor<?xf32>
+/// } else {
+///   %t = linalg.init_tensor : tensor<?xf32>
+///   scf.yield %t : tensor<?xf32>
+/// }
+/// ```
+/// In the above example, the first scf.yield op satifies destination-passing
+/// style because the yielded value %0 is defined in the parent block. The
+/// second scf.yield op does not satisfy destination-passing style because the
+/// yielded value %t is defined in the same block as the scf.yield op.
+// TODO: The current implementation checks for equivalent values instead of
+// aliasing values, which is stricter than needed. We can currently not check
+// for aliasing values because the analysis is a maybe-alias analysis and we
+// need a must-alias analysis here.
+struct AssertDestinationPassingStyle : public PostAnalysisStep {
+  LogicalResult run(Operation *op, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
+                    SmallVector<Operation *> &newOps) override {
+    LogicalResult status = success();
+    DominanceInfo domInfo(op);
+    op->walk([&](Operation *returnOp) {
+      if (!isRegionReturnLike(returnOp))
+        return WalkResult::advance();
+
+      for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
+        Value returnVal = returnValOperand.get();
+        // Skip non-tensor values.
+        if (!returnVal.getType().isa<TensorType>())
+          continue;
+
+        bool foundEquivValue = false;
+        aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
+          if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
+            Operation *definingOp = bbArg.getOwner()->getParentOp();
+            if (definingOp->isProperAncestor(returnOp))
+              foundEquivValue = true;
+            return;
+          }
+
+          Operation *definingOp = equivVal.getDefiningOp();
+          if (definingOp->getBlock()->findAncestorOpInBlock(
+                  *returnOp->getParentOp()))
+            // Skip ops that happen after `returnOp` and parent ops.
+            if (happensBefore(definingOp, returnOp, domInfo))
+              foundEquivValue = true;
+        });
+
+        if (!foundEquivValue)
+          status =
+              returnOp->emitError()
+              << "operand #" << returnValOperand.getOperandNumber()
+              << " of ReturnLike op does not satisfy destination passing style";
+      }
+
+      return WalkResult::advance();
+    });
+
+    return status;
+  }
+};
+
 /// Rewrite pattern that bufferizes bufferizable ops.
 struct BufferizationPattern
     : public OpInterfaceRewritePattern<BufferizableOpInterface> {
@@ -643,6 +714,13 @@ mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op,
     equivalenceAnalysis(newOps, aliasInfo, state);
   }
 
+  if (!options.allowReturnMemref) {
+    SmallVector<Operation *> newOps;
+    if (failed(
+            AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps)))
+      return failure();
+  }
+
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly)
     annotateOpsWithBufferizationMarkers(op, aliasInfo, state);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 2fd89de1e4aa0..6337cd023dd22 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -391,70 +391,37 @@ struct ForOpInterface
   }
 };
 
-// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that
-// do not bufferize inplace. (Requires a few more changes for ConstantOp,
-// InitTensorOp, CallOp.)
-LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
-    AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
-                                       BufferizationAliasInfo &aliasInfo,
-                                       SmallVector<Operation *> &newOps) {
+LogicalResult
+mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
+    run(Operation *op, BufferizationState &state,
+        BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
   LogicalResult status = success();
-  op->walk([&](scf::YieldOp yieldOp) {
-    if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
-      for (OpOperand &operand : yieldOp->getOpOperands()) {
-        auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-        if (!tensorType)
-          continue;
-
-        OpOperand &forOperand = forOp.getOpOperandForResult(
-            forOp->getResult(operand.getOperandNumber()));
-        auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-        if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
-          // TODO: this could get resolved with copies but it can also turn into
-          // swaps so we need to be careful about order of copies.
-          status =
-              yieldOp->emitError()
-              << "Yield operand #" << operand.getOperandNumber()
-              << " does not bufferize to an equivalent buffer to the matching"
-              << " enclosing scf::for operand";
-          return WalkResult::interrupt();
-        }
-      }
-    }
-
-    if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
-      // IfOps are in destination passing style if all yielded tensors are
-      // a value or equivalent to a value that is defined outside of the IfOp.
-      for (OpOperand &operand : yieldOp->getOpOperands()) {
-        auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-        if (!tensorType)
-          continue;
-
-        bool foundOutsideEquivalent = false;
-        aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) {
-          Operation *valueOp = value.getDefiningOp();
-          if (value.isa<BlockArgument>())
-            valueOp = value.cast<BlockArgument>().getOwner()->getParentOp();
-
-          bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp);
-          bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp);
-
-          if (!inThenBlock && !inElseBlock)
-            foundOutsideEquivalent = true;
-        });
 
-        if (!foundOutsideEquivalent) {
-          status = yieldOp->emitError()
-                   << "Yield operand #" << operand.getOperandNumber()
-                   << " does not bufferize to a buffer that is equivalent to a"
-                   << " buffer defined outside of the scf::if op";
-          return WalkResult::interrupt();
-        }
+  op->walk([&](scf::ForOp forOp) {
+    auto yieldOp =
+        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    for (OpOperand &operand : yieldOp->getOpOperands()) {
+      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
+
+      OpOperand &forOperand = forOp.getOpOperandForResult(
+          forOp->getResult(operand.getOperandNumber()));
+      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+      if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) {
+        // TODO: this could get resolved with copies but it can also turn into
+        // swaps so we need to be careful about order of copies.
+        status =
+            yieldOp->emitError()
+            << "Yield operand #" << operand.getOperandNumber()
+            << " does not bufferize to a buffer that is aliasing the matching"
+            << " enclosing scf::for operand";
+        return WalkResult::interrupt();
       }
     }
-
     return WalkResult::advance();
   });
+
   return status;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f0f1beb53ab01..f4e9476727f07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -95,8 +95,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
         linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
   }
 
-  if (!allowReturnMemref)
-    options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+  // Only certain scf.for ops are supported by the analysis.
+  options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 6db938f72323a..609a0df7a7cb6 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s
 
 // Run fuzzer with 
diff erent seeds.
-// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
 
 // CHECK-LABEL: func @use_tensor_func_arg(
 //  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index f9a809ea15784..447d41bd0d4c2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -41,16 +41,34 @@ func @scf_if_not_equivalent(
   %r = scf.if %cond -> (tensor<?xf32>) {
     scf.yield %t1 : tensor<?xf32>
   } else {
-    // This buffer aliases, but is not equivalent.
+    // This buffer aliases, but it is not equivalent.
     %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
-    // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}}
+    // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
     scf.yield %t2 : tensor<?xf32>
   }
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r : tensor<?xf32>
 }
 
 // -----
 
+func @scf_if_not_aliasing(
+    %cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
+    %idx: index) -> f32 {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    // This buffer aliases.
+    %t2 = linalg.init_tensor [%idx] : tensor<?xf32>
+    // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
+    scf.yield %t2 : tensor<?xf32>
+  }
+  %f = tensor.extract %r[%idx] : tensor<?xf32>
+  return %f : f32
+}
+
+// -----
+
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}
 
 func @foo() {
@@ -80,7 +98,7 @@ func @scf_for(%A : tensor<?xf32>,
     // Throw a wrench in the system by swapping yielded values: this result in a
     // ping-pong of values at each iteration on which we currently want to fail.
 
-    // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
+    // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is aliasing}}
     scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
   }
 
@@ -101,7 +119,7 @@ func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iter
   %c1 = arith.constant 1 : index
   %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor<?xf32>) {
     %r = call @foo(%A) : (tensor<?xf32>) -> (tensor<?xf32>)
-    // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
+    // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is aliasing}}
     scf.yield %r : tensor<?xf32>
   }
   call @fun_with_side_effects(%res) : (tensor<?xf32>) -> ()
@@ -110,7 +128,6 @@ func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iter
 
 // -----
 
-// expected-error @+1 {{memref return type is unsupported}}
 func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   ->  tensor<4xf32>
 {
@@ -122,12 +139,12 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   //     argument aliasing).
   %r0 = tensor.extract_slice %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
 
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r0: tensor<4xf32>
 }
 
 // -----
 
-// expected-error @+1 {{memref return type is unsupported}}
 func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
 {
   %r = scf.if %b -> (tensor<4xf32>) {
@@ -135,6 +152,7 @@ func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32
   } else {
     scf.yield %B : tensor<4xf32>
   }
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r: tensor<4xf32>
 }
 
@@ -142,29 +160,31 @@ func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32
 
 func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32>
 {
-  // expected-error @+1 {{op was not bufferized}}
+  // expected-error: @+1 {{op was not bufferized}}
   %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>)
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r: tensor<4xf32>
 }
 
 // -----
 
-// expected-error @+1 {{memref return type is unsupported}}
 func @mini_test_case1() -> tensor<10x20xf32> {
   %f0 = arith.constant 0.0 : f32
   %t = linalg.init_tensor [10, 20] : tensor<10x20xf32>
   %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32>
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r : tensor<10x20xf32>
 }
 
 // -----
 
-// expected-error @+1 {{memref return type is unsupported}}
 func @main() -> tensor<4xi32> {
   %r = scf.execute_region -> tensor<4xi32> {
     %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
     scf.yield %A: tensor<4xi32>
   }
+
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r: tensor<4xi32>
 }
 
@@ -203,12 +223,42 @@ func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
 
 func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
   %0 = linalg.init_tensor [5] : tensor<5xf32>
+  // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %0 : tensor<5xf32>
 }
 
+// Note: This function is not analyzed because there was an error in the
+// previous one.
 func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
-  // expected-error @+2 {{call to FuncOp that returns non-equivalent tensors not supported}}
-  // expected-error @+1 {{op was not bufferized}}
   call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
   return
 }
+
+// -----
+
+func @destination_passing_style_dominance_test_1(%cst : f32, %idx : index,
+                                                 %idx2 : index) -> f32 {
+  %0 = scf.execute_region -> tensor<?xf32> {
+    %1 = linalg.init_tensor [%idx] : tensor<?xf32>
+    // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
+    scf.yield %1 : tensor<?xf32>
+  }
+  %2 = tensor.insert %cst into %0[%idx] : tensor<?xf32>
+  %r = tensor.extract %2[%idx2] : tensor<?xf32>
+  return %r : f32
+}
+
+// -----
+
+func @destination_passing_style_dominance_test_2(%cst : f32, %idx : index,
+                                                 %idx2 : index) -> f32 {
+  %1 = linalg.init_tensor [%idx] : tensor<?xf32>
+
+  %0 = scf.execute_region -> tensor<?xf32> {
+    // This YieldOp is in destination-passing style, thus no error.
+    scf.yield %1 : tensor<?xf32>
+  }
+  %2 = tensor.insert %cst into %0[%idx] : tensor<?xf32>
+  %r = tensor.extract %2[%idx2] : tensor<?xf32>
+  return %r : f32
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index 909c41cef97f2..aadbeaff86ff8 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s
 
 // Run fuzzer with 
diff erent seeds.
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
 
 // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR
 // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=scf allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index e6c521ddfdc24..a739fc4645ed0 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s
 
 // Run fuzzer with 
diff erent seeds.
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
 
 // CHECK-LABEL: func @transfer_read(%{{.*}}: memref<?xf32, #map>) -> vector<4xf32> {
 func @transfer_read(

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 4006c38400afa..d0852e9c65189 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -104,7 +104,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
   auto options = std::make_unique<BufferizationOptions>();
 
   if (!allowReturnMemref)
-    options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+    options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
 
   options->allowReturnMemref = allowReturnMemref;
   options->allowUnknownOps = allowUnknownOps;

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 09b524baca2b3..b1cf056683b6b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6898,6 +6898,7 @@ cc_library(
     deps = [
         ":BufferizableOpInterface",
         ":BufferizationDialect",
+        ":ControlFlowInterfaces",
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",


        


More information about the Mlir-commits mailing list