[Mlir-commits] [mlir] 75044e9 - [mlir] Flipping vector dialect to both prefixed form.
Jacques Pienaar
llvmlistbot at llvm.org
Tue Feb 15 09:48:56 PST 2022
Author: Jacques Pienaar
Date: 2022-02-15T09:48:51-08:00
New Revision: 75044e9b4f20d025295dbd56284435937cfb4de5
URL: https://github.com/llvm/llvm-project/commit/75044e9b4f20d025295dbd56284435937cfb4de5
DIFF: https://github.com/llvm/llvm-project/commit/75044e9b4f20d025295dbd56284435937cfb4de5.diff
LOG: [mlir] Flipping vector dialect to both prefixed form.
Following
https://discourse.llvm.org/t/psa-ods-generated-accessors-will-change-to-have-a-get-prefix-update-you-apis/4476
Mostly mechanical, avoiding function name conflicts.
Differential Revision: https://reviews.llvm.org/D119607
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 009df114ec2c2..66d4a69593358 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -22,6 +22,7 @@ def Vector_Dialect : Dialect {
let cppNamespace = "::mlir::vector";
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithmeticDialect"];
+ let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
// Base class for Vector dialect ops.
@@ -63,6 +64,15 @@ def Vector_CombiningKindAttr : DialectAttr<
"::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
}
+def Vector_AffineMapArrayAttr : TypedArrayAttrBase<AffineMapAttr,
+ "AffineMap array attribute"> {
+ let returnType = [{ ::llvm::SmallVector<::mlir::AffineMap, 4> }];
+ let convertFromStorage = [{
+ llvm::to_vector<4>($_self.getAsValueRange<::mlir::AffineMapAttr>());
+ }];
+ let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)";
+}
+
// TODO: Add an attribute to specify a
diff erent algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -75,7 +85,8 @@ def Vector_ContractionOp :
]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
- AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types,
+ Vector_AffineMapArrayAttr:$indexing_maps,
+ ArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
"CombiningKind::ADD">:$kind)>,
Results<(outs AnyType)> {
@@ -223,7 +234,6 @@ def Vector_ContractionOp :
}
Type getResultType() { return getResult().getType(); }
ArrayRef<StringRef> getTraitAttrNames();
- SmallVector<AffineMap, 4> getIndexingMaps();
static unsigned getAccOperandIndex() { return 2; }
// Returns the bounds of each dimension in the iteration space spanned
@@ -240,7 +250,7 @@ def Vector_ContractionOp :
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
- static constexpr StringRef getKindAttrName() { return "kind"; }
+ static constexpr StringRef getKindAttrStrName() { return "kind"; }
static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
@@ -327,8 +337,8 @@ def Vector_MultiDimReductionOp :
"CombiningKind":$kind)>
];
let extraClassDeclaration = [{
- static StringRef getKindAttrName() { return "kind"; }
- static StringRef getReductionDimsAttrName() { return "reduction_dims"; }
+ static StringRef getKindAttrStrName() { return "kind"; }
+ static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; }
VectorType getSourceVectorType() {
return source().getType().cast<VectorType>();
@@ -474,7 +484,7 @@ def Vector_ShuffleOp :
];
let hasFolder = 1;
let extraClassDeclaration = [{
- static StringRef getMaskAttrName() { return "mask"; }
+ static StringRef getMaskAttrStrName() { return "mask"; }
VectorType getV1VectorType() {
return v1().getType().cast<VectorType>();
}
@@ -561,7 +571,7 @@ def Vector_ExtractOp :
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
];
let extraClassDeclaration = [{
- static StringRef getPositionAttrName() { return "position"; }
+ static StringRef getPositionAttrStrName() { return "position"; }
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
}
@@ -754,7 +764,7 @@ def Vector_InsertOp :
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
];
let extraClassDeclaration = [{
- static StringRef getPositionAttrName() { return "position"; }
+ static StringRef getPositionAttrStrName() { return "position"; }
Type getSourceType() { return source().getType(); }
VectorType getDestVectorType() {
return dest().getType().cast<VectorType>();
@@ -873,15 +883,15 @@ def Vector_InsertStridedSliceOp :
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
];
let extraClassDeclaration = [{
- static StringRef getOffsetsAttrName() { return "offsets"; }
- static StringRef getStridesAttrName() { return "strides"; }
+ static StringRef getOffsetsAttrStrName() { return "offsets"; }
+ static StringRef getStridesAttrStrName() { return "strides"; }
VectorType getSourceVectorType() {
return source().getType().cast<VectorType>();
}
VectorType getDestVectorType() {
return dest().getType().cast<VectorType>();
}
- bool hasNonUnitStrides() {
+ bool hasNonUnitStrides() {
return llvm::any_of(strides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
});
@@ -970,7 +980,7 @@ def Vector_OuterProductOp :
VectorType getVectorType() {
return getResult().getType().cast<VectorType>();
}
- static constexpr StringRef getKindAttrName() {
+ static constexpr StringRef getKindAttrStrName() {
return "kind";
}
static CombiningKind getDefaultKind() {
@@ -1089,11 +1099,11 @@ def Vector_ReshapeOp :
void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
- static StringRef getFixedVectorSizesAttrName() {
+ static StringRef getFixedVectorSizesAttrStrName() {
return "fixed_vector_sizes";
}
- static StringRef getInputShapeAttrName() { return "input_shape"; }
- static StringRef getOutputShapeAttrName() { return "output_shape"; }
+ static StringRef getInputShapeAttrStrName() { return "input_shape"; }
+ static StringRef getOutputShapeAttrStrName() { return "output_shape"; }
}];
let assemblyFormat = [{
@@ -1140,12 +1150,12 @@ def Vector_ExtractStridedSliceOp :
"ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides)>
];
let extraClassDeclaration = [{
- static StringRef getOffsetsAttrName() { return "offsets"; }
- static StringRef getSizesAttrName() { return "sizes"; }
- static StringRef getStridesAttrName() { return "strides"; }
+ static StringRef getOffsetsAttrStrName() { return "offsets"; }
+ static StringRef getSizesAttrStrName() { return "sizes"; }
+ static StringRef getStridesAttrStrName() { return "strides"; }
VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
void getOffsets(SmallVectorImpl<int64_t> &results);
- bool hasNonUnitStrides() {
+ bool hasNonUnitStrides() {
return llvm::any_of(strides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
});
@@ -2190,7 +2200,7 @@ def Vector_ConstantMaskOp :
}];
let extraClassDeclaration = [{
- static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
+ static StringRef getMaskDimSizesAttrStrName() { return "mask_dim_sizes"; }
}];
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
@@ -2276,7 +2286,7 @@ def Vector_TransposeOp :
return result().getType().cast<VectorType>();
}
void getTransp(SmallVectorImpl<int64_t> &results);
- static StringRef getTranspAttrName() { return "transp"; }
+ static StringRef getTranspAttrStrName() { return "transp"; }
}];
let assemblyFormat = [{
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
@@ -2537,8 +2547,8 @@ def Vector_ScanOp :
CArg<"bool", "true">:$inclusive)>
];
let extraClassDeclaration = [{
- static StringRef getKindAttrName() { return "kind"; }
- static StringRef getReductionDimAttrName() { return "reduction_dim"; }
+ static StringRef getKindAttrStrName() { return "kind"; }
+ static StringRef getReductionDimAttrStrName() { return "reduction_dim"; }
VectorType getSourceType() {
return source().getType().cast<VectorType>();
}
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 68b88860b2ff3..ee6c638d402c5 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -55,7 +55,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
StaticInterfaceMethod<
/*desc=*/"Return the `in_bounds` attribute name.",
/*retTy=*/"::mlir::StringRef",
- /*methodName=*/"getInBoundsAttrName",
+ /*methodName=*/"getInBoundsAttrStrName",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/ [{ return "in_bounds"; }]
@@ -63,7 +63,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
StaticInterfaceMethod<
/*desc=*/"Return the `permutation_map` attribute name.",
/*retTy=*/"::mlir::StringRef",
- /*methodName=*/"getPermutationMapAttrName",
+ /*methodName=*/"getPermutationMapAttrStrName",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/ [{ return "permutation_map"; }]
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index dec9eec703884..8650d574de289 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -318,7 +318,7 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write,
write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
} else {
newForOp.getResult(initArgNumber)
- .replaceAllUsesWith(write.transferWriteOp.getResult(0));
+ .replaceAllUsesWith(write.transferWriteOp.getResult());
write.transferWriteOp.sourceMutable().assign(
newForOp.getResult(initArgNumber));
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2d504cb0029c4..4db150927fae5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -347,9 +347,9 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
- result.addAttribute(getReductionDimsAttrName(),
+ result.addAttribute(getReductionDimsAttrStrName(),
builder.getI64ArrayAttr(reductionDims));
- result.addAttribute(getKindAttrName(),
+ result.addAttribute(getKindAttrStrName(),
CombiningKindAttr::get(kind, builder.getContext()));
}
@@ -491,10 +491,10 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<StringRef> iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
- result.addAttribute(getIndexingMapsAttrName(),
+ result.addAttribute(::mlir::getIndexingMapsAttrName(),
builder.getAffineMapArrayAttr(
AffineMap::inferFromExprList(indexingExprs)));
- result.addAttribute(getIteratorTypesAttrName(),
+ result.addAttribute(::mlir::getIteratorTypesAttrName(),
builder.getStrArrayAttr(iteratorTypes));
}
@@ -512,9 +512,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
ArrayAttr iteratorTypes, CombiningKind kind) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
- result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
- result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
- result.addAttribute(ContractionOp::getKindAttrName(),
+ result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
+ result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
+ result.addAttribute(ContractionOp::getKindAttrStrName(),
CombiningKindAttr::get(kind, builder.getContext()));
}
@@ -543,8 +543,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
- if (!result.attributes.get(ContractionOp::getKindAttrName())) {
- result.addAttribute(ContractionOp::getKindAttrName(),
+ if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
+ result.addAttribute(ContractionOp::getKindAttrStrName(),
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
result.getContext()));
}
@@ -698,7 +698,7 @@ LogicalResult ContractionOp::verify() {
unsigned numIterators = iterator_types().getValue().size();
for (const auto &it : llvm::enumerate(indexing_maps())) {
auto index = it.index();
- auto map = it.value().cast<AffineMapAttr>().getValue();
+ auto map = it.value();
if (map.getNumSymbols() != 0)
return emitOpError("expected indexing map ")
<< index << " to have no symbols";
@@ -759,9 +759,9 @@ LogicalResult ContractionOp::verify() {
}
ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
- static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
- getIteratorTypesAttrName(),
- ContractionOp::getKindAttrName()};
+ static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
+ ::mlir::getIteratorTypesAttrName(),
+ ContractionOp::getKindAttrStrName()};
return llvm::makeArrayRef(names);
}
@@ -817,11 +817,11 @@ void ContractionOp::getIterationBounds(
void ContractionOp::getIterationIndexMap(
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
- unsigned numMaps = indexing_maps().getValue().size();
+ unsigned numMaps = indexing_maps().size();
iterationIndexMap.resize(numMaps);
for (const auto &it : llvm::enumerate(indexing_maps())) {
auto index = it.index();
- auto map = it.value().cast<AffineMapAttr>().getValue();
+ auto map = it.value();
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
auto dim = map.getResult(i).cast<AffineDimExpr>();
iterationIndexMap[index][dim.getPosition()] = i;
@@ -841,13 +841,6 @@ std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
getParallelIteratorTypeName(), getContext());
}
-SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
- return llvm::to_vector<4>(
- llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
- return mapAttr.cast<AffineMapAttr>().getValue();
- }));
-}
-
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
SmallVector<int64_t, 4> shape;
getIterationBounds(shape);
@@ -961,7 +954,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
auto positionAttr = getVectorSubscriptAttr(builder, position);
result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
positionAttr));
- result.addAttribute(getPositionAttrName(), positionAttr);
+ result.addAttribute(getPositionAttrStrName(), positionAttr);
}
// Convenience builder which assumes the values are constant indices.
@@ -1053,7 +1046,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
- extractOp->setAttr(ExtractOp::getPositionAttrName(),
+ extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(globalPosition));
return success();
}
@@ -1295,7 +1288,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
extractOp.setOperand(source);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp->setAttr(ExtractOp::getPositionAttrName(),
+ extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(extractPos));
return extractOp.getResult();
}
@@ -1355,7 +1348,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp->setAttr(ExtractOp::getPositionAttrName(),
+ extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(newPosition));
extractOp.setOperand(shapeCastOp.source());
return extractOp.getResult();
@@ -1396,7 +1389,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp->setAttr(ExtractOp::getPositionAttrName(),
+ extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(extractedPos));
return extractOp.getResult();
}
@@ -1453,7 +1446,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
op.vectorMutable().assign(insertOp.source());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- op->setAttr(ExtractOp::getPositionAttrName(),
+ op->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(offsetDiffs));
return op.getResult();
}
@@ -1736,7 +1729,7 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
auto shape = llvm::to_vector<4>(v1Type.getShape());
shape[0] = mask.size();
result.addTypes(VectorType::get(shape, v1Type.getElementType()));
- result.addAttribute(getMaskAttrName(), maskAttr);
+ result.addAttribute(getMaskAttrStrName(), maskAttr);
}
void ShuffleOp::print(OpAsmPrinter &p) {
@@ -1784,7 +1777,7 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
VectorType v1Type, v2Type;
if (parser.parseOperand(v1) || parser.parseComma() ||
parser.parseOperand(v2) ||
- parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
+ parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(v1Type) || parser.parseComma() ||
@@ -1877,7 +1870,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
result.addOperands({source, dest});
auto positionAttr = getVectorSubscriptAttr(builder, position);
result.addTypes(dest.getType());
- result.addAttribute(getPositionAttrName(), positionAttr);
+ result.addAttribute(getPositionAttrStrName(), positionAttr);
}
// Convenience builder which assumes the values are constant indices.
@@ -1995,8 +1988,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
auto stridesAttr = getVectorSubscriptAttr(builder, strides);
result.addTypes(dest.getType());
- result.addAttribute(getOffsetsAttrName(), offsetsAttr);
- result.addAttribute(getStridesAttrName(), stridesAttr);
+ result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
+ result.addAttribute(getStridesAttrStrName(), stridesAttr);
}
// TODO: Should be moved to Tablegen Confined attributes.
@@ -2172,9 +2165,9 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
vLHS.getElementType())
: VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
- if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
+ if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
result.attributes.append(
- OuterProductOp::getKindAttrName(),
+ OuterProductOp::getKindAttrStrName(),
CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
result.getContext()));
}
@@ -2322,9 +2315,9 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(
inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
offsetsAttr, sizesAttr, stridesAttr));
- result.addAttribute(getOffsetsAttrName(), offsetsAttr);
- result.addAttribute(getSizesAttrName(), sizesAttr);
- result.addAttribute(getStridesAttrName(), stridesAttr);
+ result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
+ result.addAttribute(getSizesAttrStrName(), sizesAttr);
+ result.addAttribute(getStridesAttrStrName(), stridesAttr);
}
LogicalResult ExtractStridedSliceOp::verify() {
@@ -2412,7 +2405,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
op.setOperand(insertOp.source());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
+ op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
b.getI64ArrayAttr(offsetDiffs));
return success();
}
@@ -2765,7 +2758,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 3> elidedAttrs;
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
if (op.permutation_map().isMinorIdentity())
- elidedAttrs.push_back(op.getPermutationMapAttrName());
+ elidedAttrs.push_back(op.getPermutationMapAttrStrName());
bool elideInBounds = true;
if (auto inBounds = op.in_bounds()) {
for (auto attr : *inBounds) {
@@ -2776,7 +2769,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
}
}
if (elideInBounds)
- elidedAttrs.push_back(op.getInBoundsAttrName());
+ elidedAttrs.push_back(op.getInBoundsAttrStrName());
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
@@ -2817,7 +2810,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
VectorType vectorType = types[1].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
- auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
+ auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
Attribute mapAttr = result.attributes.get(permutationAttrName);
if (!mapAttr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
@@ -2963,7 +2956,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
return failure();
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- op->setAttr(TransferOp::getInBoundsAttrName(),
+ op->setAttr(TransferOp::getInBoundsAttrStrName(),
b.getBoolArrayAttr(newInBounds));
return success();
}
@@ -3193,7 +3186,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
- auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
+ auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
@@ -4151,7 +4144,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(vector);
result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
- result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
+ result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
}
// Eliminates transpose operations, which produce values identical to their
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index f574713ffb2a4..48470f7b059d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -514,7 +514,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
return success();
}
@@ -585,7 +585,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
- xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 226faccaba96c..f9413a7468187 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1050,7 +1050,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
bindDims(rew.getContext(), m, n, k);
// LHS must be A(m, k) or A(k, m).
Value lhs = op.lhs();
- auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue();
+ auto lhsMap = op.indexing_maps()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
@@ -1058,7 +1058,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
// RHS must be B(k, n) or B(n, k).
Value rhs = op.rhs();
- auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue();
+ auto rhsMap = op.indexing_maps()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
@@ -1088,7 +1088,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
mul);
// ACC must be C(m, n) or C(n, m).
- auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue();
+ auto accMap = op.indexing_maps()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
More information about the Mlir-commits
mailing list