[Mlir-commits] [mlir] e8f7d01 - [mlir] Add a flag to allow equivalent results.
Alexander Belyaev
llvmlistbot at llvm.org
Wed May 4 08:48:27 PDT 2022
Author: Alexander Belyaev
Date: 2022-05-04T17:48:18+02:00
New Revision: e8f7d019fc21a300cee0dc9281706ee6d2e4d793
URL: https://github.com/llvm/llvm-project/commit/e8f7d019fc21a300cee0dc9281706ee6d2e4d793
DIFF: https://github.com/llvm/llvm-project/commit/e8f7d019fc21a300cee0dc9281706ee6d2e4d793.diff
LOG: [mlir] Add a flag to allow equivalent results.
Differential Revision: https://reviews.llvm.org/D124931
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 859b138110449..2a6020118c2f7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -47,6 +47,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
/// Specifies whether returning newly allocated memrefs should be allowed.
/// Otherwise, a pass failure is triggered.
bool allowReturnAllocs = false;
+
+ /// Specifies whether buffer return values that are equivalent to a FuncOp
+ /// bbArg should be dropped.
+ bool dropEquivalentFuncResults = true;
};
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 37d4a3e7fb529..1820df8cb9b80 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -230,6 +230,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
`test-analysis-only`.
}];
let options = [
+ Option<"dropEquivalentFuncResults", "drop-equivalent-func-results", "bool",
+ /*default=*/"true",
+ "Drop buffer return values that are equivalent to a FuncOp arg.">,
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
/*default=*/"false",
"Allows returning/yielding new allocations from a block.">,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index a16b19148bb8b..cee7dfc1d432e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -169,6 +169,7 @@ struct OneShotBufferizePass
if (!options) {
// Make new bufferization options if none were provided when creating the
// pass.
+ opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
opt.allowReturnAllocs = allowReturnAllocs;
opt.allowUnknownOps = allowUnknownOps;
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index df29e7a1ae1bf..25d3df2fac287 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -269,17 +269,19 @@ struct CallOpInterface
continue;
}
- if (Optional<int64_t> bbArgIdx =
- getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
- // Return operands that are equivalent to some bbArg, are not
- // returned.
- FailureOr<Value> bufferOrFailure =
- state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
- if (failed(bufferOrFailure))
- return failure();
- replacementValues[returnValIdx] = *bufferOrFailure;
- newOperands[*bbArgIdx] = *bufferOrFailure;
- continue;
+ if (options.dropEquivalentFuncResults) {
+ if (Optional<int64_t> bbArgIdx =
+ getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
+ // Return operands that are equivalent to some bbArg, are not
+ // returned.
+ FailureOr<Value> bufferOrFailure =
+ state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
+ if (failed(bufferOrFailure))
+ return failure();
+ replacementValues[returnValIdx] = *bufferOrFailure;
+ newOperands[*bbArgIdx] = *bufferOrFailure;
+ continue;
+ }
}
if (!options.allowReturnAllocs)
@@ -404,7 +406,8 @@ struct FuncOpInterface
FunctionType funcType = funcOp.getFunctionType();
const FuncAnalysisState &funcState =
getFuncAnalysisState(state.getAnalysisState());
- const BufferizationOptions &options = state.getOptions();
+ const OneShotBufferizationOptions &options =
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
// Construct the bufferized function type.
SmallVector<Type> argTypes;
@@ -479,20 +482,23 @@ struct FuncOpInterface
}
// If return operand is equivalent to some bbArg, no need to return it.
- if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
- funcOp, funcState, returnOperand.getOperandNumber())) {
- rewriter.setInsertionPoint(returnOp);
- Location loc = returnOp.getLoc();
- Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, getMemRefType(returnVal.getType().cast<TensorType>(), options),
- returnVal);
- BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
- // Note: This copy will fold away. It must be inserted here to ensure
- // that `returnVal` still has at least one use and does not fold away.
- if (failed(
- createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
- return funcOp->emitError("could not generate copy for bbArg");
- continue;
+ if (options.dropEquivalentFuncResults) {
+ if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
+ funcOp, funcState, returnOperand.getOperandNumber())) {
+ rewriter.setInsertionPoint(returnOp);
+ Location loc = returnOp.getLoc();
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ loc,
+ getMemRefType(returnVal.getType().cast<TensorType>(), options),
+ returnVal);
+ BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
+ // Note: This copy will fold away. It must be inserted here to ensure
+ // that `returnVal` still has at least one use and does not fold away.
+ if (failed(
+ createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
+ return funcOp->emitError("could not generate copy for bbArg");
+ continue;
+ }
}
returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
index 509c96fac4323..4ef9995c07921 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs drop-equivalent-func-results=false" -split-input-file | FileCheck %s --check-prefix=EQUIV
// Run fuzzer with
diff erent seeds.
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
@@ -62,3 +63,16 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
%r2 = tensor.extract %filled[%idx] : tensor<?xf32>
return %r1, %r2 : f32, f32
}
+
+// -----
+
+func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
+ func.return %A : tensor<?xf32>
+}
+// CHECK-LABEL: func @return_arg
+// CHECK-SAME: %[[A:.*]]: memref<?xf32
+// CHECK-NOT: return %[[A]]
+
+// EQUIV-LABEL: func @return_arg
+// EQUIV-SAME: %[[A:.*]]: memref<?xf32
+// EQUIV: return %[[A]]
More information about the Mlir-commits
mailing list