[Mlir-commits] [mlir] 436ce71 - [mlir][Vector][NFC] Migrate Vector dialect to the new fold API
Markus Böck
llvmlistbot at llvm.org
Thu Jan 12 00:52:22 PST 2023
Author: Markus Böck
Date: 2023-01-12T09:52:14+01:00
New Revision: 436ce713e34b6e3b20b800b1db5e651dc9cea14e
URL: https://github.com/llvm/llvm-project/commit/436ce713e34b6e3b20b800b1db5e651dc9cea14e
DIFF: https://github.com/llvm/llvm-project/commit/436ce713e34b6e3b20b800b1db5e651dc9cea14e.diff
LOG: [mlir][Vector][NFC] Migrate Vector dialect to the new fold API
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context
Differential Revision: https://reviews.llvm.org/D141526
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8a5d1025167be..df5b7c597f6ab 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -30,6 +30,7 @@ def Vector_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
+ let useFoldAPI = kEmitFoldAdaptorFolder;
}
// Base class for Vector dialect ops.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3885d57de30b7..470b57b887735 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -308,7 +308,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
builder.getI64ArrayAttr(reductionDims));
}
-OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();
@@ -1035,13 +1035,13 @@ LogicalResult vector::ExtractElementOp::verify() {
return success();
}
-OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// Skip the 0-D vector here now.
- if (operands.size() < 2)
+ if (!adaptor.getPosition())
return {};
- Attribute src = operands[0];
- Attribute pos = operands[1];
+ Attribute src = adaptor.getVector();
+ Attribute pos = adaptor.getPosition();
// Fold extractelement (splat X) -> X.
if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
@@ -1587,7 +1587,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
return Value();
}
-OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ExtractOp::fold(FoldAdaptor) {
if (getPosition().empty())
return getVector();
if (succeeded(foldExtractOpFromExtractChain(*this)))
@@ -1918,15 +1918,15 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
-OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getVectorType())
return getSource();
- if (!operands[0])
+ if (!adaptor.getSource())
return {};
auto vectorType = getVectorType();
- if (operands[0].isa<IntegerAttr, FloatAttr>())
- return DenseElementsAttr::get(vectorType, operands[0]);
- if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
+ if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
+ return DenseElementsAttr::get(vectorType, adaptor.getSource());
+ if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
}
@@ -2034,7 +2034,7 @@ static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
});
}
-OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
VectorType v1Type = getV1VectorType();
// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
// but must be a canonicalization into a vector.broadcast.
@@ -2051,7 +2051,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
getV2VectorType().getDimSize(0)))
return getV2();
- Attribute lhs = operands.front(), rhs = operands.back();
+ Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
if (!lhs || !rhs)
return {};
@@ -2154,14 +2154,14 @@ LogicalResult InsertElementOp::verify() {
return success();
}
-OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// Skip the 0-D vector here.
- if (operands.size() < 3)
+ if (!adaptor.getPosition())
return {};
- Attribute src = operands[0];
- Attribute dst = operands[1];
- Attribute pos = operands[2];
+ Attribute src = adaptor.getSource();
+ Attribute dst = adaptor.getDest();
+ Attribute pos = adaptor.getPosition();
if (!src || !dst || !pos)
return {};
@@ -2335,7 +2335,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Eliminates insert operations that produce values identical to their source
// value. This happens when the source and destination vectors have identical
// sizes.
-OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (getPosition().empty())
return getSource();
return {};
@@ -2621,7 +2621,7 @@ void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
InsertStridedSliceConstantFolder>(context);
}
-OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
if (getSourceVectorType() == getDestVectorType())
return getSource();
return {};
@@ -2929,7 +2929,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
return failure();
}
-OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
if (getVectorType() == getResult().getType())
return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
@@ -3564,7 +3564,7 @@ static Value foldRAW(TransferReadOp readOp) {
return {};
}
-OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult TransferReadOp::fold(FoldAdaptor) {
if (Value vec = foldRAW(*this))
return vec;
/// transfer_read(memrefcast) -> transfer_read
@@ -4039,9 +4039,9 @@ static LogicalResult foldWAR(TransferWriteOp write,
return success();
}
-LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
+LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- if (succeeded(foldReadInitWrite(*this, operands, results)))
+ if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
return success();
if (succeeded(foldWAR(*this, results)))
return success();
@@ -4346,7 +4346,7 @@ LogicalResult vector::LoadOp::verify() {
return success();
}
-OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult LoadOp::fold(FoldAdaptor) {
if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
@@ -4379,7 +4379,7 @@ LogicalResult vector::StoreOp::verify() {
return success();
}
-LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
+LogicalResult StoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return memref::foldMemRefCast(*this);
}
@@ -4432,7 +4432,7 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<MaskedLoadFolder>(context);
}
-OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
@@ -4483,7 +4483,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<MaskedStoreFolder>(context);
}
-LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return memref::foldMemRefCast(*this);
}
@@ -4754,7 +4754,7 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
if (getSource().getType() == getResult().getType())
return getSource();
@@ -4888,7 +4888,7 @@ LogicalResult BitCastOp::verify() {
return success();
}
-OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
// Nop cast.
if (getSource().getType() == getResult().getType())
return getSource();
@@ -4902,7 +4902,7 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
return getResult();
}
- Attribute sourceConstant = operands.front();
+ Attribute sourceConstant = adaptor.getSource();
if (!sourceConstant)
return {};
@@ -4995,9 +4995,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
}
-OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
- if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
+ if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
if (attr.isSplat())
return attr.reshape(getResultType());
@@ -5495,8 +5495,8 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
// SplatOp
//===----------------------------------------------------------------------===//
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
+OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
+ auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};
More information about the Mlir-commits
mailing list