[Mlir-commits] [mlir] [mlir][tensor][NFC] Simplify `SubsetInsertionOpInterface` implementation (PR #69999)

Matthias Springer llvmlistbot at llvm.org
Mon Oct 23 19:46:10 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/69999

`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.

>From ff9ab72f7fdbfee185671d98a91bf4a2a1fb8d85 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 24 Oct 2023 11:43:44 +0900
Subject: [PATCH] [mlir][tensor][NFC] Simplify `SubsetInsertionOpInterface`
 implementation

`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.
---
 .../SubsetInsertionOpInterfaceImpl.cpp        | 118 ++++++------------
 1 file changed, 36 insertions(+), 82 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index f4f46d54d78e59f..85f7796096a42ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
 
 namespace {
 
-/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-/// equivalence of tensors.
 template <typename OpTy>
-bool isSubsetEquivalentToInsertSliceLikeOp(
-    OpTy insertSliceOp, Value candidate,
-    function_ref<bool(Value, Value)> equivalenceFn) {
-  // Look for a matching tensor.extract_slice op.
-  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
-  if (!extractSliceOp)
-    return false;
-  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
-    return false;
-  return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                    isEqualConstantIntOrValue);
-}
-
-template <typename OpTy>
-Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
-                                               OpTy insertSliceOp) {
-  auto extractOp = b.create<tensor::ExtractSliceOp>(
-      loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
-      insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-      insertSliceOp.getMixedStrides());
-  return extractOp.getResult();
-}
-
-template <typename OpTy>
-SmallVector<Value>
-getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
-  SmallVector<Value> neededValues;
-  // Collect all values that are needed to construct the replacement op.
-  neededValues.append(insertSliceOp.getOffsets().begin(),
-                      insertSliceOp.getOffsets().end());
-  neededValues.append(insertSliceOp.getSizes().begin(),
-                      insertSliceOp.getSizes().end());
-  neededValues.append(insertSliceOp.getStrides().begin(),
-                      insertSliceOp.getStrides().end());
-  neededValues.push_back(insertSliceOp.getDest());
-  return neededValues;
-}
-
-struct InsertSliceOpInterface
-    : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
-                                                       tensor::InsertSliceOp> {
-  OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::InsertSliceOp>(op).getSourceMutable();
-  }
-
-  bool
-  isEquivalentSubset(Operation *op, Value candidate,
-                     function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
-  }
-
-  Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
-                              Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::InsertSliceOp>(op));
-  }
-
-  SmallVector<Value>
-  getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::InsertSliceOp>(op));
-  }
-};
-
-struct ParallelInsertSliceOpInterface
+struct InsertSliceLikeOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
-          ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+          InsertSliceLikeOpInterface<OpTy>, OpTy> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
+    return cast<OpTy>(op).getSourceMutable();
   }
 
   OpOperand &getDestinationOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
+    return cast<OpTy>(op).getDestMutable();
   }
 
+  /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+  /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+  /// equivalence of tensors.
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
+    auto insertSliceOp = cast<OpTy>(op);
+    // Look for a matching tensor.extract_slice op.
+    auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractSliceOp)
+      return false;
+    if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+      return false;
+    return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                      isEqualConstantIntOrValue);
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    auto extractOp = builder.create<tensor::ExtractSliceOp>(
+        loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
+    return extractOp.getResult();
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    SmallVector<Value> neededValues;
+    // Collect all values that are needed to construct the replacement op.
+    neededValues.append(insertSliceOp.getOffsets().begin(),
+                        insertSliceOp.getOffsets().end());
+    neededValues.append(insertSliceOp.getSizes().begin(),
+                        insertSliceOp.getSizes().end());
+    neededValues.append(insertSliceOp.getStrides().begin(),
+                        insertSliceOp.getStrides().end());
+    neededValues.push_back(insertSliceOp.getDest());
+    return neededValues;
   }
 };
 
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
-    InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
-    ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+    InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
         *ctx);
+    ParallelInsertSliceOp::attachInterface<
+        InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
   });
 }



More information about the Mlir-commits mailing list