[Mlir-commits] [mlir] 2fe40c3 - [mlir][bufferize] Fix op filter
Matthias Springer
llvmlistbot at llvm.org
Thu May 12 00:33:15 PDT 2022
Author: Matthias Springer
Date: 2022-05-12T09:33:07+02:00
New Revision: 2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f
URL: https://github.com/llvm/llvm-project/commit/2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f
DIFF: https://github.com/llvm/llvm-project/commit/2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f.diff
LOG: [mlir][bufferize] Fix op filter
Bufferization has an optional filter to exclude certain ops from analysis+bufferization. There were a few remaining places in the codebase where the filter was not checked.
Differential Revision: https://reviews.llvm.org/D125356
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 29d983bdffdf1..f5b6203c2395f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -117,7 +117,7 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
SmallVector<OpOperand *>
AnalysisState::getAliasingOpOperand(OpResult result) const {
if (Operation *op = result.getDefiningOp())
- if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+ if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
return {};
}
@@ -127,7 +127,7 @@ AnalysisState::getAliasingOpOperand(OpResult result) const {
SmallVector<OpResult>
AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
if (auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+ getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
return {};
}
@@ -136,7 +136,7 @@ AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
/// op is not bufferizable.
bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+ getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
@@ -148,7 +148,7 @@ bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
/// `true` if the op is not bufferizable.
bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+ getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
@@ -160,7 +160,7 @@ bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
/// alias. Return false if the op is not bufferizable.
bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+ getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index dd0ba1ca68a02..c15d3cd86bf9c 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,31 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// This is a regression test. The linalg-bufferize pass should ignore all func
+// dialect ops.
+
+// CHECK-LABEL: func private @csum(tensor<6xi64>) -> tensor<6xi64>
+func.func private @csum(%arg0: tensor<6xi64>) -> tensor<6xi64>
+
+// CHECK: func public @main(%[[arg0:.*]]: tensor<2x3xi1>)
+// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[arg0]]
+// CHECK: %[[collapse_m:.*]] = bufferization.to_memref %[[collapse]]
+// CHECK: %[[alloc:.*]] = memref.alloc()
+// CHECK: linalg.generic {{.*}} ins(%[[collapse_m]] : memref<6xi1>) outs(%[[alloc]] : memref<6xi64>)
+// CHECK: %[[generic_t:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK: %[[call:.*]] = call @csum(%[[generic_t]])
+// CHECK: return %[[call]]
+func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xi1> into tensor<6xi1>
+ %1 = linalg.init_tensor [6] : tensor<6xi64>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor<6xi1>) outs(%1 : tensor<6xi64>) {
+ ^bb0(%arg1: i1, %arg2: i64):
+ %4 = arith.extui %arg1 : i1 to i64
+ linalg.yield %4 : i64
+ } -> tensor<6xi64>
+ %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
+ return %3 : tensor<6xi64>
+}
More information about the Mlir-commits
mailing list