[Mlir-commits] [mlir] 4002eaa - [mlir][bufferize] Improve analysis of external functions
Matthias Springer
llvmlistbot at llvm.org
Fri Dec 9 05:36:41 PST 2022
Author: Matthias Springer
Date: 2022-12-09T14:36:33+01:00
New Revision: 4002eaaa01b71ab1f74fc15b151c72c0626d9d2d
URL: https://github.com/llvm/llvm-project/commit/4002eaaa01b71ab1f74fc15b151c72c0626d9d2d
DIFF: https://github.com/llvm/llvm-project/commit/4002eaaa01b71ab1f74fc15b151c72c0626d9d2d.diff
LOG: [mlir][bufferize] Improve analysis of external functions
External functions have no body, so they cannot be analyzed. Assume conservatively that each tensor bbArg may be aliasing with each tensor result. Furthermore, assume that each function arg is read and written-to after bufferization. This default behavior can be controlled with `bufferization.access` (similar to `bufferization.memory_layout`) in test cases.
Also fix a bug in the dialect attribute verifier, which did not run for region argument attributes.
Differential Revision: https://reviews.llvm.org/D139517
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
index 09f3e73ffb2b1..280bfdb380177 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
@@ -30,11 +30,29 @@ def Bufferization_Dialect : Dialect {
];
let extraClassDeclaration = [{
+ /// Verify an attribute from this dialect on the argument at 'argIndex' for
+ /// the region at 'regionIndex' on the given operation. Returns failure if
+ /// the verification failed, success otherwise. This hook may optionally be
+ /// invoked from any operation containing a region.
+ LogicalResult verifyRegionArgAttribute(Operation *,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute) override;
+
/// An attribute that can override writability of buffers of tensor function
/// arguments during One-Shot Module Bufferize.
constexpr const static ::llvm::StringLiteral
kWritableAttrName = "bufferization.writable";
+ /// An attribute for function arguments that describes how the function
+ /// accesses the buffer. Can be one "none", "read", "write" or "read-write".
+ ///
+ /// When no attribute is specified, the analysis tries to infer the access
+ /// behavior from its body. In case of external functions, for which no
+ /// function body is available, "read-write" is assumed by default.
+ constexpr const static ::llvm::StringLiteral
+ kBufferAccessAttrName = "bufferization.access";
+
/// Attribute name used to mark the bufferization layout for region
/// arguments during One-Shot Module Bufferize.
constexpr const static ::llvm::StringLiteral
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 1798346b7deef..a8e6d51940395 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -59,19 +59,34 @@ void mlir::bufferization::BufferizationDialect::initialize() {
addInterfaces<BufferizationInlinerInterface>();
}
-LogicalResult
-BufferizationDialect::verifyOperationAttribute(Operation *op,
- NamedAttribute attr) {
- using bufferization::BufferizableOpInterface;
-
+LogicalResult BufferizationDialect::verifyRegionArgAttribute(
+ Operation *op, unsigned /*regionIndex*/, unsigned argIndex,
+ NamedAttribute attr) {
if (attr.getName() == kWritableAttrName) {
if (!attr.getValue().isa<BoolAttr>()) {
return op->emitError() << "'" << kWritableAttrName
<< "' is expected to be a boolean attribute";
}
if (!isa<FunctionOpInterface>(op))
- return op->emitError() << "expected " << attr.getName()
- << " to be used on function-like operations";
+ return op->emitError() << "expected '" << kWritableAttrName
+ << "' to be used on function-like operations";
+ if (cast<FunctionOpInterface>(op).isExternal())
+ return op->emitError() << "'" << kWritableAttrName
+ << "' is invalid on external functions";
+ return success();
+ }
+ if (attr.getName() == kBufferAccessAttrName) {
+ if (!attr.getValue().isa<StringAttr>()) {
+ return op->emitError() << "'" << kBufferAccessAttrName
+ << "' is expected to be a string attribute";
+ }
+ StringRef str = attr.getValue().cast<StringAttr>().getValue();
+ if (str != "none" && str != "read" && str != "write" && str != "read-write")
+ return op->emitError()
+ << "invalid value for '" << kBufferAccessAttrName << "'";
+ if (!isa<FunctionOpInterface>(op))
+ return op->emitError() << "expected '" << kBufferAccessAttrName
+ << "' to be used on function-like operations";
return success();
}
if (attr.getName() == kBufferLayoutAttrName) {
@@ -80,10 +95,20 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
<< "' is expected to be a affine map attribute";
}
if (!isa<FunctionOpInterface>(op))
- return op->emitError() << "expected " << attr.getName()
- << " to be used on function-like operations";
+ return op->emitError() << "expected '" << kBufferLayoutAttrName
+ << "' to be used on function-like operations";
return success();
}
+ return op->emitError() << "attribute '" << kBufferLayoutAttrName
+ << "' not supported as a region arg attribute by the "
+ "bufferization dialect";
+}
+
+LogicalResult
+BufferizationDialect::verifyOperationAttribute(Operation *op,
+ NamedAttribute attr) {
+ using bufferization::BufferizableOpInterface;
+
if (attr.getName() == kEscapeAttrName) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr)
@@ -116,6 +141,7 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
return success();
}
- return op->emitError() << "attribute '" << attr.getName()
- << "' not supported by the bufferization dialect";
+ return op->emitError()
+ << "attribute '" << attr.getName()
+ << "' not supported as an op attribute by the bufferization dialect";
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index e2878b2e1df76..87cd11657f92a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -127,6 +127,25 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
static LogicalResult
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
+ if (funcOp.getBody().empty()) {
+ // No function body available. Conservatively assume that every tensor
+ // return value may alias with any tensor bbArg.
+ FunctionType type = funcOp.getFunctionType();
+ for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
+ if (!inputIt.value().isa<TensorType>())
+ continue;
+ for (const auto &resultIt : llvm::enumerate(type.getResults())) {
+ if (!resultIt.value().isa<TensorType>())
+ continue;
+ int64_t returnIdx = resultIt.index();
+ int64_t bbArgIdx = inputIt.index();
+ funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
+ funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
+ }
+ }
+ return success();
+ }
+
// Support only single return-terminated block in the function.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
@@ -151,8 +170,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
-static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
- bool isRead, bool isWritten) {
+static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
+ bool isWritten) {
OpBuilder b(funcOp.getContext());
Attribute accessType;
if (isRead && isWritten) {
@@ -164,7 +183,8 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
} else {
accessType = b.getStringAttr("none");
}
- funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType);
+ funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
+ accessType);
}
/// Determine which FuncOp bbArgs are read and which are written. When run on a
@@ -173,28 +193,37 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
static LogicalResult
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- // If the function has no body, conservatively assume that all args are
- // read + written.
- if (funcOp.getBody().empty()) {
- for (BlockArgument bbArg : funcOp.getArguments()) {
- funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
- funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+ for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
+ ++idx) {
+ // Skip non-tensor arguments.
+ if (!funcOp.getFunctionType().getInput(idx).isa<TensorType>())
+ continue;
+ bool isRead;
+ bool isWritten;
+ if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
+ idx, BufferizationDialect::kBufferAccessAttrName)) {
+ // Buffer access behavior is specified on the function. Skip the analysis.
+ StringRef str = accessAttr.getValue();
+ isRead = str == "read" || str == "read-write";
+ isWritten = str == "write" || str == "read-write";
+ } else if (funcOp.getBody().empty()) {
+ // If the function has no body, conservatively assume that all args are
+ // read + written.
+ isRead = true;
+ isWritten = true;
+ } else {
+ // Analyze the body of the function.
+ BlockArgument bbArg = funcOp.getArgument(idx);
+ isRead = state.isValueRead(bbArg);
+ isWritten = state.isValueWritten(bbArg);
}
- return success();
- }
-
- for (BlockArgument bbArg : funcOp.getArguments()) {
- if (!bbArg.getType().isa<TensorType>())
- continue;
- bool isRead = state.isValueRead(bbArg);
- bool isWritten = state.isValueWritten(bbArg);
if (state.getOptions().testAnalysisOnly)
- annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten);
+ annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
if (isRead)
- funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.readBbArgs[funcOp].insert(idx);
if (isWritten)
- funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.writtenBbArgs[funcOp].insert(idx);
}
return success();
@@ -351,10 +380,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// Analyze ops.
for (func::FuncOp funcOp : orderedFuncOps) {
- // No body => no analysis.
- if (funcOp.getBody().empty())
- continue;
-
// Now analyzing function.
funcState.startFunctionAnalysis(funcOp);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 4a637a64e2ad0..9069010bc4cf1 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1280,3 +1280,66 @@ func.func @write_to_same_alloc_tensor_out_of_place(
return %r0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func private @ext_func(tensor<*xf32> {bufferization.access = "read-write"})
+func.func private @ext_func(%t: tensor<*xf32>)
+
+// CHECK: func.func @private_func_read_write(%{{.*}}: tensor<5xf32> {bufferization.access = "read"})
+func.func @private_func_read_write(%t: tensor<5xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ // Bufferizes out-of-place because `ext_func` may modify the buffer.
+ // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["false"]}
+ %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32>
+ func.call @ext_func(%0) : (tensor<*xf32>) -> ()
+ %1 = tensor.extract %t[%c0] : tensor<5xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @print_buffer(tensor<*xf32> {bufferization.access = "read"})
+func.func private @print_buffer(%t: tensor<*xf32> {bufferization.access = "read"})
+
+// CHECK: func.func @private_func_read(%{{.*}}: tensor<5xf32> {bufferization.access = "read"})
+func.func @private_func_read(%t: tensor<5xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ // Bufferizes in-place because `print_buffer` is read-only.
+ // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["true"]}
+ %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32>
+ // CHECK: call @print_buffer(%cast) {__inplace_operands_attr__ = ["true"]}
+ func.call @print_buffer(%0) : (tensor<*xf32>) -> ()
+ %1 = tensor.extract %t[%c0] : tensor<5xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @ext_func(tensor<?xf32> {bufferization.access = "read-write"}, tensor<?xf32> {bufferization.access = "read-write"})
+func.func private @ext_func(%t1: tensor<?xf32>, %t2: tensor<?xf32>)
+
+// CHECK: func.func @private_func_two_params_writing(%{{.*}}: tensor<?xf32> {bufferization.access = "read"})
+func.func @private_func_two_params_writing(%t: tensor<?xf32>) {
+ // Both operands bufferize out-of-place because both bufferize to a memory
+ // write.
+ // CHECK: call @ext_func(%{{.*}}, %{{.*}}) {__inplace_operands_attr__ = ["false", "false"]}
+ func.call @ext_func(%t, %t) : (tensor<?xf32>, tensor<?xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @ext_func(tensor<?xf32> {bufferization.access = "read-write"}) -> (tensor<5xf32>, tensor<6xf32>)
+func.func private @ext_func(%t: tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
+
+// CHECK: func.func @private_func_aliasing(%{{.*}}: tensor<?xf32> {bufferization.access = "read"})
+func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ // Bufferizes out-of-place because either one of the two reuslts may alias
+ // with the argument and one of the results is read afterwards.
+ // CHECK: call @ext_func(%{{.*}}) {__inplace_operands_attr__ = ["false"]} : (tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
+ %0, %1 = func.call @ext_func(%t) : (tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
+ %2 = tensor.extract %1[%c0] : tensor<6xf32>
+ return %2 : f32
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index a0f72f3598737..da0fe74db60a8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -158,7 +158,7 @@ func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>,
// -----
-func.func private @fun_with_side_effects(%A: tensor<?xf32> {bufferization.writable = true})
+func.func private @fun_with_side_effects(%A: tensor<?xf32>)
func.func @foo(%A: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>) {
call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index a7dc5e07310fd..32c22c167a5e7 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -78,3 +78,20 @@ func.func @sparse_alloc_call() {
call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> ()
return
}
+
+// -----
+
+// expected-error @+1{{invalid value for 'bufferization.access'}}
+func.func private @invalid_buffer_access_type(tensor<*xf32> {bufferization.access = "foo"})
+
+// -----
+
+// expected-error @+1{{'bufferization.writable' is invalid on external functions}}
+func.func private @invalid_writable_attribute(tensor<*xf32> {bufferization.writable = false})
+
+// -----
+
+func.func @invalid_writable_on_op() {
+ // expected-error @+1{{attribute '"bufferization.writable"' not supported as an op attribute by the bufferization dialect}}
+ arith.constant {bufferization.writable = true} 0 : index
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index ce408f7f1a8c2..90c88d86a11c1 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -129,7 +129,7 @@ func.func @scf_for_with_tensor.insert_slice(
// CHECK-LABEL: func @execute_region_with_conflict(
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
func.func @execute_region_with_conflict(
- %t1 : tensor<?xf32> {bufferization.writable = "true"})
+ %t1 : tensor<?xf32> {bufferization.writable = true})
-> (f32, tensor<?xf32>, f32)
{
%f1 = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list