[Mlir-commits] [mlir] 00767cb - [mlir] Delete dup code and use unified methods.
Hanhan Wang
llvmlistbot at llvm.org
Fri Oct 21 16:51:58 PDT 2022
Author: Hanhan Wang
Date: 2022-10-21T16:51:44-07:00
New Revision: 00767cb45225e142ce5b5cf6312f3e689d53bb82
URL: https://github.com/llvm/llvm-project/commit/00767cb45225e142ce5b5cf6312f3e689d53bb82
DIFF: https://github.com/llvm/llvm-project/commit/00767cb45225e142ce5b5cf6312f3e689d53bb82.diff
LOG: [mlir] Delete dup code and use unified methods.
The foldMemRefCast method is defined in memref namespace; the
foldTensorCast method is defined in tensor namespace. This revision
deletes the dup code and use the unified methods.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D136379
Added:
Modified:
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1d05601f4fbea..01cc8b31f7c21 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1325,26 +1325,6 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyAffineOp<AffineApplyOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// Common canonicalization pattern support logic
-//===----------------------------------------------------------------------===//
-
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto cast = operand.get().getDefiningOp<memref::CastOp>();
- if (cast && operand.get() != ignore &&
- !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
- operand.set(cast.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
//===----------------------------------------------------------------------===//
@@ -1511,7 +1491,7 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -1589,7 +1569,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -2821,7 +2801,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
// Fold load from a global constant memref.
@@ -2939,7 +2919,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this, getValueToStore());
+ return memref::foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
@@ -3392,7 +3372,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// prefetch(memrefcast) -> prefetch
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bfdcedfa8d771..4f5cfb65c924a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1266,29 +1266,14 @@ LogicalResult SubgroupMmaComputeOp::verify() {
return success();
}
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
- if (cast) {
- operand.set(cast.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 82e5024cf58bf..5d6dd379b2e40 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -254,23 +254,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
// Region is elided.
}
-/// This is a common class used for patterns of the form
-/// ```
-/// someop(memrefcast(%src)) -> someop(%src)
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<memref::CastOp>();
- if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
@@ -1290,7 +1273,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 230c58522a022..c25bfc674cc95 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3085,35 +3085,6 @@ LogicalResult TransferReadOp::verify() {
[&](Twine t) { return emitOpError(t); });
}
-/// This is a common class used for patterns of the form
-/// ```
-/// someop(memrefcast) -> someop
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<memref::CastOp>();
- if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
-static LogicalResult foldTensorCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
- if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
@@ -3198,9 +3169,9 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
/// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferInBoundsAttribute(*this)))
return getResult();
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
- if (succeeded(foldTensorCast(*this)))
+ if (succeeded(tensor::foldTensorCast(*this)))
return getResult();
return OpFoldResult();
}
@@ -3648,7 +3619,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
return success();
if (succeeded(foldTransferInBoundsAttribute(*this)))
return success();
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
@@ -3948,7 +3919,7 @@ LogicalResult vector::LoadOp::verify() {
}
OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
}
@@ -3982,7 +3953,7 @@ LogicalResult vector::StoreOp::verify() {
LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -4034,7 +4005,7 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
}
@@ -4086,7 +4057,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index d8e10efaaba04..d531d8b45160a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -655,7 +655,7 @@ ArrayAttr {0}::getIndexingMaps() {{
const char structuredOpFoldersFormat[] = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
More information about the Mlir-commits
mailing list