[Mlir-commits] [mlir] 1defec8 - [mlir][tensor][bufferize][NFC] Remove duplicate code

Matthias Springer llvmlistbot at llvm.org
Mon Jul 25 03:34:28 PDT 2022


Author: Matthias Springer
Date: 2022-07-25T12:34:16+02:00
New Revision: 1defec87306593e71057de7baf9bd8e2389c2419

URL: https://github.com/llvm/llvm-project/commit/1defec87306593e71057de7baf9bd8e2389c2419
DIFF: https://github.com/llvm/llvm-project/commit/1defec87306593e71057de7baf9bd8e2389c2419.diff

LOG: [mlir][tensor][bufferize][NFC] Remove duplicate code

InsertSliceOp and ParallelInsertSliceOp are very similar and can share some of the bufferization analysis code.

Differential Revision: https://reviews.llvm.org/D130465

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 44ac40abfb65..38044da2e4db 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -552,29 +552,30 @@ struct InsertOpInterface
 
 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
 /// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
+template <typename OpTy>
 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
-                                         ExtractSliceOp st, InsertSliceOp sti) {
-  if (!st || !sti)
+                                         ExtractSliceOp extractSliceOp,
+                                         OpTy insertSliceOp) {
+  if (!extractSliceOp || !insertSliceOp)
     return false;
-  if (sti != sti &&
-      !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
+  if (extractSliceOp != insertSliceOp &&
+      !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
+                                           insertSliceOp.getDest()))
     return false;
-  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                  isEqualConstantIntOrValue))
     return false;
   return true;
 }
 
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
+template <typename OpTy>
 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
-                                      InsertSliceOp insertOp) {
+                                      OpTy insertSliceOp) {
   auto condition = [&](Value val) {
-    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
+    if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
+      if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
         return true;
     return false;
   };
@@ -583,6 +584,83 @@ static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
                       condition);
 }
 
+template <typename OpTy>
+static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
+                                              OpOperand *uConflictingWrite,
+                                              const AnalysisState &state) {
+  Operation *readingOp = uRead->getOwner();
+  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+  // uRead is an InsertSliceOp...
+  if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+
+    // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+    if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+        hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
+                                  insertSliceOp))
+      // Case 1: The main insight is that InsertSliceOp reads only part of
+      // the destination tensor. The overwritten area is not read. If
+      // uConflictingWrite writes into exactly the memory location that is
+      // being read by uRead, this is not a conflict.
+      //
+      // In the above example:
+      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+      //
+      // The read of %t does not conflict with the write of the FillOp
+      // (same aliases!) because the area that the FillOp operates on is
+      // exactly the one that is *not* read via %t.
+      return true;
+
+    if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+        uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+        hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+      // Case 2: The read of the source tensor and the write to the dest
+      // tensor via an InsertSliceOp is not a conflict if the read is
+      // reading exactly that part of an equivalent tensor that the
+      // InsertSliceOp is writing.
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      return true;
+  }
+
+  // If uConflictingWrite is an InsertSliceOp...
+  if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+    // %3 = vector.transfer_read %1, %cst
+    //
+    // In the above example:
+    // uRead             = OpOperand 0 (%1) of vector.transfer_read
+    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+    // lastWrite         = %1
+    //
+    // This is not a conflict because the InsertSliceOp overwrites the
+    // memory segment of %1 with the exact same data. (Effectively, there
+    // is no memory write here.)
+    if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+        state.areEquivalentBufferizedValues(uRead->get(),
+                                            insertSliceOp.getSource()) &&
+        hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
+                                  insertSliceOp))
+      return true;
+
+  return false;
+}
+
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 struct InsertSliceOpInterface
@@ -613,77 +691,8 @@ struct InsertSliceOpInterface
   bool isNotConflicting(Operation *op, OpOperand *uRead,
                         OpOperand *uConflictingWrite,
                         const AnalysisState &state) const {
-    Operation *readingOp = uRead->getOwner();
-    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-    // uRead is an InsertSliceOp...
-    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-
-      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
-                                    insertSliceOp))
-        // Case 1: The main insight is that InsertSliceOp reads only part of
-        // the destination tensor. The overwritten area is not read. If
-        // uConflictingWrite writes into exactly the memory location that is
-        // being read by uRead, this is not a conflict.
-        //
-        // In the above example:
-        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-        //
-        // The read of %t does not conflict with the write of the FillOp
-        // (same aliases!) because the area that the FillOp operates on is
-        // exactly the one that is *not* read via %t.
-        return true;
-
-      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
-        // Case 2: The read of the source tensor and the write to the dest
-        // tensor via an InsertSliceOp is not a conflict if the read is
-        // reading exactly that part of an equivalent tensor that the
-        // InsertSliceOp is writing.
-        //
-        // In the above example:
-        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-        return true;
-    }
-
-    // If uConflictingWrite is an InsertSliceOp...
-    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-      // %3 = vector.transfer_read %1, %cst
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of vector.transfer_read
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      // lastWrite         = %1
-      //
-      // This is not a conflict because the InsertSliceOp overwrites the
-      // memory segment of %1 with the exact same data. (Effectively, there
-      // is no memory write here.)
-      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          state.areEquivalentBufferizedValues(uRead->get(),
-                                              insertSliceOp.getSource()) &&
-          hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
-                                    insertSliceOp))
-        return true;
-
-    return false;
+    return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
+        op, uRead, uConflictingWrite, state);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -805,36 +814,6 @@ struct ReshapeOpInterface
   }
 };
 
-/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-static bool areEquivalentExtractSliceOps(const AnalysisState &state,
-                                         ExtractSliceOp st,
-                                         ParallelInsertSliceOp sti) {
-  if (!st || !sti)
-    return false;
-  if (st != sti &&
-      !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
-                                      ParallelInsertSliceOp insertOp) {
-  auto condition = [&](Value val) {
-    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
-        return true;
-    return false;
-  };
-
-  return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
-                      condition);
-}
-
 /// Analysis of ParallelInsertSliceOp.
 struct ParallelInsertSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<
@@ -978,83 +957,11 @@ struct ParallelInsertSliceOpInterface
     return success();
   }
 
-  // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
-  // the code.
   bool isNotConflicting(Operation *op, OpOperand *uRead,
                         OpOperand *uConflictingWrite,
                         const AnalysisState &state) const {
-    Operation *readingOp = uRead->getOwner();
-    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-    // uRead is an InsertSliceOp...
-    if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-
-      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
-                                    insertSliceOp))
-        // Case 1: The main insight is that InsertSliceOp reads only part of
-        // the destination tensor. The overwritten area is not read. If
-        // uConflictingWrite writes into exactly the memory location that is
-        // being read by uRead, this is not a conflict.
-        //
-        // In the above example:
-        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-        //
-        // The read of %t does not conflict with the write of the FillOp
-        // (same aliases!) because the area that the FillOp operates on is
-        // exactly the one that is *not* read via %t.
-        return true;
-
-      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
-        // Case 2: The read of the source tensor and the write to the dest
-        // tensor via an InsertSliceOp is not a conflict if the read is
-        // reading exactly that part of an equivalent tensor that the
-        // InsertSliceOp is writing.
-        //
-        // In the above example:
-        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-        return true;
-    }
-
-    // If uConflictingWrite is an InsertSliceOp...
-    if (auto insertSliceOp =
-            dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-      // %3 = vector.transfer_read %1, %cst
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of vector.transfer_read
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      // lastWrite         = %1
-      //
-      // This is not a conflict because the InsertSliceOp overwrites the
-      // memory segment of %1 with the exact same data. (Effectively, there
-      // is no memory write here.)
-      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          state.areEquivalentBufferizedValues(uRead->get(),
-                                              insertSliceOp.getSource()) &&
-          hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
-                                    insertSliceOp))
-        return true;
-
-    return false;
+    return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
+        op, uRead, uConflictingWrite, state);
   }
 };
 


        


More information about the Mlir-commits mailing list