[Mlir-commits] [mlir] 75044e9 - [mlir] Flipping vector dialect to both prefixed form.

Jacques Pienaar llvmlistbot at llvm.org
Tue Feb 15 09:48:56 PST 2022


Author: Jacques Pienaar
Date: 2022-02-15T09:48:51-08:00
New Revision: 75044e9b4f20d025295dbd56284435937cfb4de5

URL: https://github.com/llvm/llvm-project/commit/75044e9b4f20d025295dbd56284435937cfb4de5
DIFF: https://github.com/llvm/llvm-project/commit/75044e9b4f20d025295dbd56284435937cfb4de5.diff

LOG: [mlir] Flipping vector dialect to both prefixed form.

Following
https://discourse.llvm.org/t/psa-ods-generated-accessors-will-change-to-have-a-get-prefix-update-you-apis/4476

Mostly mechanical, avoiding function name conflicts.

Differential Revision: https://reviews.llvm.org/D119607

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 009df114ec2c2..66d4a69593358 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -22,6 +22,7 @@ def Vector_Dialect : Dialect {
   let cppNamespace = "::mlir::vector";
   let hasConstantMaterializer = 1;
   let dependentDialects = ["arith::ArithmeticDialect"];
+  let emitAccessorPrefix = kEmitAccessorPrefix_Both;
 }
 
 // Base class for Vector dialect ops.
@@ -63,6 +64,15 @@ def Vector_CombiningKindAttr : DialectAttr<
           "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
 }
 
+def Vector_AffineMapArrayAttr : TypedArrayAttrBase<AffineMapAttr,
+                                      "AffineMap array attribute"> {
+  let returnType = [{ ::llvm::SmallVector<::mlir::AffineMap, 4> }];
+  let convertFromStorage = [{
+    llvm::to_vector<4>($_self.getAsValueRange<::mlir::AffineMapAttr>());
+  }];
+  let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)";
+}
+
 // TODO: Add an attribute to specify a 
diff erent algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -75,7 +85,8 @@ def Vector_ContractionOp :
     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
-               AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types,
+               Vector_AffineMapArrayAttr:$indexing_maps,
+	       ArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
                                  "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyType)> {
@@ -223,7 +234,6 @@ def Vector_ContractionOp :
     }
     Type getResultType() { return getResult().getType(); }
     ArrayRef<StringRef> getTraitAttrNames();
-    SmallVector<AffineMap, 4> getIndexingMaps();
     static unsigned getAccOperandIndex() { return 2; }
 
     // Returns the bounds of each dimension in the iteration space spanned
@@ -240,7 +250,7 @@ def Vector_ContractionOp :
     std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
     std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
 
-    static constexpr StringRef getKindAttrName() { return "kind"; }
+    static constexpr StringRef getKindAttrStrName() { return "kind"; }
 
     static CombiningKind getDefaultKind() {
       return CombiningKind::ADD;
@@ -327,8 +337,8 @@ def Vector_MultiDimReductionOp :
                    "CombiningKind":$kind)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getKindAttrName() { return "kind"; }
-    static StringRef getReductionDimsAttrName() { return "reduction_dims"; }
+    static StringRef getKindAttrStrName() { return "kind"; }
+    static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; }
 
     VectorType getSourceVectorType() {
       return source().getType().cast<VectorType>();
@@ -474,7 +484,7 @@ def Vector_ShuffleOp :
   ];
   let hasFolder = 1;
   let extraClassDeclaration = [{
-    static StringRef getMaskAttrName() { return "mask"; }
+    static StringRef getMaskAttrStrName() { return "mask"; }
     VectorType getV1VectorType() {
       return v1().getType().cast<VectorType>();
     }
@@ -561,7 +571,7 @@ def Vector_ExtractOp :
     OpBuilder<(ins "Value":$source, "ValueRange":$position)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getPositionAttrName() { return "position"; }
+    static StringRef getPositionAttrStrName() { return "position"; }
     VectorType getVectorType() {
       return vector().getType().cast<VectorType>();
     }
@@ -754,7 +764,7 @@ def Vector_InsertOp :
     OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getPositionAttrName() { return "position"; }
+    static StringRef getPositionAttrStrName() { return "position"; }
     Type getSourceType() { return source().getType(); }
     VectorType getDestVectorType() {
       return dest().getType().cast<VectorType>();
@@ -873,15 +883,15 @@ def Vector_InsertStridedSliceOp :
       "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getOffsetsAttrName() { return "offsets"; }
-    static StringRef getStridesAttrName() { return "strides"; }
+    static StringRef getOffsetsAttrStrName() { return "offsets"; }
+    static StringRef getStridesAttrStrName() { return "strides"; }
     VectorType getSourceVectorType() {
       return source().getType().cast<VectorType>();
     }
     VectorType getDestVectorType() {
       return dest().getType().cast<VectorType>();
     }
-    bool hasNonUnitStrides() { 
+    bool hasNonUnitStrides() {
       return llvm::any_of(strides(), [](Attribute attr) {
         return attr.cast<IntegerAttr>().getInt() != 1;
       });
@@ -970,7 +980,7 @@ def Vector_OuterProductOp :
     VectorType getVectorType() {
       return getResult().getType().cast<VectorType>();
     }
-    static constexpr StringRef getKindAttrName() {
+    static constexpr StringRef getKindAttrStrName() {
       return "kind";
     }
     static CombiningKind getDefaultKind() {
@@ -1089,11 +1099,11 @@ def Vector_ReshapeOp :
 
     void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
 
-    static StringRef getFixedVectorSizesAttrName() {
+    static StringRef getFixedVectorSizesAttrStrName() {
       return "fixed_vector_sizes";
     }
-    static StringRef getInputShapeAttrName() { return "input_shape"; }
-    static StringRef getOutputShapeAttrName() { return "output_shape"; }
+    static StringRef getInputShapeAttrStrName() { return "input_shape"; }
+    static StringRef getOutputShapeAttrStrName() { return "output_shape"; }
   }];
 
   let assemblyFormat = [{
@@ -1140,12 +1150,12 @@ def Vector_ExtractStridedSliceOp :
       "ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getOffsetsAttrName() { return "offsets"; }
-    static StringRef getSizesAttrName() { return "sizes"; }
-    static StringRef getStridesAttrName() { return "strides"; }
+    static StringRef getOffsetsAttrStrName() { return "offsets"; }
+    static StringRef getSizesAttrStrName() { return "sizes"; }
+    static StringRef getStridesAttrStrName() { return "strides"; }
     VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
     void getOffsets(SmallVectorImpl<int64_t> &results);
-    bool hasNonUnitStrides() { 
+    bool hasNonUnitStrides() {
       return llvm::any_of(strides(), [](Attribute attr) {
         return attr.cast<IntegerAttr>().getInt() != 1;
       });
@@ -2190,7 +2200,7 @@ def Vector_ConstantMaskOp :
   }];
 
   let extraClassDeclaration = [{
-    static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
+    static StringRef getMaskDimSizesAttrStrName() { return "mask_dim_sizes"; }
   }];
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
@@ -2276,7 +2286,7 @@ def Vector_TransposeOp :
       return result().getType().cast<VectorType>();
     }
     void getTransp(SmallVectorImpl<int64_t> &results);
-    static StringRef getTranspAttrName() { return "transp"; }
+    static StringRef getTranspAttrStrName() { return "transp"; }
   }];
   let assemblyFormat = [{
     $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
@@ -2537,8 +2547,8 @@ def Vector_ScanOp :
                    CArg<"bool", "true">:$inclusive)>
   ];
   let extraClassDeclaration = [{
-    static StringRef getKindAttrName() { return "kind"; }
-    static StringRef getReductionDimAttrName() { return "reduction_dim"; }
+    static StringRef getKindAttrStrName() { return "kind"; }
+    static StringRef getReductionDimAttrStrName() { return "reduction_dim"; }
     VectorType getSourceType() {
       return source().getType().cast<VectorType>();
     }

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 68b88860b2ff3..ee6c638d402c5 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -55,7 +55,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
     StaticInterfaceMethod<
       /*desc=*/"Return the `in_bounds` attribute name.",
       /*retTy=*/"::mlir::StringRef",
-      /*methodName=*/"getInBoundsAttrName",
+      /*methodName=*/"getInBoundsAttrStrName",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/ [{ return "in_bounds"; }]
@@ -63,7 +63,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
     StaticInterfaceMethod<
       /*desc=*/"Return the `permutation_map` attribute name.",
       /*retTy=*/"::mlir::StringRef",
-      /*methodName=*/"getPermutationMapAttrName",
+      /*methodName=*/"getPermutationMapAttrStrName",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/ [{ return "permutation_map"; }]

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index dec9eec703884..8650d574de289 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -318,7 +318,7 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write,
     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
   } else {
     newForOp.getResult(initArgNumber)
-        .replaceAllUsesWith(write.transferWriteOp.getResult(0));
+        .replaceAllUsesWith(write.transferWriteOp.getResult());
     write.transferWriteOp.sourceMutable().assign(
         newForOp.getResult(initArgNumber));
   }

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2d504cb0029c4..4db150927fae5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -347,9 +347,9 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   for (const auto &en : llvm::enumerate(reductionMask))
     if (en.value())
       reductionDims.push_back(en.index());
-  result.addAttribute(getReductionDimsAttrName(),
+  result.addAttribute(getReductionDimsAttrStrName(),
                       builder.getI64ArrayAttr(reductionDims));
-  result.addAttribute(getKindAttrName(),
+  result.addAttribute(getKindAttrStrName(),
                       CombiningKindAttr::get(kind, builder.getContext()));
 }
 
@@ -491,10 +491,10 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                   ArrayRef<StringRef> iteratorTypes) {
   result.addOperands({lhs, rhs, acc});
   result.addTypes(acc.getType());
-  result.addAttribute(getIndexingMapsAttrName(),
+  result.addAttribute(::mlir::getIndexingMapsAttrName(),
                       builder.getAffineMapArrayAttr(
                           AffineMap::inferFromExprList(indexingExprs)));
-  result.addAttribute(getIteratorTypesAttrName(),
+  result.addAttribute(::mlir::getIteratorTypesAttrName(),
                       builder.getStrArrayAttr(iteratorTypes));
 }
 
@@ -512,9 +512,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                   ArrayAttr iteratorTypes, CombiningKind kind) {
   result.addOperands({lhs, rhs, acc});
   result.addTypes(acc.getType());
-  result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
-  result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
-  result.addAttribute(ContractionOp::getKindAttrName(),
+  result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
+  result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
+  result.addAttribute(ContractionOp::getKindAttrStrName(),
                       CombiningKindAttr::get(kind, builder.getContext()));
 }
 
@@ -543,8 +543,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
-  if (!result.attributes.get(ContractionOp::getKindAttrName())) {
-    result.addAttribute(ContractionOp::getKindAttrName(),
+  if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
+    result.addAttribute(ContractionOp::getKindAttrStrName(),
                         CombiningKindAttr::get(ContractionOp::getDefaultKind(),
                                                result.getContext()));
   }
@@ -698,7 +698,7 @@ LogicalResult ContractionOp::verify() {
   unsigned numIterators = iterator_types().getValue().size();
   for (const auto &it : llvm::enumerate(indexing_maps())) {
     auto index = it.index();
-    auto map = it.value().cast<AffineMapAttr>().getValue();
+    auto map = it.value();
     if (map.getNumSymbols() != 0)
       return emitOpError("expected indexing map ")
              << index << " to have no symbols";
@@ -759,9 +759,9 @@ LogicalResult ContractionOp::verify() {
 }
 
 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
-  static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
-                                         getIteratorTypesAttrName(),
-                                         ContractionOp::getKindAttrName()};
+  static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
+                                         ::mlir::getIteratorTypesAttrName(),
+                                         ContractionOp::getKindAttrStrName()};
   return llvm::makeArrayRef(names);
 }
 
@@ -817,11 +817,11 @@ void ContractionOp::getIterationBounds(
 
 void ContractionOp::getIterationIndexMap(
     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
-  unsigned numMaps = indexing_maps().getValue().size();
+  unsigned numMaps = indexing_maps().size();
   iterationIndexMap.resize(numMaps);
   for (const auto &it : llvm::enumerate(indexing_maps())) {
     auto index = it.index();
-    auto map = it.value().cast<AffineMapAttr>().getValue();
+    auto map = it.value();
     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
       auto dim = map.getResult(i).cast<AffineDimExpr>();
       iterationIndexMap[index][dim.getPosition()] = i;
@@ -841,13 +841,6 @@ std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
                    getParallelIteratorTypeName(), getContext());
 }
 
-SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
-  return llvm::to_vector<4>(
-      llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
-        return mapAttr.cast<AffineMapAttr>().getValue();
-      }));
-}
-
 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
   SmallVector<int64_t, 4> shape;
   getIterationBounds(shape);
@@ -961,7 +954,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
   auto positionAttr = getVectorSubscriptAttr(builder, position);
   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
                                            positionAttr));
-  result.addAttribute(getPositionAttrName(), positionAttr);
+  result.addAttribute(getPositionAttrStrName(), positionAttr);
 }
 
 // Convenience builder which assumes the values are constant indices.
@@ -1053,7 +1046,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   std::reverse(globalPosition.begin(), globalPosition.end());
-  extractOp->setAttr(ExtractOp::getPositionAttrName(),
+  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
                      b.getI64ArrayAttr(globalPosition));
   return success();
 }
@@ -1295,7 +1288,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     extractOp.setOperand(source);
     // OpBuilder is only used as a helper to build an I64ArrayAttr.
     OpBuilder b(extractOp.getContext());
-    extractOp->setAttr(ExtractOp::getPositionAttrName(),
+    extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
                        b.getI64ArrayAttr(extractPos));
     return extractOp.getResult();
   }
@@ -1355,7 +1348,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp->setAttr(ExtractOp::getPositionAttrName(),
+  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
                      b.getI64ArrayAttr(newPosition));
   extractOp.setOperand(shapeCastOp.source());
   return extractOp.getResult();
@@ -1396,7 +1389,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
   extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp->setAttr(ExtractOp::getPositionAttrName(),
+  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
                      b.getI64ArrayAttr(extractedPos));
   return extractOp.getResult();
 }
@@ -1453,7 +1446,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
       op.vectorMutable().assign(insertOp.source());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(op.getContext());
-      op->setAttr(ExtractOp::getPositionAttrName(),
+      op->setAttr(ExtractOp::getPositionAttrStrName(),
                   b.getI64ArrayAttr(offsetDiffs));
       return op.getResult();
     }
@@ -1736,7 +1729,7 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
   auto shape = llvm::to_vector<4>(v1Type.getShape());
   shape[0] = mask.size();
   result.addTypes(VectorType::get(shape, v1Type.getElementType()));
-  result.addAttribute(getMaskAttrName(), maskAttr);
+  result.addAttribute(getMaskAttrStrName(), maskAttr);
 }
 
 void ShuffleOp::print(OpAsmPrinter &p) {
@@ -1784,7 +1777,7 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
   VectorType v1Type, v2Type;
   if (parser.parseOperand(v1) || parser.parseComma() ||
       parser.parseOperand(v2) ||
-      parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
+      parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(),
                             result.attributes) ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(v1Type) || parser.parseComma() ||
@@ -1877,7 +1870,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
   result.addOperands({source, dest});
   auto positionAttr = getVectorSubscriptAttr(builder, position);
   result.addTypes(dest.getType());
-  result.addAttribute(getPositionAttrName(), positionAttr);
+  result.addAttribute(getPositionAttrStrName(), positionAttr);
 }
 
 // Convenience builder which assumes the values are constant indices.
@@ -1995,8 +1988,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
   result.addTypes(dest.getType());
-  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
-  result.addAttribute(getStridesAttrName(), stridesAttr);
+  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
+  result.addAttribute(getStridesAttrStrName(), stridesAttr);
 }
 
 // TODO: Should be moved to Tablegen Confined attributes.
@@ -2172,9 +2165,9 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
                              vLHS.getElementType())
            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
 
-  if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
+  if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
     result.attributes.append(
-        OuterProductOp::getKindAttrName(),
+        OuterProductOp::getKindAttrStrName(),
         CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
                                result.getContext()));
   }
@@ -2322,9 +2315,9 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
   result.addTypes(
       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
                                     offsetsAttr, sizesAttr, stridesAttr));
-  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
-  result.addAttribute(getSizesAttrName(), sizesAttr);
-  result.addAttribute(getStridesAttrName(), stridesAttr);
+  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
+  result.addAttribute(getSizesAttrStrName(), sizesAttr);
+  result.addAttribute(getStridesAttrStrName(), stridesAttr);
 }
 
 LogicalResult ExtractStridedSliceOp::verify() {
@@ -2412,7 +2405,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
       op.setOperand(insertOp.source());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(op.getContext());
-      op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
+      op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
                   b.getI64ArrayAttr(offsetDiffs));
       return success();
     }
@@ -2765,7 +2758,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 3> elidedAttrs;
   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
   if (op.permutation_map().isMinorIdentity())
-    elidedAttrs.push_back(op.getPermutationMapAttrName());
+    elidedAttrs.push_back(op.getPermutationMapAttrStrName());
   bool elideInBounds = true;
   if (auto inBounds = op.in_bounds()) {
     for (auto attr : *inBounds) {
@@ -2776,7 +2769,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
     }
   }
   if (elideInBounds)
-    elidedAttrs.push_back(op.getInBoundsAttrName());
+    elidedAttrs.push_back(op.getInBoundsAttrStrName());
   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 }
 
@@ -2817,7 +2810,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   VectorType vectorType = types[1].dyn_cast<VectorType>();
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
-  auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
+  auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
   Attribute mapAttr = result.attributes.get(permutationAttrName);
   if (!mapAttr) {
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
@@ -2963,7 +2956,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
     return failure();
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(op.getContext());
-  op->setAttr(TransferOp::getInBoundsAttrName(),
+  op->setAttr(TransferOp::getInBoundsAttrStrName(),
               b.getBoolArrayAttr(newInBounds));
   return success();
 }
@@ -3193,7 +3186,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   ShapedType shapedType = types[1].dyn_cast<ShapedType>();
   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
-  auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
+  auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
   auto attr = result.attributes.get(permutationAttrName);
   if (!attr) {
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
@@ -4151,7 +4144,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
   result.addOperands(vector);
   result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
-  result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
+  result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
 }
 
 // Eliminates transpose operations, which produce values identical to their

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index f574713ffb2a4..48470f7b059d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -514,7 +514,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
   auto inBoundsAttr = b.getBoolArrayAttr(bools);
   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
-    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
     return success();
   }
 
@@ -585,7 +585,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
 
-    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 226faccaba96c..f9413a7468187 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1050,7 +1050,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
   bindDims(rew.getContext(), m, n, k);
   // LHS must be A(m, k) or A(k, m).
   Value lhs = op.lhs();
-  auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue();
+  auto lhsMap = op.indexing_maps()[0];
   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
@@ -1058,7 +1058,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
 
   // RHS must be B(k, n) or B(n, k).
   Value rhs = op.rhs();
-  auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue();
+  auto rhsMap = op.indexing_maps()[1];
   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
@@ -1088,7 +1088,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
       mul);
 
   // ACC must be C(m, n) or C(n, m).
-  auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue();
+  auto accMap = op.indexing_maps()[2];
   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))


        


More information about the Mlir-commits mailing list