[Mlir-commits] [mlir] a98c5a0 - [mlir][linalg][bufferize] Fix CallOps with non-tensor operands

Matthias Springer llvmlistbot at llvm.org
Wed Jan 5 07:19:45 PST 2022


Author: Matthias Springer
Date: 2022-01-06T00:19:23+09:00
New Revision: a98c5a08b15ebc31b9303206b10a50d9949eb4d4

URL: https://github.com/llvm/llvm-project/commit/a98c5a08b15ebc31b9303206b10a50d9949eb4d4
DIFF: https://github.com/llvm/llvm-project/commit/a98c5a08b15ebc31b9303206b10a50d9949eb4d4.diff

LOG: [mlir][linalg][bufferize] Fix CallOps with non-tensor operands

Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for.

Differential Revision: https://reviews.llvm.org/D116445

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ee4ced0b17396..f919bbfc2363d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,6 +490,19 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace std_ext {
 
+/// Return the index of the parent function's bbArg that is equivalent to the
+/// given ReturnOp operand (if any).
+static Optional<int64_t>
+getEquivalentFuncArgIdx(ModuleBufferizationState &state,
+                        OpOperand &returnOperand) {
+  FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
+  if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
+    // Return value has no equivalent bbArg.
+    return None;
+
+  return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
+}
+
 struct CallOpInterface
     : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@@ -515,57 +528,67 @@ struct CallOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     CallOp callOp = cast<CallOp>(op);
+    unsigned numResults = callOp.getNumResults();
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
-           "expected Callop to a FuncOp");
+           "expected CallOp to a FuncOp");
     ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
-    // 1. Filter return types:
-    //    - if the callee is bodiless / external, we cannot inspect it and we
-    //      cannot assume anything. We can just assert that it does not return a
-    //      tensor as this would have to bufferize to "return a memref", whose
-    //      semantics is ill-defined.
-    //    - if the callee has a body, we perform inter-procedural equivalence
-    //      analysis. When successful, a result folds onto an operand. When
-    //      unsuccessful, additional work is needed (TODO) to either:
-    //        * hoist a result into an inplaceable operand or
-    //        * devise a better representation to truly return a buffer.
+    // Result types of the bufferized CallOp.
     SmallVector<Type> resultTypes;
+    // Replacement values for the existing CallOp. These are usually the results
+    // of the bufferized CallOp, unless a tensor result folds onto an operand.
+    SmallVector<Value> replacementValues(numResults, Value());
+    // For non-tensor results: A mapping from return val indices of the old
+    // CallOp to return val indices of the bufferized CallOp.
+    SmallVector<Optional<unsigned>> retValMapping(numResults, None);
+
     if (funcOp.body().empty()) {
-      if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
-        return callOp->emitError()
-               << "cannot bufferize bodiless function that returns a tensor";
+      // The callee is bodiless / external, so we cannot inspect it and we
+      // cannot assume anything. We can just assert that it does not return a
+      // tensor as this would have to bufferize to "return a memref", whose
+      // semantics is ill-defined.
+      for (int i = 0; i < numResults; ++i) {
+        Type returnType = callOp.getResult(i).getType();
+        if (isaTensor(returnType))
+          return callOp->emitError()
+                 << "cannot bufferize bodiless function that returns a tensor";
+        resultTypes.push_back(returnType);
+        retValMapping[i] = i;
+      }
     } else {
+      // The callee has a body. Based on previously gathered equivalence
+      // information, we know if a tensor result folds onto an operand. These
+      // are the only tensor value returns that are supported at the moment.
+      //
+      // For tensors return values that do not fold onto an operand, additional
+      // work is needed (TODO) to either:
+      // * hoist a result into an inplaceable operand or
+      // * devise a better representation to truly return a buffer.
       ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
       assert(returnOp && "expected func with single return op");
 
       // For each FuncOp result, keep track of which inplace argument it reuses.
       for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+        unsigned returnIdx = returnOperand.getOperandNumber();
         Type returnType = returnOperand.get().getType();
         if (!isaTensor(returnType)) {
+          // Non-tensor values are returned.
+          retValMapping[returnIdx] = resultTypes.size();
           resultTypes.push_back(returnType);
           continue;
         }
 
-        // If return operand is equivalent to some bbArg, no need to return it.
-        if (moduleState.equivalentFuncArgs[funcOp].count(
-                returnOperand.getOperandNumber())) {
-          int64_t idx =
-              moduleState
-                  .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
-          Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
-          Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
-          // Add a ToTensorOp to kill all uses of the CallOp return.
-          // Replace all uses of the CallOp results so we can erase the CallOp.
-          // This ToTensorOp must fold/DCE away or bufferization should be
-          // considered failed.
-          Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
-              callOp.getLoc(), buffer);
-          oldRes.replaceAllUsesWith(toTensorOp);
+        if (Optional<int64_t> bbArgIdx =
+                getEquivalentFuncArgIdx(moduleState, returnOperand)) {
+          // Return operands that are equivalent to some bbArg, are not
+          // returned.
+          replacementValues[returnIdx] =
+              state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
           continue;
         }
 
-        resultTypes.push_back(returnType);
+        llvm_unreachable("returning non-equivalent tensors not supported");
       }
     }
 
@@ -612,8 +635,13 @@ struct CallOpInterface
         callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
     newCallOp->setAttrs(callOp->getAttrs());
 
-    // 5. Delete the op at the end of bufferization.
-    callOp->erase();
+    // 5. Replace the old op with the new op.
+    for (int i = 0; i < replacementValues.size(); ++i) {
+      if (replacementValues[i])
+        continue;
+      replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
+    }
+    state.replaceOp(rewriter, callOp, replacementValues);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index f55f3008f2bd2..8f08e37c6774e 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
 
 // -----
 
+// CHECK-LABEL: func @inner_func(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+  // CHECK-NOT: copy
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: memref.store %{{.*}}, %[[arg0]]
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+  %1 = tensor.extract %0[%c1] : tensor<?xf32>
+  // CHECK: return %[[load]] : f32
+  return %0, %1 : tensor<?xf32>, f32
+}
+
+// CHECK-LABEL: func @call_func_with_non_tensor_return(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @call_func_with_non_tensor_return(
+    %t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
+  // CHECK-NOT: copy
+  // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+  %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+  // CHECK: return %[[call]] : f32
+  return %1, %0 : f32, tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @func_without_tensor_args
 func @func_without_tensor_args(%v : vector<10xf32>) -> () {
   // CHECK: %[[alloc:.*]] = memref.alloc()


        


More information about the Mlir-commits mailing list