[Mlir-commits] [mlir] f36e193 - [mlir][bufferization] Improve `bufferizesToElementwiseAccess`

Matthias Springer llvmlistbot at llvm.org
Tue Aug 22 00:00:42 PDT 2023


Author: Matthias Springer
Date: 2023-08-22T09:00:17+02:00
New Revision: f36e19347fca388b80a890ca1b1e785920536289

URL: https://github.com/llvm/llvm-project/commit/f36e19347fca388b80a890ca1b1e785920536289
DIFF: https://github.com/llvm/llvm-project/commit/f36e19347fca388b80a890ca1b1e785920536289.diff

LOG: [mlir][bufferization] Improve `bufferizesToElementwiseAccess`

The operands for which elementwise access is relevant can now be specified. All other operands are ignored. This is useful because only two particular operands participate in a RaW conflict. Furthermore, the two tensors no longer must be equivalent to rule out conflicts due to elementwise access. Equivalent tensor sets may be formed after an inplace bufferization decision is made. The two tensors are actually not required to be equivalent. The only important thing is that they have "equivalent" indexing into the same base buffer.

Differential Revision: https://reviews.llvm.org/D158428

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index bd7a2d8b3f1eac..42aff77303e0d1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -93,10 +93,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Return `true` if the operation bufferizes to IR that performs only
-          element-wise accesses on all tensor operands. (All operands must have
-          the same shape.) The `bufferize` method must be implemented in such a
-          way that it is free of loop-carried dependences. I.e., all loads at a
-          position appear before all stores at the same position.
+          element-wise accesses on the specified tensor operands. (The operands
+          must have the same shape.) The `bufferize` method must be implemented
+          in such a way that it is free of loop-carried dependences. I.e., all
+          loads at a position appear before all stores at the same position.
 
           Example: Consider a hypothetical op element-wise op, where the "ins"
           bufferize to a memory read and the "outs" bufferize to a memory write.
@@ -130,10 +130,15 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           equivalent tensors. (It is not possible, if %0 and %1 are merely
           aliasing. It is not necessary if %0 and %1 are not aliasing at all,
           because there would be no conflict anyway.)
+
+          Note: Tensor operands that are not included in `opOperands` can be
+          ignored. A conservative implementation of this interface method may
+          always return "false".
         }],
         /*retType=*/"bool",
         /*methodName=*/"bufferizesToElementwiseAccess",
-        /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
+        /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state,
+                      "ArrayRef<OpOperand *>":$opOperands),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // It is always safe to assume that the op is not element-wise.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 7f7ab9c21041c5..ba595bec0e6bdc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -446,6 +446,26 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
   }
 }
 
+/// Return 'true' if a tensor that is equivalent to `other` can be found in the
+/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
+/// place along that use-def chain, the two tensors may not materialize as
+/// equivalent buffers (but separate allocations).
+///
+/// Note: This function also requires that the two tensors have equivalent
+/// indexing. I.e., the tensor types do not change along the use-def chain,
+/// apart from static <-> dynamic dim casts.
+static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
+                                                   Value start, Value other) {
+  TraversalConfig config;
+  config.followEquivalentOnly = true;
+  config.alwaysIncludeLeaves = false;
+  config.followSameTypeOrCastsOnly = true;
+  return !state
+              .findValueInReverseUseDefChain(
+                  start, [&](Value v) { return v == other; }, config)
+              .empty();
+}
+
 /// Given sets of uses and writes, return true if there is a RaW conflict under
 /// the assumption that all given reads/writes alias the same buffer and that
 /// all given writes bufferize inplace.
@@ -545,15 +565,19 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
       // Two equivalent operands of the same op are not conflicting if the op
       // bufferizes to element-wise access. I.e., all loads at a position happen
       // before all stores to the same position.
-      if (conflictingWritingOp == readingOp &&
-          state.areEquivalentBufferizedValues(uRead->get(),
-                                              uConflictingWrite->get())) {
+      if (conflictingWritingOp == readingOp) {
         if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
-          if (bufferizableOp.bufferizesToElementwiseAccess(state)) {
-            LLVM_DEBUG(
-                llvm::dbgs()
-                << "  no conflict: op bufferizes to element-wise access\n");
-            continue;
+          if (bufferizableOp.bufferizesToElementwiseAccess(
+                  state, {uRead, uConflictingWrite})) {
+            if (hasEquivalentValueInReverseUseDefChain(
+                    state, uRead->get(), uConflictingWrite->get()) ||
+                hasEquivalentValueInReverseUseDefChain(
+                    state, uConflictingWrite->get(), uRead->get())) {
+              LLVM_DEBUG(
+                  llvm::dbgs()
+                  << "  no conflict: op bufferizes to element-wise access\n");
+              continue;
+            }
           }
         }
       }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 7264f831b7ef8a..0577441bdd28d2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -106,8 +106,8 @@ struct LinalgOpInterface
     return dpsOp.isDpsInit(&opOperand);
   }
 
-  bool bufferizesToElementwiseAccess(Operation *op,
-                                     const AnalysisState &state) const {
+  bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
+                                     ArrayRef<OpOperand *> opOperands) const {
     auto linalgOp = cast<linalg::LinalgOp>(op);
 
     // All loops must be parallel.
@@ -119,10 +119,13 @@ struct LinalgOpInterface
     assert(linalgOp->getNumOperands() == indexingMaps.size() &&
            "unexpected number of indexing maps");
     for (auto [operand, map] :
-         llvm::zip(linalgOp->getOperands(), indexingMaps)) {
+         llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
       // Non-tensors do not participate in bufferization, so they can be
       // ignored.
-      if (!isa<RankedTensorType, MemRefType>(operand.getType()))
+      if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
+        continue;
+      // Only consider operands in `opOperands`.
+      if (llvm::find(opOperands, &operand) == opOperands.end())
         continue;
       // TODO: This could be generalized to other indexing maps. (All indexing
       // must be the same.)

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
index 88d3e303887ab4..b4230314302f6e 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
@@ -57,3 +57,53 @@ func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> {
     } -> tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: @elementwise_no_conflict_4
+func.func @elementwise_no_conflict_4(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32xf32>) -> tensor<8x32x32x32xf32> {
+  %cst = arith.constant dense<3.000000e-02> : tensor<32x32x32xf32>
+  %cst_0 = arith.constant dense<6.000000e-01> : tensor<32xf32>
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %r = scf.forall (%arg2, %arg3) in (8, 32) shared_outs(%arg4 = %arg0) -> (tensor<8x32x32x32xf32>) {
+    // CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "none"]}
+    %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32xf32>
+
+    // CHECK: linalg.fill
+    // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
+    %4 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+
+    // CHECK: linalg.batch_reduce_matmul
+    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]}
+    %5 = linalg.batch_reduce_matmul ins(%arg1, %cst : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%4 : tensor<32x32xf32>) -> tensor<32x32xf32>
+
+    // CHECK: linalg.generic
+    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]}
+    // %cst_0 has a non-identity layout may, but %5 and %extracted_slice still
+    // bufferize to element-wise access.
+    %6 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_0 : tensor<32x32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+    ^bb0(%in: f32, %in_4: f32, %out: f32):
+      %8 = arith.addf %in, %in_4 : f32
+      linalg.yield %8 : f32
+    } -> tensor<32x32xf32>
+
+    // CHECK: linalg.generic
+    // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
+    // They are 
diff erent SSA values, but %6 and %extract_slice are equivalent.
+    %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %8 = arith.maxf %in, %cst_1 : f32
+      linalg.yield %8 : f32
+    } -> tensor<32x32xf32>
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice
+      // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]}
+      tensor.parallel_insert_slice %7 into %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x32x32x32xf32>
+    }
+  }
+  return %r : tensor<8x32x32x32xf32>
+}


        


More information about the Mlir-commits mailing list