[Mlir-commits] [mlir] ba9d886 - [mlir][bufferization][NFC] Bufferize with PostOrder traversal

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 03:46:05 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T12:42:41+02:00
New Revision: ba9d886db4fbb2dfd6787bfa073811e77eacbfe7

URL: https://github.com/llvm/llvm-project/commit/ba9d886db4fbb2dfd6787bfa073811e77eacbfe7
DIFF: https://github.com/llvm/llvm-project/commit/ba9d886db4fbb2dfd6787bfa073811e77eacbfe7.diff

LOG: [mlir][bufferization][NFC] Bufferize with PostOrder traversal

This is useful because the result type of an op can sometimes be inferred from its body (e.g., `scf.if`). This will be utilized in subsequent changes.

Also introduces a new `getBufferType` interface method on BufferizableOpInterface. This method is useful for computing a bufferized block argument type with respect to OpOperand types of the parent op.

Differential Revision: https://reviews.llvm.org/D128420

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index ccc9d1d706261..4c56a8196d455 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -337,7 +337,24 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*defaultImplementation=*/[{
           return success();
         }]
-      >
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the bufferized type of the given tensor block argument. The
+          block argument is guaranteed to belong to a block of this op.
+        }],
+        /*retType=*/"BaseMemRefType",
+        /*methodName=*/"getBufferType",
+        /*args=*/(ins "BlockArgument":$bbArg,
+                      "const BufferizationOptions &":$options),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          assert(bbArg.getOwner()->getParentOp() == $_op &&
+                 "bbArg must belong to this op");
+          auto tensorType = bbArg.getType().cast<TensorType>();
+          return bufferization::getMemRefType(tensorType, options);
+        }]
+      >,
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 3fe32bc293ea8..8a01acaf9374a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -482,8 +482,10 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
 
 Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                const BufferizationOptions &options) {
+#ifndef NDEBUG
   auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");
+#endif // NDEBUG
 
   // Replace "%t = to_tensor %m" with %m.
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
@@ -492,7 +494,7 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_memref op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  Type memrefType = getMemRefType(tensorType, options);
+  Type memrefType = getBufferType(value, options);
   ensureToMemrefOpIsValid(value, memrefType);
   return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
                                                     value);
@@ -507,6 +509,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
 
+  if (auto bbArg = value.dyn_cast<BlockArgument>())
+    if (auto bufferizableOp =
+            options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+      return bufferizableOp.getBufferType(bbArg, options);
+
   return getMemRefType(tensorType, options);
 }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f9c809081723e..c68d1d120be6a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -393,9 +393,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   // Otherwise, we have to use a memref type with a fully dynamic layout map to
   // avoid copies. We are currently missing patterns for layout maps to
   // canonicalize away (or canonicalize to more precise layouts).
+  //
+  // FuncOps must be bufferized before their bodies, so add them to the worklist
+  // first.
   SmallVector<Operation *> worklist;
-  op->walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (hasTensorSemantics(op))
+  op->walk([&](func::FuncOp funcOp) {
+    if (hasTensorSemantics(funcOp))
+      worklist.push_back(funcOp);
+  });
+  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
+    if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
       worklist.push_back(op);
   });
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 26f5e2b0b5519..785cd0d7806dc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -725,7 +725,7 @@ struct WhileOpInterface
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
-          return getBufferType(bbArg, options).cast<Type>();
+          return bufferization::getBufferType(bbArg, options).cast<Type>();
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
@@ -1107,7 +1107,7 @@ struct ParallelInsertSliceOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &b,
                           const BufferizationOptions &options) const {
     // Will be bufferized as part of ForeachThreadOp.
-    return failure();
+    return success();
   }
 
   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share

diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 177d820ccefcb..02a073de45699 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -154,7 +154,7 @@ struct AssumingYieldOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     // Op is bufferized as part of AssumingOp.
-    return failure();
+    return success();
   }
 };
 

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index d067704112754..cc0357a055d71 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -313,10 +313,10 @@ func.func @scf_for_swapping_yields(
 //       CHECK:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
 //       CHECK:     memref.copy %[[iter2]], %[[alloc2]]
 //       CHECK:     memref.dealloc %[[iter2]]
-//       CHECK:     %[[casted2:.*]] = memref.cast %[[alloc2]]
 //       CHECK:     %[[alloc1:.*]] = memref.alloc(%{{.*}})
 //       CHECK:     memref.copy %[[iter1]], %[[alloc1]]
 //       CHECK:     memref.dealloc %[[iter1]]
+//       CHECK:     %[[casted2:.*]] = memref.cast %[[alloc2]]
 //       CHECK:     %[[casted1:.*]] = memref.cast %[[alloc1]]
 //       CHECK:     %[[cloned1:.*]] = bufferization.clone %[[casted1]]
 //       CHECK:     memref.dealloc %[[alloc1]]
@@ -384,10 +384,10 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
     // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w1]], %[[a1]]
     // CHECK: memref.dealloc %[[w1]]
-    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
     // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w0]], %[[a0]]
     // CHECK: memref.dealloc %[[w0]]
+    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
     // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
     // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
     // CHECK: memref.dealloc %[[a0]]
@@ -437,10 +437,10 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
     // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w1]], %[[a1]]
     // CHECK: memref.dealloc %[[w1]]
-    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
     // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[w0]], %[[a0]]
     // CHECK: memref.dealloc %[[w0]]
+    // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
     // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
     // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
     // CHECK: memref.dealloc %[[a0]]
@@ -457,9 +457,9 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
     // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[b1]], %[[a3]]
     // CHECK: memref.dealloc %[[b1]]
-    // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
     // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
     // CHECK: memref.copy %[[b0]], %[[a2]]
+    // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
     // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
     // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
     // CHECK: memref.dealloc %[[a2]]


        


More information about the Mlir-commits mailing list