[Mlir-commits] [mlir] [mlir][tensor][NFC] Simplify `SubsetInsertionOpInterface` implementation (PR #69999)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 23 19:47:22 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.
---
Full diff: https://github.com/llvm/llvm-project/pull/69999.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+36-82)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index f4f46d54d78e59f..85f7796096a42ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
namespace {
-/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-/// equivalence of tensors.
template <typename OpTy>
-bool isSubsetEquivalentToInsertSliceLikeOp(
- OpTy insertSliceOp, Value candidate,
- function_ref<bool(Value, Value)> equivalenceFn) {
- // Look for a matching tensor.extract_slice op.
- auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
- if (!extractSliceOp)
- return false;
- if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
- return false;
- return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
- isEqualConstantIntOrValue);
-}
-
-template <typename OpTy>
-Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
- OpTy insertSliceOp) {
- auto extractOp = b.create<tensor::ExtractSliceOp>(
- loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides());
- return extractOp.getResult();
-}
-
-template <typename OpTy>
-SmallVector<Value>
-getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
- SmallVector<Value> neededValues;
- // Collect all values that are needed to construct the replacement op.
- neededValues.append(insertSliceOp.getOffsets().begin(),
- insertSliceOp.getOffsets().end());
- neededValues.append(insertSliceOp.getSizes().begin(),
- insertSliceOp.getSizes().end());
- neededValues.append(insertSliceOp.getStrides().begin(),
- insertSliceOp.getStrides().end());
- neededValues.push_back(insertSliceOp.getDest());
- return neededValues;
-}
-
-struct InsertSliceOpInterface
- : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
- tensor::InsertSliceOp> {
- OpOperand &getSourceOperand(Operation *op) const {
- return cast<tensor::InsertSliceOp>(op).getSourceMutable();
- }
-
- bool
- isEquivalentSubset(Operation *op, Value candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
- equivalenceFn);
- }
-
- Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
- Location loc) const {
- return buildSubsetExtractionOfInsertSliceLikeOp(
- builder, loc, cast<tensor::InsertSliceOp>(op));
- }
-
- SmallVector<Value>
- getValuesNeededToBuildSubsetExtraction(Operation *op) const {
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
- cast<tensor::InsertSliceOp>(op));
- }
-};
-
-struct ParallelInsertSliceOpInterface
+struct InsertSliceLikeOpInterface
: public SubsetInsertionOpInterface::ExternalModel<
- ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+ InsertSliceLikeOpInterface<OpTy>, OpTy> {
OpOperand &getSourceOperand(Operation *op) const {
- return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
+ return cast<OpTy>(op).getSourceMutable();
}
OpOperand &getDestinationOperand(Operation *op) const {
- return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
+ return cast<OpTy>(op).getDestMutable();
}
+ /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+ /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+ /// equivalence of tensors.
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
- return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
- equivalenceFn);
+ auto insertSliceOp = cast<OpTy>(op);
+ // Look for a matching tensor.extract_slice op.
+ auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return false;
+ if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+ return false;
+ return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+ isEqualConstantIntOrValue);
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
- return buildSubsetExtractionOfInsertSliceLikeOp(
- builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+ auto insertSliceOp = cast<OpTy>(op);
+ auto extractOp = builder.create<tensor::ExtractSliceOp>(
+ loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides());
+ return extractOp.getResult();
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
- cast<tensor::ParallelInsertSliceOp>(op));
+ auto insertSliceOp = cast<OpTy>(op);
+ SmallVector<Value> neededValues;
+ // Collect all values that are needed to construct the replacement op.
+ neededValues.append(insertSliceOp.getOffsets().begin(),
+ insertSliceOp.getOffsets().end());
+ neededValues.append(insertSliceOp.getSizes().begin(),
+ insertSliceOp.getSizes().end());
+ neededValues.append(insertSliceOp.getStrides().begin(),
+ insertSliceOp.getStrides().end());
+ neededValues.push_back(insertSliceOp.getDest());
+ return neededValues;
}
};
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
- InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
- ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+ InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
*ctx);
+ ParallelInsertSliceOp::attachInterface<
+ InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
});
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/69999
More information about the Mlir-commits
mailing list