[Mlir-commits] [mlir] 06dacf5 - [mlir][func][bufferization][NFC] Simplify implementation

Matthias Springer llvmlistbot at llvm.org
Tue Aug 15 03:00:21 PDT 2023


Author: Matthias Springer
Date: 2023-08-15T12:00:12+02:00
New Revision: 06dacf5ea798741930fa714ec752c2e971625679

URL: https://github.com/llvm/llvm-project/commit/06dacf5ea798741930fa714ec752c2e971625679
DIFF: https://github.com/llvm/llvm-project/commit/06dacf5ea798741930fa714ec752c2e971625679.diff

LOG: [mlir][func][bufferization][NFC] Simplify implementation

The bufferization implementation of `func.func` and `func.call` can be simplified. It still contained code that was necessary when One-Shot Bufferize removed return values. This functionality has been extracted into a separate pass a while ago.

Differential Revision: https://reviews.llvm.org/D157893

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 1f55728cd5f533..da9b1d9868b571 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -438,16 +438,9 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   // Otherwise, we have to use a memref type with a fully dynamic layout map to
   // avoid copies. We are currently missing patterns for layout maps to
   // canonicalize away (or canonicalize to more precise layouts).
-  //
-  // FuncOps must be bufferized before their bodies, so add them to the worklist
-  // first.
   SmallVector<Operation *> worklist;
-  op->walk([&](func::FuncOp funcOp) {
-    if (hasTensorSemantics(funcOp))
-      worklist.push_back(funcOp);
-  });
   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
-    if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
+    if (hasTensorSemantics(op))
       worklist.push_back(op);
   });
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 89904db43508f5..fbefaf06e11737 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -197,68 +197,67 @@ struct CallOpInterface
     return result;
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto callOp = cast<func::CallOp>(op);
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+
+    // The callee was already bufferized, so we can directly take the type from
+    // its signature.
+    FunctionType funcType = funcOp.getFunctionType();
+    return cast<BaseMemRefType>(
+        funcType.getResult(cast<OpResult>(value).getResultNumber()));
+  }
+
   /// All function arguments are writable. It is the responsibility of the
   /// CallOp to insert buffer copies where necessary.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    unsigned numResults = callOp.getNumResults();
-    unsigned numOperands = callOp->getNumOperands();
-    FuncOp funcOp = getCalledFunction(callOp);
-    assert(funcOp && "expected CallOp to a FuncOp");
-    FunctionType funcType = funcOp.getFunctionType();
-
-    // Result types of the bufferized CallOp.
-    SmallVector<Type> resultTypes;
-    // Replacement values for the existing CallOp. These are usually the results
-    // of the bufferized CallOp, unless a tensor result folds onto an operand.
-    SmallVector<Value> replacementValues(numResults, Value());
-    // For non-tensor results: A mapping from return val indices of the old
-    // CallOp to return val indices of the bufferized CallOp.
-    SmallVector<std::optional<unsigned>> retValMapping(numResults,
-                                                       std::nullopt);
-    // Operands of the bufferized CallOp.
-    SmallVector<Value> newOperands(numOperands, Value());
 
     // 1. Compute the result types of the new CallOp.
-    for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
-      unsigned returnValIdx = it.index();
-      Type returnType = it.value();
+    SmallVector<Type> resultTypes;
+    for (Value result : callOp.getResults()) {
+      Type returnType = result.getType();
       if (!isa<TensorType>(returnType)) {
         // Non-tensor values are returned.
-        retValMapping[returnValIdx] = resultTypes.size();
         resultTypes.push_back(returnType);
         continue;
       }
 
       // Returning a memref.
-      retValMapping[returnValIdx] = resultTypes.size();
-      resultTypes.push_back(funcType.getResult(resultTypes.size()));
+      FailureOr<BaseMemRefType> resultType =
+          bufferization::getBufferType(result, options);
+      if (failed(resultType))
+        return failure();
+      resultTypes.push_back(*resultType);
     }
 
-    // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
-    for (OpOperand &opOperand : callOp->getOpOperands()) {
-      unsigned idx = opOperand.getOperandNumber();
-      Value tensorOperand = opOperand.get();
+    // 2. Rewrite tensor operands as memrefs based on type of the already
+    //    bufferized callee.
+    SmallVector<Value> newOperands;
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+    FunctionType funcType = funcOp.getFunctionType();
 
+    for (OpOperand &opOperand : callOp->getOpOperands()) {
       // Non-tensor operands are just copied.
-      if (!isa<TensorType>(tensorOperand.getType())) {
-        newOperands[idx] = tensorOperand;
+      if (!isa<TensorType>(opOperand.get().getType())) {
+        newOperands.push_back(opOperand.get());
         continue;
       }
 
       // Retrieve buffers for tensor operands.
-      Value buffer = newOperands[idx];
-      if (!buffer) {
-        FailureOr<Value> maybeBuffer =
-            getBuffer(rewriter, opOperand.get(), options);
-        if (failed(maybeBuffer))
-          return failure();
-        buffer = *maybeBuffer;
-      }
+      FailureOr<Value> maybeBuffer =
+          getBuffer(rewriter, opOperand.get(), options);
+      if (failed(maybeBuffer))
+        return failure();
+      Value buffer = *maybeBuffer;
 
       // Caller / callee type mismatch is handled with a CastOp.
-      auto memRefType = funcType.getInput(idx);
+      auto memRefType = funcType.getInput(opOperand.getOperandNumber());
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
@@ -272,22 +271,16 @@ struct CallOpInterface
                                                            memRefType, buffer);
         buffer = castBuffer;
       }
-      newOperands[idx] = buffer;
+      newOperands.push_back(buffer);
     }
 
     // 3. Create the new CallOp.
     Operation *newCallOp = rewriter.create<func::CallOp>(
         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
     newCallOp->setAttrs(callOp->getAttrs());
-    // Get replacement values.
-    for (unsigned i = 0; i < replacementValues.size(); ++i) {
-      if (replacementValues[i])
-        continue;
-      replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
-    }
 
     // 4. Replace the old op with the new op.
-    replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
+    replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
 
     return success();
   }
@@ -326,6 +319,17 @@ struct ReturnOpInterface
 
 struct FuncOpInterface
     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto funcOp = cast<FuncOp>(op);
+    auto bbArg = cast<BlockArgument>(value);
+    // Unstructured control flow is not supported.
+    assert(bbArg.getOwner() == &funcOp.getBody().front() &&
+           "expected that block argument belongs to first block");
+    return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
+  }
+
   /// Rewrite function bbArgs and return values into buffer form. This function
   /// bufferizes the function signature and the ReturnOp. When the entire
   /// function body has been bufferized, function return types can be switched
@@ -384,9 +388,11 @@ struct FuncOpInterface
         bbArgUses.push_back(&use);
 
       // Change the bbArg type to memref.
-      Type memrefType =
-          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
-      bbArg.setType(memrefType);
+      FailureOr<BaseMemRefType> memrefType =
+          bufferization::getBufferType(bbArg, options);
+      if (failed(memrefType))
+        return failure();
+      bbArg.setType(*memrefType);
 
       // Replace all uses of the original tensor bbArg.
       rewriter.setInsertionPointToStart(&frontBlock);


        


More information about the Mlir-commits mailing list