[Mlir-commits] [mlir] 00767cb - [mlir] Delete dup code and use unified methods.

Hanhan Wang llvmlistbot at llvm.org
Fri Oct 21 16:51:58 PDT 2022


Author: Hanhan Wang
Date: 2022-10-21T16:51:44-07:00
New Revision: 00767cb45225e142ce5b5cf6312f3e689d53bb82

URL: https://github.com/llvm/llvm-project/commit/00767cb45225e142ce5b5cf6312f3e689d53bb82
DIFF: https://github.com/llvm/llvm-project/commit/00767cb45225e142ce5b5cf6312f3e689d53bb82.diff

LOG: [mlir] Delete dup code and use unified methods.

The foldMemRefCast method is defined in memref namespace; the
foldTensorCast method is defined in tensor namespace. This revision
deletes the dup code and use the unified methods.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D136379

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1d05601f4fbea..01cc8b31f7c21 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1325,26 +1325,6 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyAffineOp<AffineApplyOp>>(context);
 }
 
-//===----------------------------------------------------------------------===//
-// Common canonicalization pattern support logic
-//===----------------------------------------------------------------------===//
-
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = operand.get().getDefiningOp<memref::CastOp>();
-    if (cast && operand.get() != ignore &&
-        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
-      operand.set(cast.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 //===----------------------------------------------------------------------===//
 // AffineDmaStartOp
 //===----------------------------------------------------------------------===//
@@ -1511,7 +1491,7 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
                                      SmallVectorImpl<OpFoldResult> &results) {
   /// dma_start(memrefcast) -> dma_start
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1589,7 +1569,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
                                     SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2821,7 +2801,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
   /// load(memrefcast) -> load
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
 
   // Fold load from a global constant memref.
@@ -2939,7 +2919,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
                                   SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this, getValueToStore());
+  return memref::foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3392,7 +3372,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
                                      SmallVectorImpl<OpFoldResult> &results) {
   /// prefetch(memrefcast) -> prefetch
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bfdcedfa8d771..4f5cfb65c924a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1266,29 +1266,14 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   return success();
 }
 
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
-    if (cast) {
-      operand.set(cast.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 82e5024cf58bf..5d6dd379b2e40 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -254,23 +254,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
   // Region is elided.
 }
 
-/// This is a common class used for patterns of the form
-/// ```
-///    someop(memrefcast(%src)) -> someop(%src)
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
-    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -1290,7 +1273,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 230c58522a022..c25bfc674cc95 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3085,35 +3085,6 @@ LogicalResult TransferReadOp::verify() {
                               [&](Twine t) { return emitOpError(t); });
 }
 
-/// This is a common class used for patterns of the form
-/// ```
-///    someop(memrefcast) -> someop
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
-    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
-static LogicalResult foldTensorCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
-    if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
@@ -3198,9 +3169,9 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
   /// transfer_read(memrefcast) -> transfer_read
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return getResult();
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
-  if (succeeded(foldTensorCast(*this)))
+  if (succeeded(tensor::foldTensorCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -3648,7 +3619,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
     return success();
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return success();
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
@@ -3948,7 +3919,7 @@ LogicalResult vector::LoadOp::verify() {
 }
 
 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -3982,7 +3953,7 @@ LogicalResult vector::StoreOp::verify() {
 
 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
                             SmallVectorImpl<OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4034,7 +4005,7 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -4086,7 +4057,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
                                   SmallVectorImpl<OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index d8e10efaaba04..d531d8b45160a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -655,7 +655,7 @@ ArrayAttr {0}::getIndexingMaps() {{
 const char structuredOpFoldersFormat[] = R"FMT(
 LogicalResult {0}::fold(ArrayRef<Attribute>,
                         SmallVectorImpl<OpFoldResult> &) {{
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{


        


More information about the Mlir-commits mailing list