[Mlir-commits] [mlir] 7df7612 - [mlir][NFC] Migrate rest of the dialects to the new fold API

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 12:47:35 PST 2023


Author: Markus Böck
Date: 2023-01-11T21:47:25+01:00
New Revision: 7df761217cd7d0026ffff23c4bdac846bb60f185

URL: https://github.com/llvm/llvm-project/commit/7df761217cd7d0026ffff23c4bdac846bb60f185
DIFF: https://github.com/llvm/llvm-project/commit/7df761217cd7d0026ffff23c4bdac846bb60f185.diff

LOG: [mlir][NFC] Migrate rest of the dialects to the new fold API

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
    mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td
    mlir/include/mlir/Dialect/Func/IR/FuncOps.td
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/IR/BuiltinDialect.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Dialect/Func/IR/FuncOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTraits.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d19e4d233c952..4a383ec3a6bf3 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -24,6 +24,7 @@ def Affine_Dialect : Dialect {
   let cppNamespace = "mlir";
   let hasConstantMaterializer = 1;
   let dependentDialects = ["arith::ArithDialect"];
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // Base class for Affine dialect ops.

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
index 280bfdb380177..ecb8ec918cf72 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
@@ -69,6 +69,7 @@ def Bufferization_Dialect : Dialect {
         kEscapeAttrName = "bufferization.escape";
   }];
   let hasOperationAttrVerify = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // BUFFERIZATION_BASE

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
index 31135fc8c8ce7..20bc712b003b3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
@@ -22,6 +22,7 @@ def Complex_Dialect : Dialect {
   let dependentDialects = ["arith::ArithDialect"];
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // COMPLEX_BASE

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td
index 375dbcbce1d03..19b2a3281086e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td
@@ -31,6 +31,7 @@ def EmitC_Dialect : Dialect {
   let hasConstantMaterializer = 1;
   let useDefaultTypePrinterParser = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE

diff  --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 4922689fbef60..58a33a124042b 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -23,6 +23,7 @@ def Func_Dialect : Dialect {
   let cppNamespace = "::mlir::func";
   let dependentDialects = ["cf::ControlFlowDialect"];
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // Base class for Func dialect ops.

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index d5b36e495493d..7b02ed7c707ee 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -56,6 +56,7 @@ def GPU_Dialect : Dialect {
   let dependentDialects = ["arith::ArithDialect"];
   let useDefaultAttributePrinterParser = 1;
   let useDefaultTypePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 def GPU_AsyncToken : DialectType<

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 80f82650a10f8..37365f648ccc4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -31,6 +31,7 @@ def LLVM_Dialect : Dialect {
   let hasRegionArgAttrVerify = 1;
   let hasRegionResultAttrVerify = 1;
   let hasOperationAttrVerify = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 
   let extraClassDeclaration = [{
     /// Name of the data layout attributes.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 1ee9d2ac241f6..706b2ffe0cee7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -46,6 +46,7 @@ def Linalg_Dialect : Dialect {
   let hasCanonicalizer = 1;
   let hasOperationAttrVerify = 1;
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
   let extraClassDeclaration = [{
     /// Attribute name used to to memoize indexing maps for named ops.
     constexpr const static ::llvm::StringLiteral

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
index 74d69c046f06b..4a2e11d299d89 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
+++ b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
@@ -20,6 +20,7 @@ def Quantization_Dialect : Dialect {
   let cppNamespace = "::mlir::quant";
 
   let useDefaultTypePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d5a150557163c..a610562b281ae 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -25,6 +25,7 @@ def SCF_Dialect : Dialect {
   let name = "scf";
   let cppNamespace = "::mlir::scf";
   let dependentDialects = ["arith::ArithDialect"];
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // Base class for SCF dialect ops.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
index 2506da4042bd3..12066f3096301 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
@@ -83,6 +83,7 @@ def SparseTensor_Dialect : Dialect {
 
   let useDefaultAttributePrinterParser = 1;
   let useDefaultTypePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // SPARSETENSOR_BASE

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index a7bb75e9c7e70..05b094cefd992 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -23,6 +23,8 @@ def Transform_Dialect : Dialect {
     "::mlir::pdl_interp::PDLInterpDialect",
   ];
 
+  let useFoldAPI = kEmitFoldAdaptorFolder;
+
   let extraClassDeclaration = [{
       /// Returns the named PDL constraint functions available in the dialect
       /// as a map from their name to the function.

diff  --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index a3d0f0cbfd9a4..e9d00968c4599 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -34,6 +34,8 @@ def Builtin_Dialect : Dialect {
 
   public:
   }];
+
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // BUILTIN_BASE

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index be58118d32474..f4c851c1823c6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -562,7 +562,7 @@ bool AffineApplyOp::isValidSymbol(Region *region) {
   });
 }
 
-OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
   auto map = getAffineMap();
 
   // Fold dims and symbols to existing values.
@@ -574,7 +574,7 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
 
   // Otherwise, default to folding the map.
   SmallVector<Attribute, 1> result;
-  if (failed(map.constantFold(operands, result)))
+  if (failed(map.constantFold(adaptor.getMapOperands(), result)))
     return {};
   return result[0];
 }
@@ -2135,7 +2135,7 @@ static bool hasTrivialZeroTripCount(AffineForOp op) {
   return tripCount && *tripCount == 0;
 }
 
-LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
+LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
                                 SmallVectorImpl<OpFoldResult> &results) {
   bool folded = succeeded(foldLoopBounds(*this));
   folded |= succeeded(canonicalizeLoopBounds(*this));
@@ -2723,7 +2723,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
+LogicalResult AffineIfOp::fold(FoldAdaptor,
                                SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
@@ -2858,7 +2858,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyAffineOp<AffineLoadOp>>(context);
 }
 
-OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
+OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
   /// load(memrefcast) -> load
   if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
@@ -2975,7 +2975,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyAffineOp<AffineStoreOp>>(context);
 }
 
-LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
                                   SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
   return memref::foldMemRefCast(*this, getValueToStore());
@@ -3282,8 +3282,8 @@ struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
 //
 
-OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
-  return foldMinMaxOp(*this, operands);
+OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
+  return foldMinMaxOp(*this, adaptor.getOperands());
 }
 
 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -3310,8 +3310,8 @@ void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
 //
 
-OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
-  return foldMinMaxOp(*this, operands);
+OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
+  return foldMinMaxOp(*this, adaptor.getOperands());
 }
 
 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -3431,7 +3431,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
 }
 
-LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
                                      SmallVectorImpl<OpFoldResult> &results) {
   /// prefetch(memrefcast) -> prefetch
   return memref::foldMemRefCast(*this);
@@ -3705,7 +3705,7 @@ static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
   return success();
 }
 
-LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
+LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
                                      SmallVectorImpl<OpFoldResult> &results) {
   return canonicalizeLoopBounds(*this);
 }

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f021a593f2415..59263f699032b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -458,7 +458,7 @@ void CloneOp::getEffects(
                        SideEffects::DefaultResource::get());
 }
 
-OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
 }
 
@@ -560,7 +560,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ToTensorOp::fold(FoldAdaptor) {
   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
     // Approximate alias analysis by conservatively folding only when no there
     // is no interleaved operation.
@@ -596,7 +596,7 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
     if (memrefToTensor.getMemref().getType() == getType())
       return memrefToTensor.getMemref();

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 44ebf69a65199..c71e2e08b6ffb 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -17,8 +17,7 @@ using namespace mlir::complex;
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "constant has no operands");
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
   return getValue();
 }
 
@@ -68,8 +67,7 @@ LogicalResult ConstantOp::verify() {
 // CreateOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary op takes two operands");
+OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
   // Fold complex.create(complex.re(op), complex.im(op)).
   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
@@ -85,9 +83,8 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
 // ImOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-  ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
+OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
+  ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
   if (arrayAttr && arrayAttr.size() == 2)
     return arrayAttr[1];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -99,9 +96,8 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
 // ReOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-  ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
+OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
+  ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
   if (arrayAttr && arrayAttr.size() == 2)
     return arrayAttr[0];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -113,9 +109,7 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
 // AddOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary op takes 2 operands");
-
+OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   // complex.add(complex.sub(a, b), b) -> a
   if (auto sub = getLhs().getDefiningOp<SubOp>())
     if (getRhs() == sub.getRhs())
@@ -142,9 +136,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
 // SubOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary op takes 2 operands");
-
+OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   // complex.sub(complex.add(a, b), b) -> a
   if (auto add = getLhs().getDefiningOp<AddOp>())
     if (getRhs() == add.getRhs())
@@ -166,9 +158,7 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
 // NegOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-
+OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
   // complex.neg(complex.neg(a)) -> a
   if (auto negOp = getOperand().getDefiningOp<NegOp>())
     return negOp.getOperand();
@@ -180,9 +170,7 @@ OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
 // LogOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-
+OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
   // complex.log(complex.exp(a)) -> a
   if (auto expOp = getOperand().getDefiningOp<ExpOp>())
     return expOp.getOperand();
@@ -194,9 +182,7 @@ OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
 // ExpOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-
+OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
   // complex.exp(complex.log(a)) -> a
   if (auto logOp = getOperand().getDefiningOp<LogOp>())
     return logOp.getOperand();
@@ -208,9 +194,7 @@ OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
 // ConjOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ConjOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary op takes 1 operand");
-
+OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
   // complex.conj(complex.conj(a)) -> a
   if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
     return conjOp.getOperand();

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 91046c7f70f21..75d92229c287b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -129,8 +129,7 @@ LogicalResult emitc::ConstantOp::verify() {
   return success();
 }
 
-OpFoldResult emitc::ConstantOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "constant has no operands");
+OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) {
   return getValue();
 }
 

diff  --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index 7c26e6ae4fe3a..de96e46c5376f 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -201,8 +201,7 @@ LogicalResult ConstantOp::verify() {
   return success();
 }
 
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "constant has no operands");
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
   return getValueAttr();
 }
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8ef5483fb7ca9..e55e9a60fa67d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1286,12 +1286,12 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   return success();
 }
 
-LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
   return memref::foldMemRefCast(*this);
 }
 
-LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
   return memref::foldMemRefCast(*this);
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e00c6887844ef..bb47aeb040662 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1441,7 +1441,7 @@ static Type getInsertExtractValueElementType(Type llvmType,
   return llvmType;
 }
 
-OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
   OpFoldResult result = {};
   while (insertValueOp) {
@@ -2275,7 +2275,7 @@ LogicalResult LLVM::ConstantOp::verify() {
 }
 
 // Constant op constant-folds to its value.
-OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
+OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
 
 //===----------------------------------------------------------------------===//
 // Utility functions for parsing atomic ops
@@ -2513,7 +2513,7 @@ LogicalResult FenceOp::verify() {
 // Folder for LLVM::BitcastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
   // bitcast(x : T0, T0) -> x
   if (getArg().getType() == getType())
     return getArg();
@@ -2528,7 +2528,7 @@ OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
 // Folder for LLVM::AddrSpaceCastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
   // addrcast(x : T0, T0) -> x
   if (getArg().getType() == getType())
     return getArg();
@@ -2543,9 +2543,9 @@ OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
 // Folder for LLVM::GEPOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
   GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
-                                                 operands.drop_front());
+                                                 adaptor.getDynamicIndices());
 
   // gep %x:T, 0 -> %x
   if (getBase().getType() == getType() && indices.size() == 1)

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 308d003f4b9e3..7f5bccecd3707 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -980,8 +980,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<EraseIdentityGenericOp>(context);
 }
 
-LogicalResult GenericOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
+LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
 

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index fcb97ae603429..c9a6bbc9ceeea 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -36,7 +36,7 @@ void QuantizationDialect::initialize() {
   addBytecodeInterface(this);
 }
 
-OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
   // Matches x -> [scast -> scast] -> y, replacing the second scast with the
   // value of x if the casts invert each other.
   auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 25f953c22ac36..b98f9df6ad987 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1598,7 +1598,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
   regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
 }
 
-LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
+LogicalResult IfOp::fold(FoldAdaptor adaptor,
                          SmallVectorImpl<OpFoldResult> &results) {
   // if (!c) then A() else B() -> if c then B() else A()
   if (getElseRegion().empty())

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4d2c4fb22b58d..59a4b1360b473 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -467,7 +467,7 @@ LogicalResult ConvertOp::verify() {
   return emitError("unexpected type in convert");
 }
 
-OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
   Type dstType = getType();
   // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
   // convert for codegen to remove. This is because we use trivial
@@ -531,7 +531,7 @@ static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
 }
 
-OpFoldResult GetStorageSpecifierOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
   StorageSpecifierKind kind = getSpecifierKind();
   std::optional<APInt> dim = getDim();
   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 20052bd3ac15a..0b3391ea448e2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -463,7 +463,7 @@ void transform::MergeHandlesOp::getEffects(
   // manipulation.
 }
 
-OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
   if (getDeduplicate() || getHandles().size() != 1)
     return {};
 

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index b66346ffa39c0..0668459d622b1 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -190,7 +190,7 @@ LogicalResult ModuleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
+UnrealizedConversionCastOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &foldResults) {
   OperandRange operands = getInputs();
   ResultRange results = getOutputs();

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 0f48bfcf6b229..7f0ad3b69c5b9 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1099,32 +1099,31 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns(
   results.add<TestRemoveOpWithInnerOps>(context);
 }
 
-OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
   return getOperand();
 }
 
-OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) {
   return getValue();
 }
 
 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
-    ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
+    FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
   for (Value input : this->getOperands()) {
     results.push_back(input);
   }
   return success();
 }
 
-OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1);
-  if (operands.front()) {
-    (*this)->setAttr("attr", operands.front());
+OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
+  if (adaptor.getOp()) {
+    (*this)->setAttr("attr", adaptor.getOp());
     return getResult();
   }
   return {};
 }
 
-OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
   return getOperand();
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 9ec12749bfcc3..9dc4203915392 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -23,6 +23,7 @@ def Test_Dialect : Dialect {
   let hasNonDefaultDestructor = 1;
   let useDefaultTypePrinterParser = 0;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
   let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7d661368c0e82..e9816ebcc13e8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1290,7 +1290,7 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
   let results = (outs Variadic<I1>);
   let hasFolder = 1;
   let extraClassDefinition = [{
-    ::mlir::LogicalResult $cppClass::fold(ArrayRef<Attribute> operands,
+    ::mlir::LogicalResult $cppClass::fold(FoldAdaptor adaptor,
         SmallVectorImpl<OpFoldResult> &results) {
       return success();
     }
@@ -1315,11 +1315,7 @@ def TestOpFoldWithFoldAdaptor
     $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
   }];
 
-  let hasFolder = 0;
-
-  let extraClassDeclaration = [{
-    ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
-  }];
+  let hasFolder = 1;
 }
 
 // An op that always fold itself.

diff  --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index 0ccffc7a47a02..d9b67ef95ace8 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -18,13 +18,13 @@ using namespace test;
 //===----------------------------------------------------------------------===//
 
 OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
-    ArrayRef<Attribute> operands) {
+    FoldAdaptor adaptor) {
   // This failure should cause the trait fold to run instead.
   return {};
 }
 
 OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
-    ArrayRef<Attribute> operands) {
+    FoldAdaptor adaptor) {
   auto argumentOp = getOperand();
   // The success case should cause the trait fold to be supressed.
   return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 9b7816c952278..d92c36795b2bd 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -654,7 +654,7 @@ ArrayAttr {0}::getIndexingMaps() {{
 // Parameters:
 // {0}: Class name
 const char structuredOpFoldersFormat[] = R"FMT(
-LogicalResult {0}::fold(ArrayRef<Attribute>,
+LogicalResult {0}::fold(FoldAdaptor,
                         SmallVectorImpl<OpFoldResult> &) {{
   return memref::foldMemRefCast(*this);
 }


        


More information about the Mlir-commits mailing list