[Mlir-commits] [mlir] 7039bd2 - [mlir][MemRef][NFC] Migrate MemRef dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 12:47:33 PST 2023


Author: Markus Böck
Date: 2023-01-11T21:47:25+01:00
New Revision: 7039bd25093fb73bff0426f5987b75006f65889b

URL: https://github.com/llvm/llvm-project/commit/7039bd25093fb73bff0426f5987b75006f65889b
DIFF: https://github.com/llvm/llvm-project/commit/7039bd25093fb73bff0426f5987b75006f65889b.diff

LOG: [mlir][MemRef][NFC] Migrate MemRef dialect to the new fold API

See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Differential Revision: https://reviews.llvm.org/D141529

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
index 3be84ae654f6a..1e7d5816550cc 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
@@ -21,6 +21,7 @@ def MemRef_Dialect : Dialect {
   }];
   let dependentDialects = ["arith::ArithDialect"];
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // MEMREF_BASE

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78a606de706d9..d3a1ae1663c01 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -808,7 +808,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return false;
 }
 
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
@@ -883,7 +883,7 @@ void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
 }
 
-LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult CopyOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {
   /// copy(memrefcast) -> copy
   bool folded = false;
@@ -902,7 +902,7 @@ LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
                               SmallVectorImpl<OpFoldResult> &results) {
   /// dealloc(memrefcast) -> dealloc
   return foldMemRefCast(*this);
@@ -1056,9 +1056,9 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
   return *unusedDims;
 }
 
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
-  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
+  auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
   if (!index)
     return {};
 
@@ -1322,7 +1322,7 @@ LogicalResult DmaStartOp::verify() {
   return success();
 }
 
-LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &results) {
   /// dma_start(memrefcast) -> dma_start
   return foldMemRefCast(*this);
@@ -1332,7 +1332,7 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
 // DmaWaitOp
 // ---------------------------------------------------------------------------
 
-LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
                               SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
   return foldMemRefCast(*this);
@@ -1433,7 +1433,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
 }
 
 LogicalResult
-ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
+ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &results) {
   OpBuilder builder(*this);
 
@@ -1677,7 +1677,7 @@ LogicalResult LoadOp::verify() {
   return success();
 }
 
-OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
+OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1747,7 +1747,7 @@ LogicalResult PrefetchOp::verify() {
   return success();
 }
 
-LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &results) {
   // prefetch(memrefcast) -> prefetch
   return foldMemRefCast(*this);
@@ -1757,7 +1757,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
 // RankOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
   // Constant fold rank when the rank of the operand is known.
   auto type = getOperand().getType();
   auto shapedType = type.dyn_cast<ShapedType>();
@@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
   return success();
 }
 
-OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
+OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
   Value src = getSource();
   auto getPrevSrc = [&]() -> Value {
     // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
@@ -2465,12 +2465,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
-  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
+                                                       adaptor.getOperands());
 }
 
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
-  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
+                                                       adaptor.getOperands());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2522,7 +2524,7 @@ LogicalResult StoreOp::verify() {
   return success();
 }
 
-LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult StoreOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
   return foldMemRefCast(*this, getValueToStore());
@@ -3101,7 +3103,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
 }
 
-OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   auto resultShapedType = getResult().getType().cast<ShapedType>();
   auto sourceShapedType = getSource().getType().cast<ShapedType>();
 
@@ -3217,7 +3219,7 @@ LogicalResult TransposeOp::verify() {
   return success();
 }
 
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
+OpFoldResult TransposeOp::fold(FoldAdaptor) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
   return {};
@@ -3393,7 +3395,7 @@ LogicalResult AtomicRMWOp::verify() {
   return success();
 }
 
-OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   /// atomicrmw(memrefcast) -> atomicrmw
   if (succeeded(foldMemRefCast(*this, getValue())))
     return getResult();


        


More information about the Mlir-commits mailing list