[Mlir-commits] [mlir] 2e3c62b - [mlir][tensor][NFC] Simplify `SubsetInsertionOpInterface` implementation (#69999)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 24 16:45:44 PDT 2023


Author: Matthias Springer
Date: 2023-10-25T08:45:39+09:00
New Revision: 2e3c62b15d49fbe11967de7719f4b9d70c5493e4

URL: https://github.com/llvm/llvm-project/commit/2e3c62b15d49fbe11967de7719f4b9d70c5493e4
DIFF: https://github.com/llvm/llvm-project/commit/2e3c62b15d49fbe11967de7719f4b9d70c5493e4.diff

LOG: [mlir][tensor][NFC] Simplify `SubsetInsertionOpInterface` implementation (#69999)

`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the
same implementation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index f4f46d54d78e59f..85f7796096a42ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
 
 namespace {
 
-/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-/// equivalence of tensors.
 template <typename OpTy>
-bool isSubsetEquivalentToInsertSliceLikeOp(
-    OpTy insertSliceOp, Value candidate,
-    function_ref<bool(Value, Value)> equivalenceFn) {
-  // Look for a matching tensor.extract_slice op.
-  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
-  if (!extractSliceOp)
-    return false;
-  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
-    return false;
-  return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                    isEqualConstantIntOrValue);
-}
-
-template <typename OpTy>
-Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
-                                               OpTy insertSliceOp) {
-  auto extractOp = b.create<tensor::ExtractSliceOp>(
-      loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
-      insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-      insertSliceOp.getMixedStrides());
-  return extractOp.getResult();
-}
-
-template <typename OpTy>
-SmallVector<Value>
-getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
-  SmallVector<Value> neededValues;
-  // Collect all values that are needed to construct the replacement op.
-  neededValues.append(insertSliceOp.getOffsets().begin(),
-                      insertSliceOp.getOffsets().end());
-  neededValues.append(insertSliceOp.getSizes().begin(),
-                      insertSliceOp.getSizes().end());
-  neededValues.append(insertSliceOp.getStrides().begin(),
-                      insertSliceOp.getStrides().end());
-  neededValues.push_back(insertSliceOp.getDest());
-  return neededValues;
-}
-
-struct InsertSliceOpInterface
-    : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
-                                                       tensor::InsertSliceOp> {
-  OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::InsertSliceOp>(op).getSourceMutable();
-  }
-
-  bool
-  isEquivalentSubset(Operation *op, Value candidate,
-                     function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
-  }
-
-  Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
-                              Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::InsertSliceOp>(op));
-  }
-
-  SmallVector<Value>
-  getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::InsertSliceOp>(op));
-  }
-};
-
-struct ParallelInsertSliceOpInterface
+struct InsertSliceLikeOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
-          ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+          InsertSliceLikeOpInterface<OpTy>, OpTy> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
+    return cast<OpTy>(op).getSourceMutable();
   }
 
   OpOperand &getDestinationOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
+    return cast<OpTy>(op).getDestMutable();
   }
 
+  /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+  /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+  /// equivalence of tensors.
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
+    auto insertSliceOp = cast<OpTy>(op);
+    // Look for a matching tensor.extract_slice op.
+    auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractSliceOp)
+      return false;
+    if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+      return false;
+    return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                      isEqualConstantIntOrValue);
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    auto extractOp = builder.create<tensor::ExtractSliceOp>(
+        loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
+    return extractOp.getResult();
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    SmallVector<Value> neededValues;
+    // Collect all values that are needed to construct the replacement op.
+    neededValues.append(insertSliceOp.getOffsets().begin(),
+                        insertSliceOp.getOffsets().end());
+    neededValues.append(insertSliceOp.getSizes().begin(),
+                        insertSliceOp.getSizes().end());
+    neededValues.append(insertSliceOp.getStrides().begin(),
+                        insertSliceOp.getStrides().end());
+    neededValues.push_back(insertSliceOp.getDest());
+    return neededValues;
   }
 };
 
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
-    InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
-    ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+    InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
         *ctx);
+    ParallelInsertSliceOp::attachInterface<
+        InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
   });
 }


        


More information about the Mlir-commits mailing list