[Mlir-commits] [mlir] 6c654b5 - [mlir][linalg][bufferize] Support std.select bufferization
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 12 00:51:22 PST 2022
Author: Matthias Springer
Date: 2022-01-12T17:46:44+09:00
New Revision: 6c654b51983543dbd65736e992bb1cbe50b4336e
URL: https://github.com/llvm/llvm-project/commit/6c654b51983543dbd65736e992bb1cbe50b4336e
DIFF: https://github.com/llvm/llvm-project/commit/6c654b51983543dbd65736e992bb1cbe50b4336e.diff
LOG: [mlir][linalg][bufferize] Support std.select bufferization
This op is an example for how to deal with ops who's OpResult may aliasing with one of multiple OpOperands.
Differential Revision: https://reviews.llvm.org/D116868
Added:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
index d3009ad251785..10afeea50b694 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
@@ -1,4 +1,4 @@
-//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index d67c4125a665d..0bbe3fcb6e6b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -346,10 +346,10 @@ class BufferizationState {
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
- llvm::SetVector<Value> findValueInReverseUseDefChain(
+ SetVector<Value> findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition) const;
- /// Find the Value of the last preceding write of a given Value.
+ /// Find the Values of the last preceding write of a given Value.
///
/// Note: Unknown ops are handled conservatively and assumed to be writes.
/// Furthermore, BlockArguments are also assumed to be writes. There is no
@@ -357,7 +357,7 @@ class BufferizationState {
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
/// is returned regardless of whether it is a memory write or not.
- Value findLastPrecedingWrite(Value value) const;
+ SetVector<Value> findLastPrecedingWrite(Value value) const;
/// Creates a memref allocation.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index 7d155512ef87d..f881a964b905e 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -31,7 +31,7 @@ runComprehensiveBufferize(ModuleOp moduleOp,
namespace std_ext {
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+void registerModuleBufferizationExternalModels(DialectRegistry ®istry);
} // namespace std_ext
} // namespace comprehensive_bufferize
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
new file mode 100644
index 0000000000000..ae3b3db23e648
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace std_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index fe5fc26c3d2ba..aa9c7bd6806c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -305,26 +305,18 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
return result;
}
-// Find the Value of the last preceding write of a given Value.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- findLastPrecedingWrite(Value value) const {
- SetVector<Value> result =
- findValueInReverseUseDefChain(value, [&](Value value) {
- Operation *op = value.getDefiningOp();
- if (!op)
- return true;
- auto bufferizableOp = options.dynCastBufferizableOp(op);
- if (!bufferizableOp)
- return true;
- return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
- });
-
- // To simplify the analysis, `scf.if` ops are considered memory writes. There
- // are currently no other ops where one OpResult may alias with multiple
- // OpOperands. Therefore, this function should return exactly one result at
- // the moment.
- assert(result.size() == 1 && "expected exactly one result");
- return result.front();
+// Find the Values of the last preceding write of a given Value.
+llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
+ BufferizationState::findLastPrecedingWrite(Value value) const {
+ return findValueInReverseUseDefChain(value, [&](Value value) {
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return true;
+ auto bufferizableOp = options.dynCastBufferizableOp(op);
+ if (!bufferizableOp)
+ return true;
+ return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
+ });
}
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
@@ -404,15 +396,19 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
- // Do not copy if the last preceding write of `operand` is an op that does
+ // Do not copy if the last preceding writes of `operand` are ops that do
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
- Value lastWrite = findLastPrecedingWrite(operand);
- if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
- if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
- return resultBuffer;
+ SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
+ if (llvm::none_of(lastWrites, [&](Value lastWrite) {
+ if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
+ return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
+ *this);
+ return true;
+ }))
+ return resultBuffer;
// Do not copy if the copied data is never read.
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 9cabc9ac05244..99c28fe124f76 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
SCFInterfaceImpl.cpp
+ StdInterfaceImpl.cpp
TensorInterfaceImpl.cpp
VectorInterfaceImpl.cpp
)
@@ -61,6 +62,14 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
MLIRSCF
)
+add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
+ StdInterfaceImpl.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRBufferizableOpInterface
+ MLIRStandard
+)
+
add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
TensorInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 41d4e4fe62e9d..ffebdfc665061 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -219,7 +219,8 @@ static bool hasReadAfterWriteInterference(
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
- // Find most recent write of uRead by following the SSA use-def chain. E.g.:
+ // Find most recent writes of uRead by following the SSA use-def chain.
+ // E.g.:
//
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
@@ -228,7 +229,7 @@ static bool hasReadAfterWriteInterference(
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
- Value lastWrite = state.findLastPrecedingWrite(uRead->get());
+ SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
@@ -265,35 +266,38 @@ static bool hasReadAfterWriteInterference(
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
continue;
- // No conflict if the conflicting write happens before the last
- // write.
- if (Operation *writingOp = lastWrite.getDefiningOp()) {
- if (happensBefore(conflictingWritingOp, writingOp, domInfo))
- // conflictingWritingOp happens before writingOp. No conflict.
- continue;
- // No conflict if conflictingWritingOp is contained in writingOp.
- if (writingOp->isProperAncestor(conflictingWritingOp))
- continue;
- } else {
- auto bbArg = lastWrite.cast<BlockArgument>();
- Block *block = bbArg.getOwner();
- if (!block->findAncestorOpInBlock(*conflictingWritingOp))
- // conflictingWritingOp happens outside of the block. No
- // conflict.
- continue;
- }
+ // Check all possible last writes.
+ for (Value lastWrite : lastWrites) {
+ // No conflict if the conflicting write happens before the last
+ // write.
+ if (Operation *writingOp = lastWrite.getDefiningOp()) {
+ if (happensBefore(conflictingWritingOp, writingOp, domInfo))
+ // conflictingWritingOp happens before writingOp. No conflict.
+ continue;
+ // No conflict if conflictingWritingOp is contained in writingOp.
+ if (writingOp->isProperAncestor(conflictingWritingOp))
+ continue;
+ } else {
+ auto bbArg = lastWrite.cast<BlockArgument>();
+ Block *block = bbArg.getOwner();
+ if (!block->findAncestorOpInBlock(*conflictingWritingOp))
+ // conflictingWritingOp happens outside of the block. No
+ // conflict.
+ continue;
+ }
- // No conflict if the conflicting write and the last write are the same
- // use.
- if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
- continue;
+ // No conflict if the conflicting write and the last write are the same
+ // use.
+ if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
+ continue;
- // All requirements are met. Conflict found!
+ // All requirements are met. Conflict found!
- if (options.printConflicts)
- annotateConflict(uRead, uConflictingWrite, lastWrite);
+ if (options.printConflicts)
+ annotateConflict(uRead, uConflictingWrite, lastWrite);
- return true;
+ return true;
+ }
}
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 5bf26365caa6e..ec2d6d758e0e0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -938,7 +938,7 @@ struct FuncOpInterface
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::std_ext::
- registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registerModuleBufferizationExternalModels(DialectRegistry ®istry) {
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
new file mode 100644
index 0000000000000..1f8cee5cf7adb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
@@ -0,0 +1,79 @@
+//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace std_ext {
+
+/// Bufferization of std.select. Just replace the operands.
+struct SelectOpInterface
+ : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return op->getOpResult(0) /*result*/;
+ }
+
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return {&op->getOpOperand(1) /*true_value*/,
+ &op->getOpOperand(2) /*false_value*/};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto selectOp = cast<SelectOp>(op);
+ // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
+ // TODO: It would be more efficient to copy the result of the `select` op
+ // instead of its OpOperands. In the worst case, 2 copies are inserted at
+ // the moment (one for each tensor). When copying the op result, only one
+ // copy would be needed.
+ Value trueBuffer =
+ *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
+ Value falseBuffer =
+ *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
+ replaceOpWithNewBufferizedOp<SelectOp>(
+ rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
+ return success();
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationAliasInfo &aliasInfo,
+ const BufferizationState &state) const {
+ return BufferRelation::None;
+ }
+};
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::std_ext::
+ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 024499a01209c..118a9436609b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRSCFTransforms
+ MLIRStdBufferizableOpInterfaceImpl
MLIRPass
MLIRStandard
MLIRStandardOpsTransforms
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index a368f4e1653c7..2da7df6da834d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -51,6 +52,7 @@ struct LinalgComprehensiveModuleBufferize
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ std_ext::registerModuleBufferizationExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 929fc150f8946..9953900699367 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1710,3 +1710,84 @@ func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf
}
return %1: tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_one
+// CHECK-SAME: %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_read_one(
+ %t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %t2 : tensor<?xf32> {linalg.inplaceable = true},
+ %c : i1)
+ -> (f32, tensor<?xf32>)
+{
+ %cst = arith.constant 0.0 : f32
+ %idx = arith.constant 0 : index
+
+ // CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "true"]}
+ %s = std.select %c, %t1, %t2 : tensor<?xf32>
+ // CHECK: tensor.insert
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]}
+ %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+ // CHECK: tensor.extract
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+ %f = tensor.extract %t1[%idx] : tensor<?xf32>
+
+ return %f, %w : f32, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_both
+// CHECK-SAME: %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_read_both(
+ %t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %t2 : tensor<?xf32> {linalg.inplaceable = true},
+ %c : i1)
+ -> (f32, f32, tensor<?xf32>)
+{
+ %cst = arith.constant 0.0 : f32
+ %idx = arith.constant 0 : index
+
+ // CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "false"]}
+ %s = std.select %c, %t1, %t2 : tensor<?xf32>
+ // CHECK: tensor.insert
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]}
+ %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+ // CHECK: tensor.extract
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+ %f = tensor.extract %t1[%idx] : tensor<?xf32>
+ // CHECK: tensor.extract
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+ %f2 = tensor.extract %t2[%idx] : tensor<?xf32>
+
+ return %f, %f2, %w : f32, f32, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_no_conflict
+// CHECK-SAME: %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_no_conflict(
+ %t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %t2 : tensor<?xf32> {linalg.inplaceable = true},
+ %c : i1)
+ -> (f32, tensor<?xf32>)
+{
+ %cst = arith.constant 0.0 : f32
+ %idx = arith.constant 0 : index
+
+ // CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]}
+ %s = std.select %c, %t1, %t2 : tensor<?xf32>
+ // CHECK: tensor.insert
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]}
+ %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+ // CHECK: tensor.extract
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+ %f = tensor.extract %w[%idx] : tensor<?xf32>
+
+ return %f, %w : f32, tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index a9c2bcba865e6..753e099354cfa 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1227,7 +1227,7 @@ func @op_is_reading_but_following_ops_are_not(
// InitTensorOp elimination would produce SSA violations for the example below.
//===----------------------------------------------------------------------===//
-func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
+func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
-> tensor<?x1x6x8xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
@@ -1243,3 +1243,54 @@ func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32
}
return %3 : tensor<?x1x6x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @write_to_select_op_source
+// CHECK-SAME: %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
+func @write_to_select_op_source(
+ %t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %t2 : tensor<?xf32> {linalg.inplaceable = true},
+ %c : i1)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ %cst = arith.constant 0.0 : f32
+ %idx = arith.constant 0 : index
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+ // CHECK: memref.store %{{.*}}, %[[alloc]]
+ %w = tensor.insert %cst into %t1[%idx] : tensor<?xf32>
+ // CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]]
+ %s = std.select %c, %t1, %t2 : tensor<?xf32>
+ // CHECK: return %[[select]], %[[alloc]]
+ return %s, %w : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_one
+// CHECK-SAME: %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
+func @write_after_select_read_one(
+ %t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %t2 : tensor<?xf32> {linalg.inplaceable = true},
+ %c : i1)
+ -> (f32, tensor<?xf32>)
+{
+ %cst = arith.constant 0.0 : f32
+ %idx = arith.constant 0 : index
+
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+ // CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]]
+ %s = std.select %c, %t1, %t2 : tensor<?xf32>
+
+ // CHECK: memref.store %{{.*}}, %[[select]]
+ %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+
+ // CHECK: %[[f:.*]] = memref.load %[[t1]]
+ %f = tensor.extract %t1[%idx] : tensor<?xf32>
+
+ // CHECK: return %[[f]], %[[select]]
+ return %f, %w : f32, tensor<?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 4a6c9d67c0d6f..fad6ec91f7c5e 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -28,6 +28,7 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRPass
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
+ MLIRStdBufferizableOpInterfaceImpl
MLIRStandard
MLIRTensor
MLIRTensorBufferizableOpInterfaceImpl
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 5ae4efba9e1ac..506b99e5d655b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -55,13 +56,14 @@ struct TestComprehensiveFunctionBufferize
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
memref::MemRefDialect, tensor::TensorDialect,
- vector::VectorDialect, scf::SCFDialect,
+ vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
arith::ArithmeticDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 0a997d17257af..c7a51a02b4c58 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6484,6 +6484,24 @@ cc_library(
],
)
+cc_library(
+ name = "StdBufferizableOpInterfaceImpl",
+ srcs = [
+ "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
+ ],
+ hdrs = [
+ "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
+ ],
+ includes = ["include"],
+ deps = [
+ ":BufferizableOpInterface",
+ ":IR",
+ ":StandardOps",
+ ":Support",
+ "//llvm:Support",
+ ],
+)
+
cc_library(
name = "TensorBufferizableOpInterfaceImpl",
srcs = [
@@ -6743,6 +6761,7 @@ cc_library(
":SCFTransforms",
":StandardOps",
":StandardOpsTransforms",
+ ":StdBufferizableOpInterfaceImpl",
":Support",
":TensorBufferizableOpInterfaceImpl",
":TensorDialect",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 3d3d023610f33..62450c85030a5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -399,6 +399,7 @@ cc_library(
"//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",
+ "//mlir:StdBufferizableOpInterfaceImpl",
"//mlir:TensorBufferizableOpInterfaceImpl",
"//mlir:TensorDialect",
"//mlir:TransformUtils",
More information about the Mlir-commits
mailing list