[Mlir-commits] [mlir] 7039bd2 - [mlir][MemRef][NFC] Migrate MemRef dialect to the new fold API
Markus Böck
llvmlistbot at llvm.org
Wed Jan 11 12:47:33 PST 2023
Author: Markus Böck
Date: 2023-01-11T21:47:25+01:00
New Revision: 7039bd25093fb73bff0426f5987b75006f65889b
URL: https://github.com/llvm/llvm-project/commit/7039bd25093fb73bff0426f5987b75006f65889b
DIFF: https://github.com/llvm/llvm-project/commit/7039bd25093fb73bff0426f5987b75006f65889b.diff
LOG: [mlir][MemRef][NFC] Migrate MemRef dialect to the new fold API
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context
Differential Revision: https://reviews.llvm.org/D141529
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
index 3be84ae654f6a..1e7d5816550cc 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
@@ -21,6 +21,7 @@ def MemRef_Dialect : Dialect {
}];
let dependentDialects = ["arith::ArithDialect"];
let hasConstantMaterializer = 1;
+ let useFoldAPI = kEmitFoldAdaptorFolder;
}
#endif // MEMREF_BASE
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78a606de706d9..d3a1ae1663c01 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -808,7 +808,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
@@ -883,7 +883,7 @@ void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldCopyOfCast, FoldSelfCopy>(context);
}
-LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult CopyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// copy(memrefcast) -> copy
bool folded = false;
@@ -902,7 +902,7 @@ LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
// DeallocOp
//===----------------------------------------------------------------------===//
-LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dealloc(memrefcast) -> dealloc
return foldMemRefCast(*this);
@@ -1056,9 +1056,9 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
return *unusedDims;
}
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
if (!index)
return {};
@@ -1322,7 +1322,7 @@ LogicalResult DmaStartOp::verify() {
return success();
}
-LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
return foldMemRefCast(*this);
@@ -1332,7 +1332,7 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
// DmaWaitOp
// ---------------------------------------------------------------------------
-LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
return foldMemRefCast(*this);
@@ -1433,7 +1433,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
}
LogicalResult
-ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
+ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
OpBuilder builder(*this);
@@ -1677,7 +1677,7 @@ LogicalResult LoadOp::verify() {
return success();
}
-OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
+OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
/// load(memrefcast) -> load
if (succeeded(foldMemRefCast(*this)))
return getResult();
@@ -1747,7 +1747,7 @@ LogicalResult PrefetchOp::verify() {
return success();
}
-LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// prefetch(memrefcast) -> prefetch
return foldMemRefCast(*this);
@@ -1757,7 +1757,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
// RankOp
//===----------------------------------------------------------------------===//
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
auto shapedType = type.dyn_cast<ShapedType>();
@@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
return success();
}
-OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
+OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
Value src = getSource();
auto getPrevSrc = [&]() -> Value {
// reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
@@ -2465,12 +2465,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseShapeOpMemRefCastFolder>(context);
}
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
+ adaptor.getOperands());
}
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
+ adaptor.getOperands());
}
//===----------------------------------------------------------------------===//
@@ -2522,7 +2524,7 @@ LogicalResult StoreOp::verify() {
return success();
}
-LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+LogicalResult StoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
return foldMemRefCast(*this, getValueToStore());
@@ -3101,7 +3103,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
}
-OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto resultShapedType = getResult().getType().cast<ShapedType>();
auto sourceShapedType = getSource().getType().cast<ShapedType>();
@@ -3217,7 +3219,7 @@ LogicalResult TransposeOp::verify() {
return success();
}
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
+OpFoldResult TransposeOp::fold(FoldAdaptor) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
@@ -3393,7 +3395,7 @@ LogicalResult AtomicRMWOp::verify() {
return success();
}
-OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
/// atomicrmw(memrefcast) -> atomicrmw
if (succeeded(foldMemRefCast(*this, getValue())))
return getResult();
More information about the Mlir-commits
mailing list