[Mlir-commits] [mlir] b15b015 - [mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps

Matthias Springer llvmlistbot at llvm.org
Wed Jan 5 07:29:12 PST 2022


Author: Matthias Springer
Date: 2022-01-06T00:28:47+09:00
New Revision: b15b0156cae73c12fe9251688266c0b2302e5d05

URL: https://github.com/llvm/llvm-project/commit/b15b0156cae73c12fe9251688266c0b2302e5d05
DIFF: https://github.com/llvm/llvm-project/commit/b15b0156cae73c12fe9251688266c0b2302e5d05.diff

LOG: [mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps

There is no need to inspect the ReturnOp of the called function.

This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.)

Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message.

Differential Revision: https://reviews.llvm.org/D116446

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index f919bbfc2363..c84545dc21df 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,17 +490,16 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace std_ext {
 
-/// Return the index of the parent function's bbArg that is equivalent to the
-/// given ReturnOp operand (if any).
+/// Return the index of the bbArg in the given FuncOp that is equivalent to the
+/// specified return value (if any).
 static Optional<int64_t>
-getEquivalentFuncArgIdx(ModuleBufferizationState &state,
-                        OpOperand &returnOperand) {
-  FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
-  if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
+getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
+                        int64_t returnValIdx) {
+  if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
     // Return value has no equivalent bbArg.
     return None;
 
-  return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
+  return state.equivalentFuncArgs[funcOp][returnValIdx];
 }
 
 struct CallOpInterface
@@ -529,6 +528,7 @@ struct CallOpInterface
                           BufferizationState &state) const {
     CallOp callOp = cast<CallOp>(op);
     unsigned numResults = callOp.getNumResults();
+    unsigned numOperands = callOp->getNumOperands();
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
            "expected CallOp to a FuncOp");
@@ -542,54 +542,48 @@ struct CallOpInterface
     // For non-tensor results: A mapping from return val indices of the old
     // CallOp to return val indices of the bufferized CallOp.
     SmallVector<Optional<unsigned>> retValMapping(numResults, None);
-
-    if (funcOp.body().empty()) {
-      // The callee is bodiless / external, so we cannot inspect it and we
-      // cannot assume anything. We can just assert that it does not return a
-      // tensor as this would have to bufferize to "return a memref", whose
-      // semantics is ill-defined.
-      for (int i = 0; i < numResults; ++i) {
-        Type returnType = callOp.getResult(i).getType();
-        if (isaTensor(returnType))
-          return callOp->emitError()
-                 << "cannot bufferize bodiless function that returns a tensor";
+    // Operands of the bufferized CallOp.
+    SmallVector<Value> newOperands(numOperands, Value());
+
+    // Based on previously gathered equivalence information, we know if a
+    // tensor result folds onto an operand. These are the only tensor value
+    // results that are supported at the moment.
+    //
+    // For tensors return values that do not fold onto an operand, additional
+    // work is needed (TODO) to either:
+    // * hoist a result into an inplaceable operand or
+    // * devise a better representation to truly return a buffer.
+    //
+    // Note: If a function has no body, no equivalence information is
+    // available. Consequently, a tensor return value cannot be proven to fold
+    // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
+    // the moment.
+
+    // 1. Compute the result types of the new CallOp. Tensor results that are
+    // equivalent to a FuncOp bbArg are no longer returned.
+    for (auto it : llvm::enumerate(callOp.getResultTypes())) {
+      unsigned returnValIdx = it.index();
+      Type returnType = it.value();
+      if (!isaTensor(returnType)) {
+        // Non-tensor values are returned.
+        retValMapping[returnValIdx] = resultTypes.size();
         resultTypes.push_back(returnType);
-        retValMapping[i] = i;
+        continue;
       }
-    } else {
-      // The callee has a body. Based on previously gathered equivalence
-      // information, we know if a tensor result folds onto an operand. These
-      // are the only tensor value returns that are supported at the moment.
-      //
-      // For tensors return values that do not fold onto an operand, additional
-      // work is needed (TODO) to either:
-      // * hoist a result into an inplaceable operand or
-      // * devise a better representation to truly return a buffer.
-      ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      assert(returnOp && "expected func with single return op");
-
-      // For each FuncOp result, keep track of which inplace argument it reuses.
-      for (OpOperand &returnOperand : returnOp->getOpOperands()) {
-        unsigned returnIdx = returnOperand.getOperandNumber();
-        Type returnType = returnOperand.get().getType();
-        if (!isaTensor(returnType)) {
-          // Non-tensor values are returned.
-          retValMapping[returnIdx] = resultTypes.size();
-          resultTypes.push_back(returnType);
-          continue;
-        }
-
-        if (Optional<int64_t> bbArgIdx =
-                getEquivalentFuncArgIdx(moduleState, returnOperand)) {
-          // Return operands that are equivalent to some bbArg, are not
-          // returned.
-          replacementValues[returnIdx] =
-              state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
-          continue;
-        }
-
-        llvm_unreachable("returning non-equivalent tensors not supported");
+
+      if (Optional<int64_t> bbArgIdx =
+              getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
+        // Return operands that are equivalent to some bbArg, are not
+        // returned.
+        Value buffer =
+            state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
+        replacementValues[returnValIdx] = buffer;
+        newOperands[*bbArgIdx] = buffer;
+        continue;
       }
+
+      return callOp->emitError(
+          "call to FuncOp that returns non-equivalent tensors not supported");
     }
 
     // 2. Compute bufferized FunctionType.
@@ -601,23 +595,26 @@ struct CallOpInterface
                                           moduleState.bufferizedFunctionTypes);
 
     // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
-    SmallVector<Value> newOperands;
-    newOperands.reserve(callOp->getNumOperands());
     for (OpOperand &opOperand : callOp->getOpOperands()) {
+      unsigned idx = opOperand.getOperandNumber();
       Value tensorOperand = opOperand.get();
+
       // Non-tensor operands are just copied.
       if (!tensorOperand.getType().isa<TensorType>()) {
-        newOperands.push_back(tensorOperand);
+        newOperands[idx] = tensorOperand;
         continue;
       }
 
-      // Tensor operands are guaranteed to have been buferized.
-      int64_t idx = opOperand.getOperandNumber();
-      Value buffer = state.lookupBuffer(rewriter, tensorOperand);
+      // Retrieve buffers for tensor operands. Tensor operand buffers, who's
+      // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
+      // already stored in `newOperands` during Step 1.
+      Value buffer = newOperands[idx]
+                         ? newOperands[idx]
+                         : state.lookupBuffer(rewriter, tensorOperand);
 
       // Caller / callee type mistmatch is handled with a CastOp.
       auto memRefType = bufferizedFuncType.getInput(idx);
-      // Since we don't yet have a clear layout story, buffer_cast may
+      // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
       // that will either canonicalize away or fail compilation until we can do
@@ -627,20 +624,21 @@ struct CallOpInterface
                                                            memRefType, buffer);
         buffer = castBuffer;
       }
-      newOperands.push_back(buffer);
+      newOperands[idx] = buffer;
     }
 
     // 4. Create the new CallOp.
     Operation *newCallOp = rewriter.create<CallOp>(
         callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
     newCallOp->setAttrs(callOp->getAttrs());
-
-    // 5. Replace the old op with the new op.
+    // Get replacement values for non-tensor / non-equivalent results.
     for (int i = 0; i < replacementValues.size(); ++i) {
       if (replacementValues[i])
         continue;
       replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
     }
+
+    // 5. Replace the old op with the new op.
     state.replaceOp(rewriter, callOp, replacementValues);
 
     return success();

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 02431d9175a9..14bb6ce48f2d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
 
   return %r1, %r2 : vector<5xf32>, vector<5xf32>
 }
+
+// -----
+
+func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+  // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
+  call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+  return
+}
+
+// -----
+
+func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
+  %0 = linalg.init_tensor [5] : tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
+  // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
+  call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list