[Mlir-commits] [mlir] 6c654b5 - [mlir][linalg][bufferize] Support std.select bufferization

Matthias Springer llvmlistbot at llvm.org
Wed Jan 12 00:51:22 PST 2022


Author: Matthias Springer
Date: 2022-01-12T17:46:44+09:00
New Revision: 6c654b51983543dbd65736e992bb1cbe50b4336e

URL: https://github.com/llvm/llvm-project/commit/6c654b51983543dbd65736e992bb1cbe50b4336e
DIFF: https://github.com/llvm/llvm-project/commit/6c654b51983543dbd65736e992bb1cbe50b4336e.diff

LOG: [mlir][linalg][bufferize] Support std.select bufferization

This op is an example for how to deal with ops who's OpResult may aliasing with one of multiple OpOperands.

Differential Revision: https://reviews.llvm.org/D116868

Added: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
index d3009ad251785..10afeea50b694 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
@@ -1,4 +1,4 @@
-//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index d67c4125a665d..0bbe3fcb6e6b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -346,10 +346,10 @@ class BufferizationState {
   /// In the above example, Values with a star satisfy the condition. When
   /// starting the traversal from Value 1, the resulting SetVector is:
   /// { 2, 7, 8, 5 }
-  llvm::SetVector<Value> findValueInReverseUseDefChain(
+  SetVector<Value> findValueInReverseUseDefChain(
       Value value, llvm::function_ref<bool(Value)> condition) const;
 
-  /// Find the Value of the last preceding write of a given Value.
+  /// Find the Values of the last preceding write of a given Value.
   ///
   /// Note: Unknown ops are handled conservatively and assumed to be writes.
   /// Furthermore, BlockArguments are also assumed to be writes. There is no
@@ -357,7 +357,7 @@ class BufferizationState {
   ///
   /// Note: When reaching an end of the reverse SSA use-def chain, that value
   /// is returned regardless of whether it is a memory write or not.
-  Value findLastPrecedingWrite(Value value) const;
+  SetVector<Value> findLastPrecedingWrite(Value value) const;
 
   /// Creates a memref allocation.
   FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index 7d155512ef87d..f881a964b905e 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -31,7 +31,7 @@ runComprehensiveBufferize(ModuleOp moduleOp,
 
 namespace std_ext {
 
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+void registerModuleBufferizationExternalModels(DialectRegistry &registry);
 
 } // namespace std_ext
 } // namespace comprehensive_bufferize

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
new file mode 100644
index 0000000000000..ae3b3db23e648
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace std_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index fe5fc26c3d2ba..aa9c7bd6806c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -305,26 +305,18 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
   return result;
 }
 
-// Find the Value of the last preceding write of a given Value.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::
-    findLastPrecedingWrite(Value value) const {
-  SetVector<Value> result =
-      findValueInReverseUseDefChain(value, [&](Value value) {
-        Operation *op = value.getDefiningOp();
-        if (!op)
-          return true;
-        auto bufferizableOp = options.dynCastBufferizableOp(op);
-        if (!bufferizableOp)
-          return true;
-        return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
-      });
-
-  // To simplify the analysis, `scf.if` ops are considered memory writes. There
-  // are currently no other ops where one OpResult may alias with multiple
-  // OpOperands. Therefore, this function should return exactly one result at
-  // the moment.
-  assert(result.size() == 1 && "expected exactly one result");
-  return result.front();
+// Find the Values of the last preceding write of a given Value.
+llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
+    BufferizationState::findLastPrecedingWrite(Value value) const {
+  return findValueInReverseUseDefChain(value, [&](Value value) {
+    Operation *op = value.getDefiningOp();
+    if (!op)
+      return true;
+    auto bufferizableOp = options.dynCastBufferizableOp(op);
+    if (!bufferizableOp)
+      return true;
+    return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
+  });
 }
 
 mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
@@ -404,15 +396,19 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
       createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
   if (failed(resultBuffer))
     return failure();
-  // Do not copy if the last preceding write of `operand` is an op that does
+  // Do not copy if the last preceding writes of `operand` are ops that do
   // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
   // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
   // use-def chain, it returns that value, regardless of whether it is a
   // memory write or not.
-  Value lastWrite = findLastPrecedingWrite(operand);
-  if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
-    if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
-      return resultBuffer;
+  SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
+  if (llvm::none_of(lastWrites, [&](Value lastWrite) {
+        if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
+          return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
+                                              *this);
+        return true;
+      }))
+    return resultBuffer;
   // Do not copy if the copied data is never read.
   OpResult aliasingOpResult = getAliasingOpResult(opOperand);
   if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 9cabc9ac05244..99c28fe124f76 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
   LinalgInterfaceImpl.cpp
   ModuleBufferization.cpp
   SCFInterfaceImpl.cpp
+  StdInterfaceImpl.cpp
   TensorInterfaceImpl.cpp
   VectorInterfaceImpl.cpp
 )
@@ -61,6 +62,14 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
   MLIRSCF
 )
 
+add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
+  StdInterfaceImpl.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRBufferizableOpInterface
+  MLIRStandard
+)
+
 add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
   TensorInterfaceImpl.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 41d4e4fe62e9d..ffebdfc665061 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -219,7 +219,8 @@ static bool hasReadAfterWriteInterference(
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
 
-    // Find most recent write of uRead by following the SSA use-def chain. E.g.:
+    // Find most recent writes of uRead by following the SSA use-def chain.
+    // E.g.:
     //
     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
@@ -228,7 +229,7 @@ static bool hasReadAfterWriteInterference(
     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
     // is %0. Note that operations that create an alias but do not write (such
     // as ExtractSliceOp) are skipped.
-    Value lastWrite = state.findLastPrecedingWrite(uRead->get());
+    SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
     // alias that have been decided to bufferize inplace.
@@ -265,35 +266,38 @@ static bool hasReadAfterWriteInterference(
       if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
         continue;
 
-      // No conflict if the conflicting write happens before the last
-      // write.
-      if (Operation *writingOp = lastWrite.getDefiningOp()) {
-        if (happensBefore(conflictingWritingOp, writingOp, domInfo))
-          // conflictingWritingOp happens before writingOp. No conflict.
-          continue;
-        // No conflict if conflictingWritingOp is contained in writingOp.
-        if (writingOp->isProperAncestor(conflictingWritingOp))
-          continue;
-      } else {
-        auto bbArg = lastWrite.cast<BlockArgument>();
-        Block *block = bbArg.getOwner();
-        if (!block->findAncestorOpInBlock(*conflictingWritingOp))
-          // conflictingWritingOp happens outside of the block. No
-          // conflict.
-          continue;
-      }
+      // Check all possible last writes.
+      for (Value lastWrite : lastWrites) {
+        // No conflict if the conflicting write happens before the last
+        // write.
+        if (Operation *writingOp = lastWrite.getDefiningOp()) {
+          if (happensBefore(conflictingWritingOp, writingOp, domInfo))
+            // conflictingWritingOp happens before writingOp. No conflict.
+            continue;
+          // No conflict if conflictingWritingOp is contained in writingOp.
+          if (writingOp->isProperAncestor(conflictingWritingOp))
+            continue;
+        } else {
+          auto bbArg = lastWrite.cast<BlockArgument>();
+          Block *block = bbArg.getOwner();
+          if (!block->findAncestorOpInBlock(*conflictingWritingOp))
+            // conflictingWritingOp happens outside of the block. No
+            // conflict.
+            continue;
+        }
 
-      // No conflict if the conflicting write and the last write are the same
-      // use.
-      if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
-        continue;
+        // No conflict if the conflicting write and the last write are the same
+        // use.
+        if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
+          continue;
 
-      // All requirements are met. Conflict found!
+        // All requirements are met. Conflict found!
 
-      if (options.printConflicts)
-        annotateConflict(uRead, uConflictingWrite, lastWrite);
+        if (options.printConflicts)
+          annotateConflict(uRead, uConflictingWrite, lastWrite);
 
-      return true;
+        return true;
+      }
     }
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 5bf26365caa6e..ec2d6d758e0e0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -938,7 +938,7 @@ struct FuncOpInterface
 } // namespace mlir
 
 void mlir::linalg::comprehensive_bufferize::std_ext::
-    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+    registerModuleBufferizationExternalModels(DialectRegistry &registry) {
   registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
   registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
   registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
new file mode 100644
index 0000000000000..1f8cee5cf7adb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
@@ -0,0 +1,79 @@
+//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace std_ext {
+
+/// Bufferization of std.select. Just replace the operands.
+struct SelectOpInterface
+    : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
+                                                    SelectOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return op->getOpResult(0) /*result*/;
+  }
+
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       const BufferizationState &state) const {
+    return {&op->getOpOperand(1) /*true_value*/,
+            &op->getOpOperand(2) /*false_value*/};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto selectOp = cast<SelectOp>(op);
+    // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
+    // TODO: It would be more efficient to copy the result of the `select` op
+    // instead of its OpOperands. In the worst case, 2 copies are inserted at
+    // the moment (one for each tensor). When copying the op result, only one
+    // copy would be needed.
+    Value trueBuffer =
+        *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
+    Value falseBuffer =
+        *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
+    replaceOpWithNewBufferizedOp<SelectOp>(
+        rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
+    return success();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo,
+                                const BufferizationState &state) const {
+    return BufferRelation::None;
+  }
+};
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::std_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 024499a01209c..118a9436609b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRSCF
   MLIRSCFBufferizableOpInterfaceImpl
   MLIRSCFTransforms
+  MLIRStdBufferizableOpInterfaceImpl
   MLIRPass
   MLIRStandard
   MLIRStandardOpsTransforms

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index a368f4e1653c7..2da7df6da834d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -51,6 +52,7 @@ struct LinalgComprehensiveModuleBufferize
     bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    std_ext::registerModuleBufferizationExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 929fc150f8946..9953900699367 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1710,3 +1710,84 @@ func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf
   }
   return %1: tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_one
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_read_one(
+    %t1 : tensor<?xf32> {linalg.inplaceable = true},
+    %t2 : tensor<?xf32> {linalg.inplaceable = true},
+    %c : i1)
+  -> (f32, tensor<?xf32>)
+{
+  %cst = arith.constant 0.0 : f32
+  %idx = arith.constant 0 : index
+
+  //      CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "false", "true"]}
+  %s = std.select %c, %t1, %t2 : tensor<?xf32>
+  //      CHECK: tensor.insert
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+  //      CHECK: tensor.extract
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["true", "none"]}
+  %f = tensor.extract %t1[%idx] : tensor<?xf32>
+
+  return %f, %w : f32, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_both
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_read_both(
+    %t1 : tensor<?xf32> {linalg.inplaceable = true},
+    %t2 : tensor<?xf32> {linalg.inplaceable = true},
+    %c : i1)
+  -> (f32, f32, tensor<?xf32>)
+{
+  %cst = arith.constant 0.0 : f32
+  %idx = arith.constant 0 : index
+
+  //      CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "false", "false"]}
+  %s = std.select %c, %t1, %t2 : tensor<?xf32>
+  //      CHECK: tensor.insert
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+  //      CHECK: tensor.extract
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["true", "none"]}
+  %f = tensor.extract %t1[%idx] : tensor<?xf32>
+  //      CHECK: tensor.extract
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["true", "none"]}
+  %f2 = tensor.extract %t2[%idx] : tensor<?xf32>
+
+  return %f, %f2, %w : f32, f32, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_no_conflict
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32> {{.*}}, %[[t2:.*]]: tensor<?xf32>
+func @write_after_select_no_conflict(
+    %t1 : tensor<?xf32> {linalg.inplaceable = true},
+    %t2 : tensor<?xf32> {linalg.inplaceable = true},
+    %c : i1)
+  -> (f32, tensor<?xf32>)
+{
+  %cst = arith.constant 0.0 : f32
+  %idx = arith.constant 0 : index
+
+  //      CHECK: select %{{.*}}, %[[t1]], %[[t2]]
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "true", "true"]}
+  %s = std.select %c, %t1, %t2 : tensor<?xf32>
+  //      CHECK: tensor.insert
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+  //      CHECK: tensor.extract
+  // CHECK-SAME:   {__inplace_operands_attr__ = ["true", "none"]}
+  %f = tensor.extract %w[%idx] : tensor<?xf32>
+
+  return %f, %w : f32, tensor<?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index a9c2bcba865e6..753e099354cfa 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1227,7 +1227,7 @@ func @op_is_reading_but_following_ops_are_not(
 // InitTensorOp elimination would produce SSA violations for the example below.
 //===----------------------------------------------------------------------===//
 
-func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) 
+func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
     -> tensor<?x1x6x8xf32> {
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
@@ -1243,3 +1243,54 @@ func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32
   }
   return %3 : tensor<?x1x6x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @write_to_select_op_source
+//  CHECK-SAME:     %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
+func @write_to_select_op_source(
+    %t1 : tensor<?xf32> {linalg.inplaceable = true},
+    %t2 : tensor<?xf32> {linalg.inplaceable = true},
+    %c : i1)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  %cst = arith.constant 0.0 : f32
+  %idx = arith.constant 0 : index
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: memref.store %{{.*}}, %[[alloc]]
+  %w = tensor.insert %cst into %t1[%idx] : tensor<?xf32>
+  // CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]]
+  %s = std.select %c, %t1, %t2 : tensor<?xf32>
+  // CHECK: return %[[select]], %[[alloc]]
+  return %s, %w : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @write_after_select_read_one
+//  CHECK-SAME:     %[[t1:.*]]: memref<?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>
+func @write_after_select_read_one(
+    %t1 : tensor<?xf32> {linalg.inplaceable = true},
+    %t2 : tensor<?xf32> {linalg.inplaceable = true},
+    %c : i1)
+  -> (f32, tensor<?xf32>)
+{
+  %cst = arith.constant 0.0 : f32
+  %idx = arith.constant 0 : index
+
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]]
+  %s = std.select %c, %t1, %t2 : tensor<?xf32>
+
+  // CHECK: memref.store %{{.*}}, %[[select]]
+  %w = tensor.insert %cst into %s[%idx] : tensor<?xf32>
+
+  // CHECK: %[[f:.*]] = memref.load %[[t1]]
+  %f = tensor.extract %t1[%idx] : tensor<?xf32>
+
+  // CHECK: return %[[f]], %[[select]]
+  return %f, %w : f32, tensor<?xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 4a6c9d67c0d6f..fad6ec91f7c5e 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -28,6 +28,7 @@ add_mlir_library(MLIRLinalgTestPasses
   MLIRPass
   MLIRSCF
   MLIRSCFBufferizableOpInterfaceImpl
+  MLIRStdBufferizableOpInterfaceImpl
   MLIRStandard
   MLIRTensor
   MLIRTensorBufferizableOpInterfaceImpl

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 5ae4efba9e1ac..506b99e5d655b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -55,13 +56,14 @@ struct TestComprehensiveFunctionBufferize
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
                     memref::MemRefDialect, tensor::TensorDialect,
-                    vector::VectorDialect, scf::SCFDialect,
+                    vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
                     arith::ArithmeticDialect, AffineDialect>();
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
     bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    std_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
   }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 0a997d17257af..c7a51a02b4c58 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6484,6 +6484,24 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "StdBufferizableOpInterfaceImpl",
+    srcs = [
+        "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":BufferizableOpInterface",
+        ":IR",
+        ":StandardOps",
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "TensorBufferizableOpInterfaceImpl",
     srcs = [
@@ -6743,6 +6761,7 @@ cc_library(
         ":SCFTransforms",
         ":StandardOps",
         ":StandardOpsTransforms",
+        ":StdBufferizableOpInterfaceImpl",
         ":Support",
         ":TensorBufferizableOpInterfaceImpl",
         ":TensorDialect",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 3d3d023610f33..62450c85030a5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -399,6 +399,7 @@ cc_library(
         "//mlir:SCFDialect",
         "//mlir:SCFTransforms",
         "//mlir:StandardOps",
+        "//mlir:StdBufferizableOpInterfaceImpl",
         "//mlir:TensorBufferizableOpInterfaceImpl",
         "//mlir:TensorDialect",
         "//mlir:TransformUtils",


        


More information about the Mlir-commits mailing list