[Mlir-commits] [mlir] 9c55e71 - [mlir][linalg][bufferize] Bufferize using PostOrder traversal

Matthias Springer llvmlistbot at llvm.org
Thu Oct 21 01:22:04 PDT 2021


Author: Matthias Springer
Date: 2021-10-21T17:21:52+09:00
New Revision: 9c55e718f537577f2aac9e52b2dce9e01aadd1d7

URL: https://github.com/llvm/llvm-project/commit/9c55e718f537577f2aac9e52b2dce9e01aadd1d7
DIFF: https://github.com/llvm/llvm-project/commit/9c55e718f537577f2aac9e52b2dce9e01aadd1d7.diff

LOG: [mlir][linalg][bufferize] Bufferize using PostOrder traversal

This is required for bufferization of scf::IfOp, which is added in a subsequent commit.

Some ops (scf::ForOp, TiledLoopOp) require PreOrder traversal to make sure that bbArgs are mapped before bufferizing the loop body.

Differential Revision: https://reviews.llvm.org/D111924

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 3d7919b6e7125..e0e4105870b4b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1625,6 +1625,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
   Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
                                           resultTypes, newOperands);
   newCallOp->setAttrs(callOp->getAttrs());
+  callOp->erase();
   return success();
 }
 
@@ -2316,34 +2317,44 @@ static LogicalResult bufferizeFuncOpInternals(
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
   OpBuilder b(funcOp->getContext());
-  /// Start by bufferizing `funcOp` arguments.
+
+  // Start by bufferizing `funcOp` arguments.
   if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
     return failure();
 
-  // Walk in PreOrder to ensure ops with regions are handled before their body.
-  // Since walk has to be PreOrder, we need to erase ops that require it
-  // separately: this is the case for CallOp
-  SmallVector<Operation *> toErase;
-  if (funcOp
-          .walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
-            if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
-                                   &globalCreator)))
-              return failure();
-            // Register post-walk erasure, if necessary.
-            if (isa<CallOpInterface>(op))
-              if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
-                  llvm::any_of(op->getResultTypes(), isaTensor))
-                toErase.push_back(op);
-            return success();
-          })
-          .wasInterrupted())
+  // Bufferize the function body. `bufferizedOps` keeps track ops that were
+  // already bufferized with pre-order traversal.
+  DenseSet<Operation *> bufferizedOps;
+  auto walkFunc = [&](Operation *op) -> WalkResult {
+    // Collect ops that need to be bufferized before `op`.
+    SmallVector<Operation *> preorderBufferize;
+    Operation *parentOp = op->getParentOp();
+    // scf::ForOp and TiledLoopOp must be bufferized before their blocks
+    // ("pre-order") because BBargs must be mapped when bufferizing children.
+    while (isa_and_nonnull<scf::ForOp, TiledLoopOp>(parentOp)) {
+      if (bufferizedOps.contains(parentOp))
+        break;
+      bufferizedOps.insert(parentOp);
+      preorderBufferize.push_back(parentOp);
+      parentOp = parentOp->getParentOp();
+    }
+
+    for (Operation *op : llvm::reverse(preorderBufferize))
+      if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
+                             &globalCreator)))
+        return failure();
+
+    if (!bufferizedOps.contains(op) &&
+        failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
+                           &globalCreator)))
+      return failure();
+    return success();
+  };
+  if (funcOp.walk(walkFunc).wasInterrupted())
     return failure();
 
   LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
 
-  for (Operation *op : toErase)
-    op->erase();
-
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 0b77335ebac46..88cfbb4d68b77 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -113,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
 
 func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
 {
-  // expected-error @+1 {{unsupported op with tensors}}
   %r = scf.if %b -> (tensor<4xf32>) {
+    // expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}}
     scf.yield %A : tensor<4xf32>
   } else {
     scf.yield %B : tensor<4xf32>
@@ -144,7 +144,7 @@ func @mini_test_case1() -> tensor<10x20xf32> {
 // -----
 
 func @main() -> tensor<4xi32> {
-  // expected-error @+1 {{unsupported op with tensors}}
+  // expected-error @+1 {{expected result-less scf.execute_region containing op}}
   %r = scf.execute_region -> tensor<4xi32> {
     %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
     scf.yield %A: tensor<4xi32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 07966f57255ea..3df671251b196 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -305,6 +305,28 @@ func @scf_for_yield_only(%A : tensor<?xf32>,
 
 // -----
 
+// Ensure that the function bufferizes without error. This tests pre-order
+// traversal of scf.for loops during bufferization. No need to check the IR,
+// just want to make sure that it does not crash.
+
+// CHECK-LABEL: func @nested_scf_for
+func @nested_scf_for(%A : tensor<?xf32> {linalg.inplaceable = true},
+                     %v : vector<5xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %r1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%B = %A) -> tensor<?xf32> {
+    %r2 = scf.for %j = %c0 to %c10 step %c1 iter_args(%C = %B) -> tensor<?xf32> {
+      %w = vector.transfer_write %v, %C[%c0] : vector<5xf32>, tensor<?xf32>
+      scf.yield %w : tensor<?xf32>
+    }
+    scf.yield %r2 : tensor<?xf32>
+  }
+  return %r1 : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
 // CHECK-LABEL: func @scf_for_with_tensor.insert_slice
@@ -767,7 +789,7 @@ func @tensor_cast_not_in_place(
 // CHECK-LABEL: func @dominance_violation_bug_1
 func @dominance_violation_bug_1(%A : tensor<?x?xf32>, %idx : index) -> tensor<?x?xf32> {
   %f0 = arith.constant 0.0 : f32
-  
+
   %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
   %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32>


        


More information about the Mlir-commits mailing list