[Mlir-commits] [mlir] 436ce71 - [mlir][Vector][NFC] Migrate Vector dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Thu Jan 12 00:52:22 PST 2023


Author: Markus Böck
Date: 2023-01-12T09:52:14+01:00
New Revision: 436ce713e34b6e3b20b800b1db5e651dc9cea14e

URL: https://github.com/llvm/llvm-project/commit/436ce713e34b6e3b20b800b1db5e651dc9cea14e
DIFF: https://github.com/llvm/llvm-project/commit/436ce713e34b6e3b20b800b1db5e651dc9cea14e.diff

LOG: [mlir][Vector][NFC] Migrate Vector dialect to the new fold API

See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Differential Revision: https://reviews.llvm.org/D141526

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8a5d1025167be..df5b7c597f6ab 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -30,6 +30,7 @@ def Vector_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
   let dependentDialects = ["arith::ArithDialect"];
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // Base class for Vector dialect ops.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3885d57de30b7..470b57b887735 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -308,7 +308,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
         builder.getI64ArrayAttr(reductionDims));
 }
 
-OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
   // Single parallel dim, this is a noop.
   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
     return getSource();
@@ -1035,13 +1035,13 @@ LogicalResult vector::ExtractElementOp::verify() {
   return success();
 }
 
-OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   // Skip the 0-D vector here now.
-  if (operands.size() < 2)
+  if (!adaptor.getPosition())
     return {};
 
-  Attribute src = operands[0];
-  Attribute pos = operands[1];
+  Attribute src = adaptor.getVector();
+  Attribute pos = adaptor.getPosition();
 
   // Fold extractelement (splat X) -> X.
   if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
@@ -1587,7 +1587,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
   return Value();
 }
 
-OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ExtractOp::fold(FoldAdaptor) {
   if (getPosition().empty())
     return getVector();
   if (succeeded(foldExtractOpFromExtractChain(*this)))
@@ -1918,15 +1918,15 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
-OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (getSourceType() == getVectorType())
     return getSource();
-  if (!operands[0])
+  if (!adaptor.getSource())
     return {};
   auto vectorType = getVectorType();
-  if (operands[0].isa<IntegerAttr, FloatAttr>())
-    return DenseElementsAttr::get(vectorType, operands[0]);
-  if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
+  if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
+    return DenseElementsAttr::get(vectorType, adaptor.getSource());
+  if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
   return {};
 }
@@ -2034,7 +2034,7 @@ static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
                       });
 }
 
-OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
   VectorType v1Type = getV1VectorType();
   // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
   // but must be a canonicalization into a vector.broadcast.
@@ -2051,7 +2051,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
                        getV2VectorType().getDimSize(0)))
     return getV2();
 
-  Attribute lhs = operands.front(), rhs = operands.back();
+  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
   if (!lhs || !rhs)
     return {};
 
@@ -2154,14 +2154,14 @@ LogicalResult InsertElementOp::verify() {
   return success();
 }
 
-OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
   // Skip the 0-D vector here.
-  if (operands.size() < 3)
+  if (!adaptor.getPosition())
     return {};
 
-  Attribute src = operands[0];
-  Attribute dst = operands[1];
-  Attribute pos = operands[2];
+  Attribute src = adaptor.getSource();
+  Attribute dst = adaptor.getDest();
+  Attribute pos = adaptor.getPosition();
   if (!src || !dst || !pos)
     return {};
 
@@ -2335,7 +2335,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Eliminates insert operations that produce values identical to their source
 // value. This happens when the source and destination vectors have identical
 // sizes.
-OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (getPosition().empty())
     return getSource();
   return {};
@@ -2621,7 +2621,7 @@ void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
               InsertStridedSliceConstantFolder>(context);
 }
 
-OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
   if (getSourceVectorType() == getDestVectorType())
     return getSource();
   return {};
@@ -2929,7 +2929,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
   return failure();
 }
 
-OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
   if (getVectorType() == getResult().getType())
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
@@ -3564,7 +3564,7 @@ static Value foldRAW(TransferReadOp readOp) {
   return {};
 }
 
-OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult TransferReadOp::fold(FoldAdaptor) {
   if (Value vec = foldRAW(*this))
     return vec;
   /// transfer_read(memrefcast) -> transfer_read
@@ -4039,9 +4039,9 @@ static LogicalResult foldWAR(TransferWriteOp write,
   return success();
 }
 
-LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
+LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
                                     SmallVectorImpl<OpFoldResult> &results) {
-  if (succeeded(foldReadInitWrite(*this, operands, results)))
+  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
     return success();
   if (succeeded(foldWAR(*this, results)))
     return success();
@@ -4346,7 +4346,7 @@ LogicalResult vector::LoadOp::verify() {
   return success();
 }
 
-OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult LoadOp::fold(FoldAdaptor) {
   if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
@@ -4379,7 +4379,7 @@ LogicalResult vector::StoreOp::verify() {
   return success();
 }
 
-LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
+LogicalResult StoreOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {
   return memref::foldMemRefCast(*this);
 }
@@ -4432,7 +4432,7 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MaskedLoadFolder>(context);
 }
 
-OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
@@ -4483,7 +4483,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MaskedStoreFolder>(context);
 }
 
-LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
                                   SmallVectorImpl<OpFoldResult> &results) {
   return memref::foldMemRefCast(*this);
 }
@@ -4754,7 +4754,7 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   // No-op shape cast.
   if (getSource().getType() == getResult().getType())
     return getSource();
@@ -4888,7 +4888,7 @@ LogicalResult BitCastOp::verify() {
   return success();
 }
 
-OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
   // Nop cast.
   if (getSource().getType() == getResult().getType())
     return getSource();
@@ -4902,7 +4902,7 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
     return getResult();
   }
 
-  Attribute sourceConstant = operands.front();
+  Attribute sourceConstant = adaptor.getSource();
   if (!sourceConstant)
     return {};
 
@@ -4995,9 +4995,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
   result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
 }
 
-OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
+  if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
     if (attr.isSplat())
       return attr.reshape(getResultType());
 
@@ -5495,8 +5495,8 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
 // SplatOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
+OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
+  auto constOperand = adaptor.getInput();
   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
     return {};
 


        


More information about the Mlir-commits mailing list