[Mlir-commits] [mlir] 7c38fd6 - [mlir] Flip Vector dialect accessors used to prefixed form.
Jacques Pienaar
llvmlistbot at llvm.org
Mon Mar 28 11:24:53 PDT 2022
Author: Jacques Pienaar
Date: 2022-03-28T11:24:47-07:00
New Revision: 7c38fd605ba85657a0ecbea75a8e3a68174d3dff
URL: https://github.com/llvm/llvm-project/commit/7c38fd605ba85657a0ecbea75a8e3a68174d3dff
DIFF: https://github.com/llvm/llvm-project/commit/7c38fd605ba85657a0ecbea75a8e3a68174d3dff.diff
LOG: [mlir] Flip Vector dialect accessors used to prefixed form.
This has been on _Both for a couple of weeks. Flip usages in core with
intention to flip flag to _Prefixed in follow up. Needed to add a couple
of helper methods in AffineOps and Linalg to facilitate a pure flag flip
in follow up as some of these classes are used in templates and so
sensitive to Vector dialect changes.
Differential Revision: https://reviews.llvm.org/D122151
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f13014396db40..3880ad88c1ee0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -493,6 +493,9 @@ class AffineLoadOpBase<string mnemonic, list<Trait> traits = []> :
}
static StringRef getMapAttrName() { return "map"; }
+
+ // TODO: Remove once prefixing is flipped.
+ operand_range getIndices() { return indices(); }
}];
}
@@ -856,6 +859,9 @@ class AffineStoreOpBase<string mnemonic, list<Trait> traits = []> :
}
static StringRef getMapAttrName() { return "map"; }
+
+ // TODO: Remove once prefixing is flipped.
+ operand_range getIndices() { return indices(); }
}];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index a551f40141a00..439538378dd71 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1131,6 +1131,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+ // TODO: Remove once prefixing is flipped.
+ ArrayAttr getIteratorTypes() { return iterator_types(); }
+
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 824bfeb326366..779983fedbf5d 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -163,7 +163,7 @@ class StructuredGenerator {
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
- iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+ iterators(op.getIteratorTypes()), maps(op.getIndexingMaps()), op(op) {}
bool iters(ArrayRef<IteratorType> its) {
if (its.size() != iterators.size())
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 69c2c929e42ff..005db9abafd03 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -219,18 +219,18 @@ def Vector_ContractionOp :
];
let extraClassDeclaration = [{
VectorType getLhsType() {
- return lhs().getType().cast<VectorType>();
+ return getLhs().getType().cast<VectorType>();
}
VectorType getRhsType() {
- return rhs().getType().cast<VectorType>();
+ return getRhs().getType().cast<VectorType>();
}
- Type getAccType() { return acc().getType(); }
+ Type getAccType() { return getAcc().getType(); }
VectorType getLHSVectorMaskType() {
- if (llvm::size(masks()) != 2) return VectorType();
+ if (llvm::size(getMasks()) != 2) return VectorType();
return getOperand(3).getType().cast<VectorType>();
}
VectorType getRHSVectorMaskType() {
- if (llvm::size(masks()) != 2) return VectorType();
+ if (llvm::size(getMasks()) != 2) return VectorType();
return getOperand(4).getType().cast<VectorType>();
}
Type getResultType() { return getResult().getType(); }
@@ -296,7 +296,7 @@ def Vector_ReductionOp :
}];
let extraClassDeclaration = [{
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
}];
let builders = [
@@ -347,10 +347,10 @@ def Vector_MultiDimReductionOp :
static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; }
VectorType getSourceVectorType() {
- return source().getType().cast<VectorType>();
+ return getSource().getType().cast<VectorType>();
}
Type getDestType() {
- return dest().getType();
+ return getDest().getType();
}
bool isReducedDim(int64_t d) {
@@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :
SmallVector<bool> getReductionMask() {
SmallVector<bool> res(getSourceVectorType().getRank(), false);
- for (auto ia : reduction_dims().getAsRange<IntegerAttr>())
+ for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
res[ia.getInt()] = true;
return res;
}
@@ -415,9 +415,9 @@ def Vector_BroadcastOp :
```
}];
let extraClassDeclaration = [{
- Type getSourceType() { return source().getType(); }
+ Type getSourceType() { return getSource().getType(); }
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
@@ -472,13 +472,13 @@ def Vector_ShuffleOp :
let extraClassDeclaration = [{
static StringRef getMaskAttrStrName() { return "mask"; }
VectorType getV1VectorType() {
- return v1().getType().cast<VectorType>();
+ return getV1().getType().cast<VectorType>();
}
VectorType getV2VectorType() {
- return v2().getType().cast<VectorType>();
+ return getV2().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
}];
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
@@ -526,7 +526,7 @@ def Vector_ExtractElementOp :
];
let extraClassDeclaration = [{
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
}];
let hasVerifier = 1;
@@ -560,7 +560,7 @@ def Vector_ExtractOp :
let extraClassDeclaration = [{
static StringRef getPositionAttrStrName() { return "position"; }
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
@@ -623,7 +623,7 @@ def Vector_ExtractMapOp :
"AffineMap":$map)>];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
VectorType getResultType() {
return getResult().getType().cast<VectorType>();
@@ -664,7 +664,7 @@ def Vector_FMAOp :
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)";
let extraClassDeclaration = [{
- VectorType getVectorType() { return lhs().getType().cast<VectorType>(); }
+ VectorType getVectorType() { return getLhs().getType().cast<VectorType>(); }
}];
}
@@ -707,9 +707,9 @@ def Vector_InsertElementOp :
OpBuilder<(ins "Value":$source, "Value":$dest)>,
];
let extraClassDeclaration = [{
- Type getSourceType() { return source().getType(); }
+ Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
- return dest().getType().cast<VectorType>();
+ return getDest().getType().cast<VectorType>();
}
}];
let hasVerifier = 1;
@@ -747,9 +747,9 @@ def Vector_InsertOp :
];
let extraClassDeclaration = [{
static StringRef getPositionAttrStrName() { return "position"; }
- Type getSourceType() { return source().getType(); }
+ Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
- return dest().getType().cast<VectorType>();
+ return getDest().getType().cast<VectorType>();
}
}];
@@ -809,7 +809,7 @@ def Vector_InsertMapOp :
}];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
VectorType getResultType() {
return getResult().getType().cast<VectorType>();
@@ -866,13 +866,13 @@ def Vector_InsertStridedSliceOp :
static StringRef getOffsetsAttrStrName() { return "offsets"; }
static StringRef getStridesAttrStrName() { return "strides"; }
VectorType getSourceVectorType() {
- return source().getType().cast<VectorType>();
+ return getSource().getType().cast<VectorType>();
}
VectorType getDestVectorType() {
- return dest().getType().cast<VectorType>();
+ return getDest().getType().cast<VectorType>();
}
bool hasNonUnitStrides() {
- return llvm::any_of(strides(), [](Attribute attr) {
+ return llvm::any_of(getStrides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
});
}
@@ -947,15 +947,15 @@ def Vector_OuterProductOp :
];
let extraClassDeclaration = [{
VectorType getOperandVectorTypeLHS() {
- return lhs().getType().cast<VectorType>();
+ return getLhs().getType().cast<VectorType>();
}
Type getOperandTypeRHS() {
- return rhs().getType();
+ return getRhs().getType();
}
VectorType getOperandVectorTypeACC() {
- return (llvm::size(acc()) == 0)
+ return (llvm::size(getAcc()) == 0)
? VectorType()
- : (*acc().begin()).getType().cast<VectorType>();
+ : (*getAcc().begin()).getType().cast<VectorType>();
}
VectorType getVectorType() {
return getResult().getType().cast<VectorType>();
@@ -1065,17 +1065,17 @@ def Vector_ReshapeOp :
let extraClassDeclaration = [{
VectorType getInputVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
VectorType getOutputVectorType() {
return getResult().getType().cast<VectorType>();
}
/// Returns as integer value the number of input shape operands.
- int64_t getNumInputShapeSizes() { return input_shape().size(); }
+ int64_t getNumInputShapeSizes() { return getInputShape().size(); }
/// Returns as integer value the number of output shape operands.
- int64_t getNumOutputShapeSizes() { return output_shape().size(); }
+ int64_t getNumOutputShapeSizes() { return getOutputShape().size(); }
void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
@@ -1133,10 +1133,10 @@ def Vector_ExtractStridedSliceOp :
static StringRef getOffsetsAttrStrName() { return "offsets"; }
static StringRef getSizesAttrStrName() { return "sizes"; }
static StringRef getStridesAttrStrName() { return "strides"; }
- VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
+ VectorType getVectorType(){ return getVector().getType().cast<VectorType>(); }
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
- return llvm::any_of(strides(), [](Attribute attr) {
+ return llvm::any_of(getStrides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
});
}
@@ -1558,11 +1558,11 @@ def Vector_LoadOp : Vector_Op<"load"> {
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getVectorType() {
- return result().getType().cast<VectorType>();
+ return getResult().getType().cast<VectorType>();
}
}];
@@ -1635,11 +1635,11 @@ def Vector_StoreOp : Vector_Op<"store"> {
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getVectorType() {
- return valueToStore().getType().cast<VectorType>();
+ return getValueToStore().getType().cast<VectorType>();
}
}];
@@ -1688,16 +1688,16 @@ def Vector_MaskedLoadOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
- return pass_thru().getType().cast<VectorType>();
+ return getPassThru().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return result().getType().cast<VectorType>();
+ return getResult().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
@@ -1744,13 +1744,13 @@ def Vector_MaskedStoreOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return valueToStore().getType().cast<VectorType>();
+ return getValueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat =
@@ -1803,19 +1803,19 @@ def Vector_GatherOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getIndexVectorType() {
- return index_vec().getType().cast<VectorType>();
+ return getIndexVec().getType().cast<VectorType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
- return pass_thru().getType().cast<VectorType>();
+ return getPassThru().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return result().getType().cast<VectorType>();
+ return getResult().getType().cast<VectorType>();
}
}];
let assemblyFormat =
@@ -1870,16 +1870,16 @@ def Vector_ScatterOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getIndexVectorType() {
- return index_vec().getType().cast<VectorType>();
+ return getIndexVec().getType().cast<VectorType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return valueToStore().getType().cast<VectorType>();
+ return getValueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat =
@@ -1931,16 +1931,16 @@ def Vector_ExpandLoadOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
- return pass_thru().getType().cast<VectorType>();
+ return getPassThru().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return result().getType().cast<VectorType>();
+ return getResult().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
@@ -1989,13 +1989,13 @@ def Vector_CompressStoreOp :
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return base().getType().cast<MemRefType>();
+ return getBase().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
- return mask().getType().cast<VectorType>();
+ return getMask().getType().cast<VectorType>();
}
VectorType getVectorType() {
- return valueToStore().getType().cast<VectorType>();
+ return getValueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat =
@@ -2045,7 +2045,7 @@ def Vector_ShapeCastOp :
}];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return source().getType().cast<VectorType>();
+ return getSource().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return getResult().getType().cast<VectorType>();
@@ -2086,7 +2086,7 @@ def Vector_BitCastOp :
}];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return source().getType().cast<VectorType>();
+ return getSource().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return getResult().getType().cast<VectorType>();
@@ -2129,13 +2129,13 @@ def Vector_TypeCastOp :
let extraClassDeclaration = [{
MemRefType getMemRefType() {
- return memref().getType().cast<MemRefType>();
+ return getMemref().getType().cast<MemRefType>();
}
MemRefType getResultMemRefType() {
return getResult().getType().cast<MemRefType>();
}
// Implement ViewLikeOpInterface.
- Value getViewSource() { return memref(); }
+ Value getViewSource() { return getMemref(); }
}];
let assemblyFormat = [{
@@ -2260,10 +2260,10 @@ def Vector_TransposeOp :
];
let extraClassDeclaration = [{
VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
VectorType getResultType() {
- return result().getType().cast<VectorType>();
+ return getResult().getType().cast<VectorType>();
}
void getTransp(SmallVectorImpl<int64_t> &results);
static StringRef getTranspAttrStrName() { return "transp"; }
@@ -2303,7 +2303,7 @@ def Vector_PrintOp :
}];
let extraClassDeclaration = [{
Type getPrintType() {
- return source().getType();
+ return getSource().getType();
}
}];
let assemblyFormat = "$source attr-dict `:` type($source)";
@@ -2530,16 +2530,16 @@ def Vector_ScanOp :
static StringRef getKindAttrStrName() { return "kind"; }
static StringRef getReductionDimAttrStrName() { return "reduction_dim"; }
VectorType getSourceType() {
- return source().getType().cast<VectorType>();
+ return getSource().getType().cast<VectorType>();
}
VectorType getDestType() {
- return dest().getType().cast<VectorType>();
+ return getDest().getType().cast<VectorType>();
}
VectorType getAccumulatorType() {
- return accumulated_value().getType().cast<VectorType>();
+ return getAccumulatedValue().getType().cast<VectorType>();
}
VectorType getInitialValueType() {
- return initial_value().getType().cast<VectorType>();
+ return getInitialValue().getType().cast<VectorType>();
}
}];
let assemblyFormat =
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index ee6c638d402c5..00e69e2b89c1b 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -77,8 +77,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.isBroadcastDim(dim)
- || ($_op.in_bounds()
- && $_op.in_bounds()->template cast<::mlir::ArrayAttr>()[dim]
+ || ($_op.getInBounds()
+ && $_op.getInBounds()->template cast<::mlir::ArrayAttr>()[dim]
.template cast<::mlir::BoolAttr>().getValue());
}]
>,
@@ -87,7 +87,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*retTy=*/"::mlir::Value",
/*methodName=*/"source",
/*args=*/(ins),
- /*methodBody=*/"return $_op.source();"
+ /*methodBody=*/"return $_op.getSource();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -95,7 +95,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*retTy=*/"::mlir::Value",
/*methodName=*/"vector",
/*args=*/(ins),
- /*methodBody=*/"return $_op.vector();"
+ /*methodBody=*/"return $_op.getVector();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -103,7 +103,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*retTy=*/"::mlir::ValueRange",
/*methodName=*/"indices",
/*args=*/(ins),
- /*methodBody=*/"return $_op.indices();"
+ /*methodBody=*/"return $_op.getIndices();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -111,7 +111,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*retTy=*/"::mlir::AffineMap",
/*methodName=*/"permutation_map",
/*args=*/(ins),
- /*methodBody=*/"return $_op.permutation_map();"
+ /*methodBody=*/"return $_op.getPermutationMap();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -121,7 +121,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto expr = $_op.permutation_map().getResult(idx);
+ auto expr = $_op.getPermutationMap().getResult(idx);
return expr.template isa<::mlir::AffineConstantExpr>() &&
expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0;
}]
@@ -146,7 +146,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*retTy=*/"::mlir::Optional<::mlir::ArrayAttr>",
/*methodName=*/"in_bounds",
/*args=*/(ins),
- /*methodBody=*/"return $_op.in_bounds();"
+ /*methodBody=*/"return $_op.getInBounds();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -156,7 +156,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
- "return $_op.source().getType().template cast<::mlir::ShapedType>();"
+ "return $_op.getSource().getType().template cast<::mlir::ShapedType>();"
>,
InterfaceMethod<
/*desc=*/"Return the VectorType.",
@@ -165,7 +165,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.vector().getType().template dyn_cast<::mlir::VectorType>();
+ return $_op.getVector().getType().template dyn_cast<::mlir::VectorType>();
}]
>,
InterfaceMethod<
@@ -175,9 +175,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.mask()
+ return $_op.getMask()
? ::mlir::vector::detail::transferMaskType(
- $_op.getVectorType(), $_op.permutation_map())
+ $_op.getVectorType(), $_op.getPermutationMap())
: ::mlir::VectorType();
}]
>,
@@ -189,7 +189,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
- "return $_op.permutation_map().getNumResults();"
+ "return $_op.getPermutationMap().getNumResults();"
>,
InterfaceMethod<
/*desc=*/[{ Return the number of leading shaped dimensions that do not
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 196e4f6c2a952..9ed1c3483c11d 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -32,14 +32,14 @@ using namespace mlir;
// Return true if the contract op can be convert to MMA matmul.
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
- if (llvm::size(contract.masks()) != 0)
+ if (llvm::size(contract.getMasks()) != 0)
return false;
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(contract.getContext(), m, n, k);
- auto iteratorTypes = contract.iterator_types().getValue();
+ auto iteratorTypes = contract.getIteratorTypes().getValue();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])))
@@ -76,12 +76,12 @@ getMemrefConstantHorizontalStride(ShapedType type) {
// Return true if the transfer op can be converted to a MMA matrix load.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
- if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
+ if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
readOp.getVectorType().getRank() != 2)
return false;
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
return false;
- AffineMap map = readOp.permutation_map();
+ AffineMap map = readOp.getPermutationMap();
OpBuilder b(readOp.getContext());
AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
AffineExpr zero = b.getAffineConstantExpr(0);
@@ -99,13 +99,13 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
if (writeOp.getTransferRank() == 0)
return false;
- if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
+ if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
writeOp.getVectorType().getRank() != 2)
return false;
if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
return false;
// TODO: Support transpose once it is added to GPU dialect ops.
- if (!writeOp.permutation_map().isMinorIdentity())
+ if (!writeOp.getPermutationMap().isMinorIdentity())
return false;
return true;
}
@@ -122,7 +122,7 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
return broadcastOp.getVectorType().getRank() == 2 &&
- broadcastOp.source().getType().isa<FloatType>();
+ broadcastOp.getSource().getType().isa<FloatType>();
}
/// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -240,7 +240,7 @@ struct PrepareContractToGPUMMA
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
+ Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
// Set up the parallel/reduction structure in right form.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
@@ -248,7 +248,7 @@ struct PrepareContractToGPUMMA
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
static constexpr std::array<int64_t, 2> perm = {1, 0};
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypes().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
@@ -286,7 +286,7 @@ struct PrepareContractToGPUMMA
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, lhs, rhs, res,
rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
- op.iterator_types());
+ op.getIteratorTypes());
return success();
}
};
@@ -299,7 +299,8 @@ struct CombineTransferReadOpTranspose final
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
- auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
+ auto transferReadOp =
+ op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
@@ -307,7 +308,7 @@ struct CombineTransferReadOpTranspose final
if (transferReadOp.getTransferRank() == 0)
return failure();
- if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
+ if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
return failure();
SmallVector<int64_t, 2> perm;
op.getTransp(perm);
@@ -316,11 +317,13 @@ struct CombineTransferReadOpTranspose final
permU.push_back(unsigned(o));
AffineMap permutationMap =
AffineMap::getPermutationMap(permU, op.getContext());
- AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
+ AffineMap newMap =
+ permutationMap.compose(transferReadOp.getPermutationMap());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
- AffineMapAttr::get(newMap), transferReadOp.padding(),
- transferReadOp.mask(), transferReadOp.in_boundsAttr());
+ op, op.getType(), transferReadOp.getSource(),
+ transferReadOp.getIndices(), AffineMapAttr::get(newMap),
+ transferReadOp.getPadding(), transferReadOp.getMask(),
+ transferReadOp.getInBoundsAttr());
return success();
}
};
@@ -337,9 +340,9 @@ static const char *inferFragType(OpTy op) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
continue;
- if (contract.lhs() == op.getResult())
+ if (contract.getLhs() == op.getResult())
return "AOp";
- if (contract.rhs() == op.getResult())
+ if (contract.getRhs() == op.getResult())
return "BOp";
}
return "COp";
@@ -351,7 +354,7 @@ static void convertTransferReadOp(vector::TransferReadOp op,
assert(transferReadSupportsMMAMatrixType(op));
Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
- AffineMap map = op.permutation_map();
+ AffineMap map = op.getPermutationMap();
// Handle broadcast by setting the stride to 0.
if (map.getResult(0).isa<AffineConstantExpr>()) {
assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
@@ -364,7 +367,8 @@ static void convertTransferReadOp(vector::TransferReadOp op,
op.getVectorType().getElementType(), fragType);
OpBuilder b(op);
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
- op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
+ op.getLoc(), type, op.getSource(), op.getIndices(),
+ b.getIndexAttr(*stride));
valueMapping[op.getResult()] = load;
}
@@ -375,18 +379,19 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
getMemrefConstantHorizontalStride(op.getShapedType());
assert(stride);
OpBuilder b(op);
- Value matrix = valueMapping.find(op.vector())->second;
- b.create<gpu::SubgroupMmaStoreMatrixOp>(
- op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
+ Value matrix = valueMapping.find(op.getVector())->second;
+ b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(),
+ op.getIndices(),
+ b.getIndexAttr(*stride));
op.erase();
}
static void convertContractOp(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
- Value opA = valueMapping.find(op.lhs())->second;
- Value opB = valueMapping.find(op.rhs())->second;
- Value opC = valueMapping.find(op.acc())->second;
+ Value opA = valueMapping.find(op.getLhs())->second;
+ Value opB = valueMapping.find(op.getRhs())->second;
+ Value opC = valueMapping.find(op.getAcc())->second;
Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
opA, opB, opC);
valueMapping[op.getResult()] = matmul;
@@ -420,7 +425,7 @@ static void convertBroadcastOp(vector::BroadcastOp op,
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
- op.source());
+ op.getSource());
valueMapping[op.getResult()] = matrix;
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 20e51008c52b1..3f6b3524b8965 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -155,9 +155,9 @@ class VectorMatmulOpConversion
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- matmulOp, typeConverter->convertType(matmulOp.res().getType()),
- adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
- matmulOp.lhs_columns(), matmulOp.rhs_columns());
+ matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
+ adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
+ matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
return success();
}
};
@@ -173,8 +173,8 @@ class VectorFlatTransposeOpConversion
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
- transOp, typeConverter->convertType(transOp.res().getType()),
- adaptor.matrix(), transOp.rows(), transOp.columns());
+ transOp, typeConverter->convertType(transOp.getRes().getType()),
+ adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
return success();
}
};
@@ -194,14 +194,14 @@ static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
- loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
+ loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
}
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
ptr, align);
}
@@ -210,7 +210,7 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
- storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
+ storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
}
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
@@ -240,8 +240,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve address.
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
.template cast<VectorType>();
- Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
@@ -269,16 +269,16 @@ class VectorGatherOpConversion
// Resolve address.
Value ptrs;
VectorType vType = gather.getVectorType();
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
- adaptor.index_vec(), memRefType, vType, ptrs)))
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
return failure();
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
- adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
};
@@ -303,15 +303,15 @@ class VectorScatterOpConversion
// Resolve address.
Value ptrs;
VectorType vType = scatter.getVectorType();
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
- adaptor.index_vec(), memRefType, vType, ptrs)))
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
- scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
+ scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
rewriter.getI32IntegerAttr(align));
return success();
}
@@ -331,11 +331,11 @@ class VectorExpandLoadOpConversion
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
+ expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
return success();
}
};
@@ -353,11 +353,11 @@ class VectorCompressStoreOpConversion
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- compress, adaptor.valueToStore(), ptr, adaptor.mask());
+ compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
return success();
}
};
@@ -374,8 +374,8 @@ class VectorReductionOpConversion
LogicalResult
matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto kind = reductionOp.kind();
- Type eltType = reductionOp.dest().getType();
+ auto kind = reductionOp.getKind();
+ Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
Value operand = adaptor.getOperands()[0];
if (eltType.isIntOrIndex()) {
@@ -468,7 +468,7 @@ class VectorShuffleOpConversion
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
Type llvmType = typeConverter->convertType(vectorType);
- auto maskArrayAttr = shuffleOp.mask();
+ auto maskArrayAttr = shuffleOp.getMask();
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -484,7 +484,7 @@ class VectorShuffleOpConversion
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
+ loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
@@ -499,10 +499,10 @@ class VectorShuffleOpConversion
int64_t insPos = 0;
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
- Value value = adaptor.v1();
+ Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
- value = adaptor.v2();
+ value = adaptor.getV2();
}
Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
eltType, rank, extPos);
@@ -537,12 +537,12 @@ class VectorExtractElementOpConversion
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.vector(), zero);
+ extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.vector(), adaptor.position());
+ extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
return success();
}
};
@@ -559,7 +559,7 @@ class VectorExtractOpConversion
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
- auto positionArrayAttr = extractOp.position();
+ auto positionArrayAttr = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -567,21 +567,21 @@ class VectorExtractOpConversion
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArrayAttr.empty()) {
- rewriter.replaceOp(extractOp, adaptor.vector());
+ rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmResultType, adaptor.vector(), positionArrayAttr);
+ loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
auto *context = extractOp->getContext();
- Value extracted = adaptor.vector();
+ Value extracted = adaptor.getVector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
@@ -628,8 +628,8 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
- adaptor.rhs(), adaptor.acc());
+ rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
+ fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
@@ -656,13 +656,13 @@ class VectorInsertElementOpConversion
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
+ insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
- adaptor.position());
+ insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
+ adaptor.getPosition());
return success();
}
};
@@ -679,7 +679,7 @@ class VectorInsertOpConversion
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
- auto positionArrayAttr = insertOp.position();
+ auto positionArrayAttr = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -688,14 +688,14 @@ class VectorInsertOpConversion
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArrayAttr.empty()) {
- rewriter.replaceOp(insertOp, adaptor.source());
+ rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, llvmResultType, adaptor.dest(), adaptor.source(),
+ loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
positionArrayAttr);
rewriter.replaceOp(insertOp, inserted);
return success();
@@ -703,7 +703,7 @@ class VectorInsertOpConversion
// Potential extraction of 1-D vector from array.
auto *context = insertOp->getContext();
- Value extracted = adaptor.dest();
+ Value extracted = adaptor.getDest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
@@ -721,15 +721,15 @@ class VectorInsertOpConversion
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
- adaptor.source(), constant);
+ adaptor.getSource(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
ArrayAttr::get(context, positionAttrs.drop_back());
- inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
- adaptor.dest(), inserted,
- nMinusOnePositionAttrs);
+ inserted = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmResultType, adaptor.getDest(), inserted,
+ nMinusOnePositionAttrs);
}
rewriter.replaceOp(insertOp, inserted);
@@ -780,9 +780,9 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
loc, elemType, rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
- Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
- Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
- Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
+ Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
+ Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
+ Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
}
@@ -1009,7 +1009,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
Type type = vectorType ? vectorType : eltType;
- emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
+ emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
@@ -1119,13 +1119,13 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- splatOp, vectorType, undef, adaptor.input(), zero);
+ splatOp, vectorType, undef, adaptor.getInput(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
- splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
+ splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
@@ -1170,7 +1170,7 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.input(), zero);
+ adaptor.getInput(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index fc906bfcd16cf..a8492466162bf 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -44,7 +44,7 @@ static LogicalResult replaceTransferOpWithMubuf(
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
Value &glc, Value &slc) {
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
- rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
+ rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.getVector(),
dwordConfig, vindex,
offsetSizeInBytes, glc, slc);
return success();
@@ -68,10 +68,10 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
return failure();
if (xferOp.getVectorType().getRank() > 1 ||
- llvm::size(xferOp.indices()) == 0)
+ llvm::size(xferOp.getIndices()) == 0)
return failure();
- if (!xferOp.permutation_map().isMinorIdentity())
+ if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Have it handled in vector->llvm conversion pass.
@@ -105,7 +105,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = this->getStridedElementPtr(
- loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
+ loc, memRefType, adaptor.getSource(), adaptor.getIndices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 0c79403ed4bdc..0f57c72f3d3f4 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -53,7 +53,7 @@ template <typename OpTy>
static Optional<int64_t> unpackedDim(OpTy xferOp) {
// TODO: support 0-d corner case.
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
- auto map = xferOp.permutation_map();
+ auto map = xferOp.getPermutationMap();
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
return expr.getPosition();
}
@@ -69,7 +69,7 @@ template <typename OpTy>
static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
// TODO: support 0-d corner case.
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
- auto map = xferOp.permutation_map();
+ auto map = xferOp.getPermutationMap();
return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
b.getContext());
}
@@ -86,7 +86,7 @@ static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
typename OpTy::Adaptor adaptor(xferOp);
// Corresponding memref dim of the vector dim that is unpacked.
auto dim = unpackedDim(xferOp);
- auto prevIndices = adaptor.indices();
+ auto prevIndices = adaptor.getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
Location loc = xferOp.getLoc();
@@ -94,7 +94,7 @@ static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
if (!isBroadcast) {
AffineExpr d0, d1;
bindDims(xferOp.getContext(), d0, d1);
- Value offset = adaptor.indices()[dim.getValue()];
+ Value offset = adaptor.getIndices()[dim.getValue()];
indices[dim.getValue()] =
makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
}
@@ -118,7 +118,7 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
/// * The to-be-unpacked dim of xferOp is a broadcast.
template <typename OpTy>
static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
- if (!xferOp.mask())
+ if (!xferOp.getMask())
return Value();
if (xferOp.getMaskType().getRank() != 1)
return Value();
@@ -126,7 +126,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv);
+ return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -167,10 +167,11 @@ static Value generateInBoundsCheck(
Location loc = xferOp.getLoc();
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
- Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
+ Value memrefDim =
+ vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
AffineExpr d0, d1;
bindDims(xferOp.getContext(), d0, d1);
- Value base = xferOp.indices()[dim.getValue()];
+ Value base = xferOp.getIndices()[dim.getValue()];
Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
memrefIdx);
@@ -289,11 +290,11 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
- if (xferOp.mask()) {
- auto maskType = MemRefType::get({}, xferOp.mask().getType());
+ if (xferOp.getMask()) {
+ auto maskType = MemRefType::get({}, xferOp.getMask().getType());
auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
b.setInsertionPoint(xferOp);
- b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
+ b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
}
@@ -319,8 +320,8 @@ static MemRefType unpackOneDim(MemRefType type) {
/// is similar to Strategy<TransferWriteOp>::getBuffer.
template <typename OpTy>
static Value getMaskBuffer(OpTy xferOp) {
- assert(xferOp.mask() && "Expected that transfer op has mask");
- auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
+ assert(xferOp.getMask() && "Expected that transfer op has mask");
+ auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
assert(loadOp && "Expected transfer op mask produced by LoadOp");
return loadOp.getMemRef();
}
@@ -401,15 +402,15 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
- auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+ auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
- loc, vecType, xferOp.source(), xferIndices,
- AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
- Value(), inBoundsAttr);
+ loc, vecType, xferOp.getSource(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+ xferOp.getPadding(), Value(), inBoundsAttr);
maybeApplyPassLabel(b, newXferOp, options.targetRank);
- b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
+ b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
return newXferOp;
}
@@ -425,7 +426,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
- auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.padding());
+ auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
@@ -453,7 +454,7 @@ struct Strategy<TransferWriteOp> {
/// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
/// ```
static Value getBuffer(TransferWriteOp xferOp) {
- auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
+ auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
assert(loadOp && "Expected transfer op vector produced by LoadOp");
return loadOp.getMemRef();
}
@@ -461,7 +462,7 @@ struct Strategy<TransferWriteOp> {
/// Retrieve the indices of the current LoadOp that loads from the buffer.
static void getBufferIndices(TransferWriteOp xferOp,
SmallVector<Value, 8> &indices) {
- auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
+ auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -488,8 +489,8 @@ struct Strategy<TransferWriteOp> {
Location loc = xferOp.getLoc();
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
- auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
- auto source = loopState.empty() ? xferOp.source() : loopState[0];
+ auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
+ auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, type, vec, source, xferIndices,
@@ -521,7 +522,7 @@ struct Strategy<TransferWriteOp> {
/// Return the initial loop state for the generated scf.for loop.
static Value initialLoopState(TransferWriteOp xferOp) {
- return isTensorOp(xferOp) ? xferOp.source() : Value();
+ return isTensorOp(xferOp) ? xferOp.getSource() : Value();
}
};
@@ -576,8 +577,8 @@ struct PrepareTransferReadConversion
auto buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
- if (xferOp.mask()) {
- dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
+ if (xferOp.getMask()) {
+ dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
buffers.maskBuffer);
}
@@ -624,16 +625,18 @@ struct PrepareTransferWriteConversion
Location loc = xferOp.getLoc();
auto buffers = allocBuffers(rewriter, xferOp);
- rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
+ rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
+ buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
rewriter.updateRootInPlace(xferOp, [&]() {
- xferOp.vectorMutable().assign(loadedVec);
+ xferOp.getVectorMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
- if (xferOp.mask()) {
- rewriter.updateRootInPlace(
- xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
+ if (xferOp.getMask()) {
+ rewriter.updateRootInPlace(xferOp, [&]() {
+ xferOp.getMaskMutable().assign(buffers.maskBuffer);
+ });
}
return success();
@@ -694,7 +697,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
- if (xferOp.mask()) {
+ if (xferOp.getMask()) {
auto maskBuffer = getMaskBuffer(xferOp);
auto maskBufferType =
maskBuffer.getType().template dyn_cast<MemRefType>();
@@ -741,8 +744,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// the
// unpacked dim is not a broadcast, no mask is
// needed on the new transfer op.
- if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
- xferOp.getMaskType().getRank() > 1)) {
+ if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
+ xferOp.getMaskType().getRank() > 1)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer); // Insert load before newXfer.
@@ -755,8 +758,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
- rewriter.updateRootInPlace(
- newXfer, [&]() { newXfer.maskMutable().assign(mask); });
+ rewriter.updateRootInPlace(newXfer, [&]() {
+ newXfer.getMaskMutable().assign(mask);
+ });
}
return loopState.empty() ? Value() : newXfer->getResult(0);
@@ -784,13 +788,13 @@ namespace lowering_n_d_unrolled {
template <typename OpTy>
static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
int64_t i) {
- if (!xferOp.mask())
+ if (!xferOp.getMask())
return;
if (xferOp.isBroadcastDim(0)) {
// To-be-unpacked dimension is a broadcast, which does not have a
// corresponding mask dimension. Mask attribute remains unchanged.
- newXferOp.maskMutable().assign(xferOp.mask());
+ newXferOp.getMaskMutable().assign(xferOp.getMask());
return;
}
@@ -801,8 +805,8 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
llvm::SmallVector<int64_t, 1> indices({i});
Location loc = xferOp.getLoc();
- auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
- newXferOp.maskMutable().assign(newMask);
+ auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
+ newXferOp.getMaskMutable().assign(newMask);
}
// If we end up here: The mask of the old transfer op is 1D and the unpacked
@@ -853,10 +857,10 @@ struct UnrollTransferReadConversion
Value getResultVector(TransferReadOp xferOp,
PatternRewriter &rewriter) const {
if (auto insertOp = getInsertOp(xferOp))
- return insertOp.dest();
+ return insertOp.getDest();
Location loc = xferOp.getLoc();
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.padding());
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -876,7 +880,7 @@ struct UnrollTransferReadConversion
void getInsertionIndices(TransferReadOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto insertOp = getInsertOp(xferOp)) {
- llvm::for_each(insertOp.position(), [&](Attribute attr) {
+ llvm::for_each(insertOp.getPosition(), [&](Attribute attr) {
indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
});
}
@@ -921,11 +925,11 @@ struct UnrollTransferReadConversion
getInsertionIndices(xferOp, insertionIndices);
insertionIndices.push_back(i);
- auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+ auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.source(), xferIndices,
+ loc, newXferVecType, xferOp.getSource(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
- xferOp.padding(), Value(), inBoundsAttr);
+ xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
return b.create<vector::InsertOp>(loc, newXferOp, vec,
insertionIndices);
@@ -988,13 +992,13 @@ struct UnrollTransferWriteConversion
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
- return extractOp.vector();
- return xferOp.vector();
+ return extractOp.getVector();
+ return xferOp.getVector();
}
/// If the input of the given TransferWriteOp is an ExtractOp, return it.
vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
- if (auto *op = xferOp.vector().getDefiningOp())
+ if (auto *op = xferOp.getVector().getDefiningOp())
return dyn_cast<vector::ExtractOp>(op);
return vector::ExtractOp();
}
@@ -1004,7 +1008,7 @@ struct UnrollTransferWriteConversion
void getExtractionIndices(TransferWriteOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto extractOp = getExtractOp(xferOp)) {
- llvm::for_each(extractOp.position(), [&](Attribute attr) {
+ llvm::for_each(extractOp.getPosition(), [&](Attribute attr) {
indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
});
}
@@ -1026,7 +1030,7 @@ struct UnrollTransferWriteConversion
auto vec = getDataVector(xferOp);
auto xferVecType = xferOp.getVectorType();
int64_t dimSize = xferVecType.getShape()[0];
- auto source = xferOp.source(); // memref or tensor to be written to.
+ auto source = xferOp.getSource(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
// Generate fully unrolled loop of transfer ops.
@@ -1050,7 +1054,7 @@ struct UnrollTransferWriteConversion
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
- auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+ auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, sourceType, extracted, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
@@ -1089,8 +1093,8 @@ template <typename OpTy>
static Optional<int64_t>
get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
SmallVector<Value, 8> &memrefIndices) {
- auto indices = xferOp.indices();
- auto map = xferOp.permutation_map();
+ auto indices = xferOp.getIndices();
+ auto map = xferOp.getPermutationMap();
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
memrefIndices.append(indices.begin(), indices.end());
@@ -1132,7 +1136,8 @@ struct Strategy1d<TransferReadOp> {
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
- Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
+ Value val =
+ b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
@@ -1144,7 +1149,7 @@ struct Strategy1d<TransferReadOp> {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.padding());
+ xferOp.getPadding());
}
};
@@ -1162,8 +1167,8 @@ struct Strategy1d<TransferWriteOp> {
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv);
- b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
+ b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
}
@@ -1221,7 +1226,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
// TODO: support 0-d corner case.
if (xferOp.getTransferRank() == 0)
return failure();
- auto map = xferOp.permutation_map();
+ auto map = xferOp.getPermutationMap();
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 4061f57909cac..5bdcc384165c9 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -44,11 +44,11 @@ struct VectorBitcastConvert final
if (!dstType)
return failure();
- if (dstType == adaptor.source().getType())
- rewriter.replaceOp(bitcastOp, adaptor.source());
+ if (dstType == adaptor.getSource().getType())
+ rewriter.replaceOp(bitcastOp, adaptor.getSource());
else
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
- adaptor.source());
+ adaptor.getSource());
return success();
}
@@ -61,11 +61,11 @@ struct VectorBroadcastConvert final
LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (broadcastOp.source().getType().isa<VectorType>() ||
+ if (broadcastOp.getSource().getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
return failure();
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
- adaptor.source());
+ adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
broadcastOp, broadcastOp.getVectorType(), source);
return success();
@@ -88,14 +88,14 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();
- if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
- rewriter.replaceOp(extractOp, adaptor.vector());
+ if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
- int32_t id = getFirstIntValue(extractOp.position());
+ int32_t id = getFirstIntValue(extractOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, adaptor.vector(), id);
+ extractOp, adaptor.getVector(), id);
return success();
}
};
@@ -111,10 +111,9 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();
-
- uint64_t offset = getFirstIntValue(extractOp.offsets());
- uint64_t size = getFirstIntValue(extractOp.sizes());
- uint64_t stride = getFirstIntValue(extractOp.strides());
+ uint64_t offset = getFirstIntValue(extractOp.getOffsets());
+ uint64_t size = getFirstIntValue(extractOp.getSizes());
+ uint64_t stride = getFirstIntValue(extractOp.getStrides());
if (stride != 1)
return failure();
@@ -147,7 +146,8 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
- fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
+ fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(),
+ adaptor.getAcc());
return success();
}
};
@@ -162,16 +162,16 @@ struct VectorInsertOpConvert final
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
- rewriter.replaceOp(insertOp, adaptor.source());
+ rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
- int32_t id = getFirstIntValue(insertOp.position());
+ int32_t id = getFirstIntValue(insertOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.source(), adaptor.dest(), id);
+ insertOp, adaptor.getSource(), adaptor.getDest(), id);
return success();
}
};
@@ -186,8 +186,8 @@ struct VectorExtractElementOpConvert final
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractElementOp, extractElementOp.getType(), adaptor.vector(),
- extractElementOp.position());
+ extractElementOp, extractElementOp.getType(), adaptor.getVector(),
+ extractElementOp.getPosition());
return success();
}
};
@@ -202,8 +202,8 @@ struct VectorInsertElementOpConvert final
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
return failure();
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
- adaptor.source(), insertElementOp.position());
+ insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
+ adaptor.getSource(), insertElementOp.getPosition());
return success();
}
};
@@ -218,10 +218,10 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
- uint64_t stride = getFirstIntValue(insertOp.strides());
+ uint64_t stride = getFirstIntValue(insertOp.getStrides());
if (stride != 1)
return failure();
- uint64_t offset = getFirstIntValue(insertOp.offsets());
+ uint64_t offset = getFirstIntValue(insertOp.getOffsets());
if (srcVector.getType().isa<spirv::ScalarType>()) {
assert(!dstVector.getType().isa<spirv::ScalarType>());
@@ -259,7 +259,8 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
VectorType dstVecType = op.getType();
if (!spirv::CompositeType::isValid(dstVecType))
return failure();
- SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
+ SmallVector<Value, 4> source(dstVecType.getNumElements(),
+ adaptor.getInput());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
source);
return success();
@@ -281,19 +282,19 @@ struct VectorShuffleOpConvert final
auto oldSourceType = shuffleOp.getV1VectorType();
if (oldSourceType.getNumElements() > 1) {
SmallVector<int32_t, 4> components = llvm::to_vector<4>(
- llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t {
+ llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
return attr.cast<IntegerAttr>().getValue().getZExtValue();
}));
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
- shuffleOp, newResultType, adaptor.v1(), adaptor.v2(),
+ shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(components));
return success();
}
- SmallVector<Value, 2> oldOperands = {adaptor.v1(), adaptor.v2()};
+ SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
SmallVector<Value, 4> newOperands;
newOperands.reserve(oldResultType.getNumElements());
- for (const APInt &i : shuffleOp.mask().getAsValueRange<IntegerAttr>()) {
+ for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
newOperands.push_back(oldOperands[i.getZExtValue()]);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 578d956665b6f..570359f953e55 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -148,7 +148,7 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write,
LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
<< "\n");
auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
- if (read && read.indices() == write.transferWriteOp.indices() &&
+ if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
read.getVectorType() == write.transferWriteOp.getVectorType())
return HoistableRead{read, sliceOp};
}
@@ -223,7 +223,7 @@ getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
Value v = yieldOperand.get();
if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
// Indexing must not depend on `forOp`.
- for (Value operand : write.indices())
+ for (Value operand : write.getIndices())
if (!forOp.isDefinedOutsideOfLoop(operand))
return HoistableWrite();
@@ -286,7 +286,7 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write,
read.extractSliceOp.sourceMutable().assign(
forOp.getInitArgs()[initArgNumber]);
else
- read.transferReadOp.sourceMutable().assign(
+ read.transferReadOp.getSourceMutable().assign(
forOp.getInitArgs()[initArgNumber]);
// Hoist write after.
@@ -299,12 +299,12 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write,
if (write.insertSliceOp)
yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
else
- yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
+ yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource());
// Rewrite `loop` with additional new yields.
OpBuilder b(read.transferReadOp);
- auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
- write.transferWriteOp.vector());
+ auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.getVector(),
+ write.transferWriteOp.getVector());
// Transfer write has been hoisted, need to update the vector and tensor
// source. Replace the result of the loop to use the new tensor created
// outside the loop.
@@ -313,17 +313,18 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write,
if (write.insertSliceOp) {
newForOp.getResult(initArgNumber)
.replaceAllUsesWith(write.insertSliceOp.getResult());
- write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result());
+ write.transferWriteOp.getSourceMutable().assign(
+ read.extractSliceOp.result());
write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
} else {
newForOp.getResult(initArgNumber)
.replaceAllUsesWith(write.transferWriteOp.getResult());
- write.transferWriteOp.sourceMutable().assign(
+ write.transferWriteOp.getSourceMutable().assign(
newForOp.getResult(initArgNumber));
}
// Always update with the newly yield tensor and vector.
- write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
+ write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
}
// To hoist transfer op on tensor the logic can be significantly simplified
@@ -355,7 +356,7 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
if (write.insertSliceOp)
LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
<< *write.insertSliceOp.getOperation() << "\n");
- if (llvm::any_of(write.transferWriteOp.indices(),
+ if (llvm::any_of(write.transferWriteOp.getIndices(),
[&forOp](Value index) {
return !forOp.isDefinedOutsideOfLoop(index);
}))
@@ -422,7 +423,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
- if (!candidateWrite || candidateWrite.source() != transferRead.source())
+ if (!candidateWrite ||
+ candidateWrite.getSource() != transferRead.getSource())
continue;
transferWrite = candidateWrite;
}
@@ -444,7 +446,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
// 2. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
- if (transferRead.indices() != transferWrite.indices() &&
+ if (transferRead.getIndices() != transferWrite.getIndices() &&
transferRead.getVectorType() == transferWrite.getVectorType())
return WalkResult::advance();
@@ -453,7 +455,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
- for (auto &use : transferRead.source().getUses()) {
+ for (auto &use : transferRead.getSource().getUses()) {
if (!loop->isAncestor(use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
@@ -488,12 +490,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
// Rewrite `loop` with new yields by cloning and erase the original loop.
OpBuilder b(transferRead);
- auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
- transferWrite.vector());
+ auto newForOp = cloneWithNewYields(b, loop, transferRead.getVector(),
+ transferWrite.getVector());
// Transfer write has been hoisted, need to update the written value to
// the value yielded by the newForOp.
- transferWrite.vector().replaceAllUsesWith(
+ transferWrite.getVector().replaceAllUsesWith(
newForOp.getResults().take_back()[0]);
changed = true;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a7ae79154ae51..4660baadf7e25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -846,15 +846,15 @@ struct PadOpVectorizationWithTransferReadPattern
if (!padValue)
return failure();
// Padding value of existing `xferOp` is unused.
- if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
+ if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
return failure();
rewriter.updateRootInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));
- xferOp.sourceMutable().assign(padOp.source());
- xferOp.paddingMutable().assign(padValue);
+ xferOp.getSourceMutable().assign(padOp.source());
+ xferOp.getPaddingMutable().assign(padValue);
});
return success();
@@ -929,8 +929,8 @@ struct PadOpVectorizationWithTransferWritePattern
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
- xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(),
+ xferOp, padOp.source().getType(), xferOp.getVector(), padOp.source(),
+ xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getMask(),
rewriter.getBoolArrayAttr(inBounds));
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
@@ -1174,11 +1174,11 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
// TODO: support mask.
- if (xferOp.mask())
+ if (xferOp.getMask())
return failure();
// Transfer into `view`.
- Value viewOrAlloc = xferOp.source();
+ Value viewOrAlloc = xferOp.getSource();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
@@ -1226,7 +1226,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
}
}
// Ensure padding matches.
- if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
+ if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
return failure();
if (maybeFillOp)
LDBG("with maybeFillOp " << *maybeFillOp);
@@ -1239,8 +1239,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
// When forwarding to vector.transfer_read, the attribute must be reset
// conservatively.
Value res = rewriter.create<vector::TransferReadOp>(
- xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
- xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(),
+ xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
+ xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
// in_bounds is explicitly reset
/*inBoundsAttr=*/ArrayAttr());
@@ -1257,11 +1257,11 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
// TODO: support mask.
- if (xferOp.mask())
+ if (xferOp.getMask())
return failure();
// Transfer into `viewOrAlloc`.
- Value viewOrAlloc = xferOp.source();
+ Value viewOrAlloc = xferOp.getSource();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
@@ -1297,8 +1297,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
// When forwarding to vector.transfer_write, the attribute must be reset
// conservatively.
rewriter.create<vector::TransferWriteOp>(
- xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
- xferOp.permutation_mapAttr(), xferOp.mask(),
+ xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
+ xferOp.getPermutationMapAttr(), xferOp.getMask(),
// in_bounds is explicitly reset
/*inBoundsAttr=*/ArrayAttr());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index c54cda0031632..4e0438f46d164 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -96,10 +96,12 @@ static Value getMemRefOperand(LoadOrStoreOpTy op) {
return op.memref();
}
-static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
+static Value getMemRefOperand(vector::TransferReadOp op) {
+ return op.getSource();
+}
static Value getMemRefOperand(vector::TransferWriteOp op) {
- return op.source();
+ return op.getSource();
}
/// Given the permutation map of the original
@@ -175,9 +177,9 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
sourceIndices,
getPermutationMapAttr(rewriter.getContext(), subViewOp,
- transferReadOp.permutation_map()),
- transferReadOp.padding(),
- /*mask=*/Value(), transferReadOp.in_boundsAttr());
+ transferReadOp.getPermutationMap()),
+ transferReadOp.getPadding(),
+ /*mask=*/Value(), transferReadOp.getInBoundsAttr());
}
template <typename StoreOpTy>
@@ -196,11 +198,11 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
if (transferWriteOp.getTransferRank() == 0)
return;
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
+ transferWriteOp, transferWriteOp.getVector(), subViewOp.source(),
sourceIndices,
getPermutationMapAttr(rewriter.getContext(), subViewOp,
- transferWriteOp.permutation_map()),
- transferWriteOp.in_boundsAttr());
+ transferWriteOp.getPermutationMap()),
+ transferWriteOp.getInBoundsAttr());
}
} // namespace
@@ -215,7 +217,7 @@ LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
- loadOp.indices(), sourceIndices)))
+ loadOp.getIndices(), sourceIndices)))
return failure();
replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
@@ -233,7 +235,7 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
- storeOp.indices(), sourceIndices)))
+ storeOp.getIndices(), sourceIndices)))
return failure();
replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9cf1538dd8bc0..cddc034cf485d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -76,7 +76,7 @@ static MaskFormat get1DMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
- ArrayAttr masks = m.mask_dim_sizes();
+ ArrayAttr masks = m.getMaskDimSizes();
assert(masks.size() == 1);
int64_t i = masks[0].cast<IntegerAttr>().getInt();
int64_t u = m.getType().getDimSize(0);
@@ -140,18 +140,18 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() &&
- defWrite.indices() == read.indices() &&
+ return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
+ !read.getMask() && defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.permutation_map() == read.permutation_map();
+ defWrite.getPermutationMap() == read.getPermutationMap();
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
- return priorWrite.indices() == write.indices() &&
- priorWrite.mask() == write.mask() &&
+ return priorWrite.getIndices() == write.getIndices() &&
+ priorWrite.getMask() == write.getMask() &&
priorWrite.getVectorType() == write.getVectorType() &&
- priorWrite.permutation_map() == write.permutation_map();
+ priorWrite.getPermutationMap() == write.getPermutationMap();
}
bool mlir::vector::isDisjointTransferIndices(
@@ -348,10 +348,10 @@ LogicalResult MultiDimReductionOp::inferReturnTypes(
DictionaryAttr attributes, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
MultiDimReductionOp::Adaptor op(operands, attributes);
- auto vectorType = op.source().getType().cast<VectorType>();
+ auto vectorType = op.getSource().getType().cast<VectorType>();
SmallVector<int64_t> targetShape;
for (auto it : llvm::enumerate(vectorType.getShape()))
- if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) {
+ if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) {
return attr.cast<IntegerAttr>().getValue() == it.index();
}))
targetShape.push_back(it.value());
@@ -367,7 +367,7 @@ LogicalResult MultiDimReductionOp::inferReturnTypes(
OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
- return source();
+ return getSource();
return {};
}
@@ -397,17 +397,17 @@ LogicalResult ReductionOp::verify() {
return emitOpError("unsupported reduction rank: ") << rank;
// Verify supported reduction kind.
- Type eltType = dest().getType();
- if (!isSupportedCombiningKind(kind(), eltType))
+ Type eltType = getDest().getType();
+ if (!isSupportedCombiningKind(getKind(), eltType))
return emitOpError("unsupported reduction type '")
- << eltType << "' for kind '" << stringifyCombiningKind(kind())
+ << eltType << "' for kind '" << stringifyCombiningKind(getKind())
<< "'";
// Verify optional accumulator.
- if (acc()) {
- if (kind() != CombiningKind::ADD && kind() != CombiningKind::MUL)
+ if (getAcc()) {
+ if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL)
return emitOpError("no accumulator for reduction kind: ")
- << stringifyCombiningKind(kind());
+ << stringifyCombiningKind(getKind());
if (!eltType.isa<FloatType>())
return emitOpError("no accumulator for type: ") << eltType;
}
@@ -439,11 +439,11 @@ ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
void ReductionOp::print(OpAsmPrinter &p) {
p << " ";
- kindAttr().print(p);
- p << ", " << vector();
- if (acc())
- p << ", " << acc();
- p << " : " << vector().getType() << " into " << dest().getType();
+ getKindAttr().print(p);
+ p << ", " << getVector();
+ if (getAcc())
+ p << ", " << getAcc();
+ p << " : " << getVector().getType() << " into " << getDest().getType();
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -582,13 +582,13 @@ void ContractionOp::print(OpAsmPrinter &p) {
attrs.push_back(attr);
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
- p << " " << dictAttr << " " << lhs() << ", ";
- p << rhs() << ", " << acc();
- if (masks().size() == 2)
- p << ", " << masks();
+ p << " " << dictAttr << " " << getLhs() << ", ";
+ p << getRhs() << ", " << getAcc();
+ if (getMasks().size() == 2)
+ p << ", " << getMasks();
p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
- p << " : " << lhs().getType() << ", " << rhs().getType() << " into "
+ p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
<< getResultType();
}
@@ -696,14 +696,14 @@ LogicalResult ContractionOp::verify() {
auto resType = getResultType();
// Verify that an indexing map was specified for each vector operand.
- if (indexing_maps().size() != 3)
+ if (getIndexingMaps().size() != 3)
return emitOpError("expected an indexing map for each vector operand");
// Verify that each index map has 'numIterators' inputs, no symbols, and
// that the number of map outputs equals the rank of its associated
// vector operand.
- unsigned numIterators = iterator_types().getValue().size();
- for (const auto &it : llvm::enumerate(indexing_maps())) {
+ unsigned numIterators = getIteratorTypes().getValue().size();
+ for (const auto &it : llvm::enumerate(getIndexingMaps())) {
auto index = it.index();
auto map = it.value();
if (map.getNumSymbols() != 0)
@@ -759,7 +759,7 @@ LogicalResult ContractionOp::verify() {
// Verify supported combining kind.
auto vectorType = resType.dyn_cast<VectorType>();
auto elementType = vectorType ? vectorType.getElementType() : resType;
- if (!isSupportedCombiningKind(kind(), elementType))
+ if (!isSupportedCombiningKind(getKind(), elementType))
return emitOpError("unsupported contraction type");
return success();
@@ -803,7 +803,7 @@ void ContractionOp::getIterationBounds(
auto resVectorType = getResultType().dyn_cast<VectorType>();
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
SmallVector<int64_t, 2> iterationShape;
- for (const auto &it : llvm::enumerate(iterator_types())) {
+ for (const auto &it : llvm::enumerate(getIteratorTypes())) {
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), getContext());
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
@@ -824,9 +824,9 @@ void ContractionOp::getIterationBounds(
void ContractionOp::getIterationIndexMap(
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
- unsigned numMaps = indexing_maps().size();
+ unsigned numMaps = getIndexingMaps().size();
iterationIndexMap.resize(numMaps);
- for (const auto &it : llvm::enumerate(indexing_maps())) {
+ for (const auto &it : llvm::enumerate(getIndexingMaps())) {
auto index = it.index();
auto map = it.value();
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -838,13 +838,13 @@ void ContractionOp::getIterationIndexMap(
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
- return getDimMap(indexingMaps, iterator_types(),
+ return getDimMap(indexingMaps, getIteratorTypes(),
getReductionIteratorTypeName(), getContext());
}
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
- return getDimMap(indexingMaps, iterator_types(),
+ return getDimMap(indexingMaps, getIteratorTypes(),
getParallelIteratorTypeName(), getContext());
}
@@ -886,11 +886,11 @@ struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
if (!contractionOp)
return vector::ContractionOp();
if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
- contractionOp.acc().getDefiningOp())) {
+ contractionOp.getAcc().getDefiningOp())) {
if (maybeZero.getValue() ==
- rewriter.getZeroAttr(contractionOp.acc().getType())) {
+ rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
BlockAndValueMapping bvm;
- bvm.map(contractionOp.acc(), otherOperand);
+ bvm.map(contractionOp.getAcc(), otherOperand);
auto newContraction =
cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
rewriter.replaceOp(addOp, newContraction.getResult());
@@ -932,13 +932,13 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
LogicalResult vector::ExtractElementOp::verify() {
VectorType vectorType = getVectorType();
if (vectorType.getRank() == 0) {
- if (position())
+ if (getPosition())
return emitOpError("expected position to be empty with 0-D vector");
return success();
}
if (vectorType.getRank() != 1)
return emitOpError("unexpected >1 vector rank");
- if (!position())
+ if (!getPosition())
return emitOpError("expected position for 1-D vector");
return success();
}
@@ -968,11 +968,12 @@ ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes);
- auto vectorType = op.vector().getType().cast<VectorType>();
- if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) {
+ auto vectorType = op.getVector().getType().cast<VectorType>();
+ if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
- auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1);
+ auto n =
+ std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType()));
}
@@ -993,7 +994,7 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
- auto positionAttr = position().getValue();
+ auto positionAttr = getPosition().getValue();
if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
return emitOpError(
"expected position attribute of rank smaller than vector rank");
@@ -1019,19 +1020,19 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
- if (!extractOp.vector().getDefiningOp<ExtractOp>())
+ if (!extractOp.getVector().getDefiningOp<ExtractOp>())
return failure();
SmallVector<int64_t, 4> globalPosition;
ExtractOp currentOp = extractOp;
- auto extrPos = extractVector<int64_t>(currentOp.position());
+ auto extrPos = extractVector<int64_t>(currentOp.getPosition());
globalPosition.append(extrPos.rbegin(), extrPos.rend());
- while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
+ while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
- auto extrPos = extractVector<int64_t>(currentOp.position());
+ auto extrPos = extractVector<int64_t>(currentOp.getPosition());
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
- extractOp.setOperand(currentOp.vector());
+ extractOp.setOperand(currentOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
@@ -1143,12 +1144,12 @@ class ExtractFromInsertTransposeChainState {
ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
ExtractOp e)
: extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
- extractedRank(extractOp.position().size()) {
+ extractedRank(extractOp.getPosition().size()) {
assert(vectorRank >= extractedRank && "extracted pos overflow");
sentinels.reserve(vectorRank - extractedRank);
for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
sentinels.push_back(-(i + 1));
- extractPosition = extractVector<int64_t>(extractOp.position());
+ extractPosition = extractVector<int64_t>(extractOp.getPosition());
llvm::append_range(extractPosition, sentinels);
}
@@ -1157,7 +1158,7 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
if (!nextTransposeOp)
return failure();
- auto permutation = extractVector<unsigned>(nextTransposeOp.transp());
+ auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
AffineMap m = inversePermutation(
AffineMap::getPermutationMap(permutation, extractOp.getContext()));
extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
@@ -1168,12 +1169,12 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
Value &res) {
- auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
if (makeArrayRef(insertedPos) !=
llvm::makeArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
- res = nextInsertOp.source();
+ res = nextInsertOp.getSource();
// Case 2.b. if internal transposition is present, canFold will be false.
return success();
}
@@ -1183,7 +1184,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
/// This method updates the internal state.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
- auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
if (!isContainedWithin(insertedPos, extractPosition))
return failure();
// Set leading dims to zero.
@@ -1193,7 +1194,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
extractPosition.begin() + insertedPos.size());
extractedRank = extractPosition.size() - sentinels.size();
// Case 3.a. early-exit fold (break and delegate to post-while path).
- res = nextInsertOp.source();
+ res = nextInsertOp.getSource();
// Case 3.b. if internal transposition is present, canFold will be false.
return success();
}
@@ -1204,28 +1205,28 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
Value source) {
// If we can't fold (either internal transposition, or nothing to fold), bail.
- bool nothingToFold = (source == extractOp.vector());
+ bool nothingToFold = (source == extractOp.getVector());
if (nothingToFold || !canFold())
return Value();
// Otherwise, fold by updating the op inplace and return its result.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(
- extractOp.positionAttrName(),
+ extractOp.getPositionAttrName(),
b.getI64ArrayAttr(
makeArrayRef(extractPosition).take_front(extractedRank)));
- extractOp.vectorMutable().assign(source);
+ extractOp.getVectorMutable().assign(source);
return extractOp.getResult();
}
/// Iterate over producing insert and transpose ops until we find a fold.
Value ExtractFromInsertTransposeChainState::fold() {
- Value valueToExtractFrom = extractOp.vector();
+ Value valueToExtractFrom = extractOp.getVector();
updateStateForNextIteration(valueToExtractFrom);
while (nextInsertOp || nextTransposeOp) {
// Case 1. If we hit a transpose, just compose the map and iterate.
// Invariant: insert + transpose do not change rank, we can always compose.
if (succeeded(handleTransposeOp())) {
- valueToExtractFrom = nextTransposeOp.vector();
+ valueToExtractFrom = nextTransposeOp.getVector();
updateStateForNextIteration(valueToExtractFrom);
continue;
}
@@ -1242,13 +1243,13 @@ Value ExtractFromInsertTransposeChainState::fold() {
// Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
// values. This is a more
diff icult case and we bail.
- auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
if (isContainedWithin(extractPosition, insertedPos) ||
intersectsWhereNonNegative(extractPosition, insertedPos))
return Value();
// Case 5: No intersection, we forward the extract to insertOp.dest().
- valueToExtractFrom = nextInsertOp.dest();
+ valueToExtractFrom = nextInsertOp.getDest();
updateStateForNextIteration(valueToExtractFrom);
}
// If after all this we can fold, go for it.
@@ -1257,7 +1258,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.vector().getDefiningOp();
+ Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
Value source = defOp->getOperand(0);
@@ -1269,7 +1270,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank < broadcastSrcRank) {
- auto extractPos = extractVector<int64_t>(extractOp.position());
+ auto extractPos = extractVector<int64_t>(extractOp.getPosition());
unsigned rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(
extractPos.begin(),
@@ -1286,7 +1287,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
- auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
+ auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
// Get the nth dimension size starting from lowest dimension.
@@ -1312,7 +1313,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
- auto extractedPos = extractVector<int64_t>(extractOp.position());
+ auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
@@ -1339,14 +1340,14 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
OpBuilder b(extractOp.getContext());
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(newPosition));
- extractOp.setOperand(shapeCastOp.source());
+ extractOp.setOperand(shapeCastOp.getSource());
return extractOp.getResult();
}
/// Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
auto extractStridedSliceOp =
- extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>();
+ extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return Value();
// Return if 'extractStridedSliceOp' has non-unit strides.
@@ -1354,7 +1355,8 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
// Trim offsets for dimensions fully extracted.
- auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets());
+ auto sliceOffsets =
+ extractVector<int64_t>(extractStridedSliceOp.getOffsets());
while (!sliceOffsets.empty()) {
size_t lastOffset = sliceOffsets.size() - 1;
if (sliceOffsets.back() != 0 ||
@@ -1371,11 +1373,11 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
if (destinationRank >
extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
return Value();
- auto extractedPos = extractVector<int64_t>(extractOp.position());
+ auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
- extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
+ extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
@@ -1388,16 +1390,16 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
int64_t destinationRank = op.getType().isa<VectorType>()
? op.getType().cast<VectorType>().getRank()
: 0;
- auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+ auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
insertOp.getSourceVectorType().getRank();
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
- auto insertOffsets = extractVector<int64_t>(insertOp.offsets());
- auto extractOffsets = extractVector<int64_t>(op.position());
+ auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+ auto extractOffsets = extractVector<int64_t>(op.getPosition());
- if (llvm::any_of(insertOp.strides(), [](Attribute attr) {
+ if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
}))
return Value();
@@ -1432,7 +1434,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
insertRankDiff))
return Value();
}
- op.vectorMutable().assign(insertOp.source());
+ op.getVectorMutable().assign(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
op->setAttr(ExtractOp::getPositionAttrStrName(),
@@ -1441,14 +1443,14 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
}
// If the chunk extracted is disjoint from the chunk inserted, keep
// looking in the insert chain.
- insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+ insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
}
return Value();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
- if (position().empty())
- return vector();
+ if (getPosition().empty())
+ return getVector();
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -1473,7 +1475,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.vector().getDefiningOp();
+ Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return failure();
Value source = defOp->getOperand(0);
@@ -1504,7 +1506,7 @@ class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
PatternRewriter &rewriter) const override {
// Return if 'extractStridedSliceOp' operand is not defined by a
// ConstantOp.
- auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
+ auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
@@ -1566,18 +1568,18 @@ LogicalResult ExtractMapOp::verify() {
if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
numId++;
}
- if (numId != ids().size())
+ if (numId != getIds().size())
return emitOpError("expected number of ids must match the number of "
"dimensions distributed");
return success();
}
OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
- auto insert = vector().getDefiningOp<vector::InsertMapOp>();
- if (insert == nullptr || getType() != insert.vector().getType() ||
- ids() != insert.ids())
+ auto insert = getVector().getDefiningOp<vector::InsertMapOp>();
+ if (insert == nullptr || getType() != insert.getVector().getType() ||
+ getIds() != insert.getIds())
return {};
- return insert.vector();
+ return insert.getVector();
}
void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
@@ -1670,7 +1672,7 @@ LogicalResult BroadcastOp::verify() {
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (getSourceType() == getVectorType())
- return source();
+ return getSource();
if (!operands[0])
return {};
auto vectorType = getVectorType();
@@ -1689,11 +1691,11 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
- auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
+ auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
if (!srcBroadcast)
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(
- broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
+ broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
return success();
}
};
@@ -1734,7 +1736,7 @@ LogicalResult ShuffleOp::verify() {
return emitOpError("dimension mismatch");
}
// Verify mask length.
- auto maskAttr = mask().getValue();
+ auto maskAttr = getMask().getValue();
int64_t maskLength = maskAttr.size();
if (maskLength <= 0)
return emitOpError("invalid mask length");
@@ -1756,12 +1758,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes);
- auto v1Type = op.v1().getType().cast<VectorType>();
+ auto v1Type = op.getV1().getType().cast<VectorType>();
// Construct resulting type: leading dimension matches mask length,
// all trailing dimensions match the operands.
SmallVector<int64_t, 4> shape;
shape.reserve(v1Type.getRank());
- shape.push_back(std::max<size_t>(1, op.mask().size()));
+ shape.push_back(std::max<size_t>(1, op.getMask().size()));
llvm::append_range(shape, v1Type.getShape().drop_front());
inferredReturnTypes.push_back(
VectorType::get(shape, v1Type.getElementType()));
@@ -1783,7 +1785,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
SmallVector<Attribute> results;
auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
- for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) {
+ for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
int64_t i = index.getZExtValue();
if (i >= lhsSize) {
results.push_back(rhsElements[i - lhsSize]);
@@ -1807,13 +1809,13 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
LogicalResult InsertElementOp::verify() {
auto dstVectorType = getDestVectorType();
if (dstVectorType.getRank() == 0) {
- if (position())
+ if (getPosition())
return emitOpError("expected position to be empty with 0-D vector");
return success();
}
if (dstVectorType.getRank() != 1)
return emitOpError("unexpected >1 vector rank");
- if (!position())
+ if (!getPosition())
return emitOpError("expected position for 1-D vector");
return success();
}
@@ -1841,7 +1843,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
}
LogicalResult InsertOp::verify() {
- auto positionAttr = position().getValue();
+ auto positionAttr = getPosition().getValue();
auto destVectorType = getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
@@ -1883,7 +1885,7 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
srcVecType.getNumElements())
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.source());
+ insertOp, insertOp.getDestVectorType(), insertOp.getSource());
return success();
}
};
@@ -1899,8 +1901,8 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
// value. This happens when the source and destination vectors have identical
// sizes.
OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
- if (position().empty())
- return source();
+ if (getPosition().empty())
+ return getSource();
return {};
}
@@ -1920,7 +1922,7 @@ LogicalResult InsertMapOp::verify() {
if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i))
numId++;
}
- if (numId != ids().size())
+ if (numId != getIds().size())
return emitOpError("expected number of ids must match the number of "
"dimensions distributed");
return success();
@@ -2037,8 +2039,8 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
LogicalResult InsertStridedSliceOp::verify() {
auto sourceVectorType = getSourceVectorType();
auto destVectorType = getDestVectorType();
- auto offsets = offsetsAttr();
- auto strides = stridesAttr();
+ auto offsets = getOffsetsAttr();
+ auto strides = getStridesAttr();
if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected offsets of same size as destination vector rank");
@@ -2072,7 +2074,7 @@ LogicalResult InsertStridedSliceOp::verify() {
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getSourceVectorType() == getDestVectorType())
- return source();
+ return getSource();
return {};
}
@@ -2088,12 +2090,12 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result,
}
void OuterProductOp::print(OpAsmPrinter &p) {
- p << " " << lhs() << ", " << rhs();
- if (!acc().empty()) {
- p << ", " << acc();
+ p << " " << getLhs() << ", " << getRhs();
+ if (!getAcc().empty()) {
+ p << ", " << getAcc();
p.printOptionalAttrDict((*this)->getAttrs());
}
- p << " : " << lhs().getType() << ", " << rhs().getType();
+ p << " : " << getLhs().getType() << ", " << getRhs().getType();
}
ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2163,7 +2165,7 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected operand #3 of same type as result type");
// Verify supported combining kind.
- if (!isSupportedCombiningKind(kind(), vRES.getElementType()))
+ if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
return emitOpError("unsupported outerproduct type");
return success();
@@ -2214,14 +2216,14 @@ LogicalResult ReshapeOp::verify() {
auto isDefByConstant = [](Value operand) {
return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
};
- if (llvm::all_of(input_shape(), isDefByConstant) &&
- llvm::all_of(output_shape(), isDefByConstant)) {
+ if (llvm::all_of(getInputShape(), isDefByConstant) &&
+ llvm::all_of(getOutputShape(), isDefByConstant)) {
int64_t numInputElements = 1;
- for (auto operand : input_shape())
+ for (auto operand : getInputShape())
numInputElements *=
cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
int64_t numOutputElements = 1;
- for (auto operand : output_shape())
+ for (auto operand : getOutputShape())
numOutputElements *=
cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
if (numInputElements != numOutputElements)
@@ -2231,7 +2233,7 @@ LogicalResult ReshapeOp::verify() {
}
void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(fixed_vector_sizes(), results);
+ populateFromInt64AttrArray(getFixedVectorSizes(), results);
}
//===----------------------------------------------------------------------===//
@@ -2274,9 +2276,9 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
LogicalResult ExtractStridedSliceOp::verify() {
auto type = getVectorType();
- auto offsets = offsetsAttr();
- auto sizes = sizesAttr();
- auto strides = stridesAttr();
+ auto offsets = getOffsetsAttr();
+ auto sizes = getSizesAttr();
+ auto strides = getStridesAttr();
if (offsets.size() != sizes.size() || offsets.size() != strides.size())
return emitOpError("expected offsets, sizes and strides attributes of same size");
@@ -2316,16 +2318,16 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
auto getElement = [](ArrayAttr array, int idx) {
return array[idx].cast<IntegerAttr>().getInt();
};
- ArrayAttr extractOffsets = op.offsets();
- ArrayAttr extractStrides = op.strides();
- ArrayAttr extractSizes = op.sizes();
- auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
+ ArrayAttr extractOffsets = op.getOffsets();
+ ArrayAttr extractStrides = op.getStrides();
+ ArrayAttr extractSizes = op.getSizes();
+ auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
if (op.getVectorType().getRank() !=
insertOp.getSourceVectorType().getRank())
return failure();
- ArrayAttr insertOffsets = insertOp.offsets();
- ArrayAttr insertStrides = insertOp.strides();
+ ArrayAttr insertOffsets = insertOp.getOffsets();
+ ArrayAttr insertStrides = insertOp.getStrides();
// If the rank of extract is greater than the rank of insert, we are likely
// extracting a partial chunk of the vector inserted.
if (extractOffsets.size() > insertOffsets.size())
@@ -2354,7 +2356,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
}
// The extract element chunk is a subset of the insert element.
if (!disjoint && !patialoverlap) {
- op.setOperand(insertOp.source());
+ op.setOperand(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
@@ -2364,7 +2366,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// If the chunk extracted is disjoint from the chunk inserted, keep looking
// in the insert chain.
if (disjoint)
- insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
+ insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
else {
// The extracted vector partially overlap the inserted vector, we cannot
// fold.
@@ -2376,14 +2378,14 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getVectorType() == getResult().getType())
- return vector();
+ return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
return {};
}
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(offsets(), results);
+ populateFromInt64AttrArray(getOffsets(), results);
}
namespace {
@@ -2399,7 +2401,7 @@ class StridedSliceConstantMaskFolder final
PatternRewriter &rewriter) const override {
// Return if 'extractStridedSliceOp' operand is not defined by a
// ConstantMaskOp.
- auto *defOp = extractStridedSliceOp.vector().getDefiningOp();
+ auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
if (!constantMaskOp)
return failure();
@@ -2408,12 +2410,13 @@ class StridedSliceConstantMaskFolder final
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
- populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
+ populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
- populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
+ populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+ sliceOffsets);
SmallVector<int64_t, 4> sliceSizes;
- populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
+ populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
// Compute slice of vector mask region.
SmallVector<int64_t, 4> sliceMaskDimSizes;
@@ -2452,7 +2455,7 @@ class StridedSliceConstantFolder final
// Return if 'extractStridedSliceOp' operand is not defined by a
// ConstantOp.
auto constantOp =
- extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>();
+ extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
@@ -2475,10 +2478,10 @@ class StridedSliceBroadcast final
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
+ auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
if (!broadcast)
return failure();
- auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
+ auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
auto dstVecType = op.getType().cast<VectorType>();
unsigned dstRank = dstVecType.getRank();
@@ -2493,15 +2496,15 @@ class StridedSliceBroadcast final
break;
}
}
- Value source = broadcast.source();
+ Value source = broadcast.getSource();
if (!lowerDimMatch) {
// The inner dimensions don't match, it means we need to extract from the
// source of the orignal broadcast and then broadcast the extracted value.
source = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), source,
- getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
- getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
- getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
+ getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
+ getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
+ getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
return success();
@@ -2515,10 +2518,10 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto splat = op.vector().getDefiningOp<SplatOp>();
+ auto splat = op.getVector().getDefiningOp<SplatOp>();
if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input());
+ rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
return success();
}
};
@@ -2726,9 +2729,9 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
}
void TransferReadOp::print(OpAsmPrinter &p) {
- p << " " << source() << "[" << indices() << "], " << padding();
- if (mask())
- p << ", " << mask();
+ p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
+ if (getMask())
+ p << ", " << getMask();
printTransferAttrs(p, *this);
p << " : " << getShapedType() << ", " << getVectorType();
}
@@ -2798,16 +2801,16 @@ LogicalResult TransferReadOp::verify() {
ShapedType shapedType = getShapedType();
VectorType vectorType = getVectorType();
VectorType maskType = getMaskType();
- auto paddingType = padding().getType();
- auto permutationMap = permutation_map();
+ auto paddingType = getPadding().getType();
+ auto permutationMap = getPermutationMap();
auto sourceElementType = shapedType.getElementType();
- if (static_cast<int64_t>(indices().size()) != shapedType.getRank())
+ if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
return emitOpError("requires ") << shapedType.getRank() << " indices";
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
shapedType, vectorType, maskType, permutationMap,
- in_bounds() ? *in_bounds() : ArrayAttr())))
+ getInBounds() ? *getInBounds() : ArrayAttr())))
return failure();
if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
@@ -2867,7 +2870,7 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
if (op.getShapedType().isDynamicDim(indicesIdx))
return false;
- Value index = op.indices()[indicesIdx];
+ Value index = op.getIndices()[indicesIdx];
auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
if (!cstOp)
return false;
@@ -2884,7 +2887,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
// TODO: Be less conservative.
if (op.getTransferRank() == 0)
return failure();
- AffineMap permutationMap = op.permutation_map();
+ AffineMap permutationMap = op.getPermutationMap();
bool changed = false;
SmallVector<bool, 4> newInBounds;
newInBounds.reserve(op.getTransferRank());
@@ -2926,15 +2929,15 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
static Value foldRAW(TransferReadOp readOp) {
if (!readOp.getShapedType().isa<RankedTensorType>())
return {};
- auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
+ auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueRAW(defWrite, readOp))
- return defWrite.vector();
+ return defWrite.getVector();
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(readOp.getOperation())))
break;
- defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+ defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
}
return {};
}
@@ -2960,7 +2963,7 @@ void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getShapedType().isa<MemRefType>())
- effects.emplace_back(MemoryEffects::Read::get(), source(),
+ effects.emplace_back(MemoryEffects::Read::get(), getSource(),
SideEffects::DefaultResource::get());
}
@@ -2992,11 +2995,11 @@ struct FoldExtractSliceIntoTransferRead
return failure();
if (xferOp.hasOutOfBoundsDim())
return failure();
- if (!xferOp.permutation_map().isIdentity())
+ if (!xferOp.getPermutationMap().isIdentity())
return failure();
- if (xferOp.mask())
+ if (xferOp.getMask())
return failure();
- auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+ auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
return failure();
if (!extractOp.hasUnitStride())
@@ -3039,7 +3042,7 @@ struct FoldExtractSliceIntoTransferRead
newIndices.push_back(getValueOrCreateConstantIndexOp(
rewriter, extractOp.getLoc(), offset));
}
- for (const auto &it : llvm::enumerate(xferOp.indices())) {
+ for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
OpFoldResult offset =
extractOp.getMixedOffsets()[it.index() + rankReduced];
newIndices.push_back(rewriter.create<arith::AddIOp>(
@@ -3050,7 +3053,7 @@ struct FoldExtractSliceIntoTransferRead
SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
rewriter.replaceOpWithNewOp<TransferReadOp>(
xferOp, xferOp.getVectorType(), extractOp.source(), newIndices,
- xferOp.padding(), ArrayRef<bool>{inBounds});
+ xferOp.getPadding(), ArrayRef<bool>{inBounds});
return success();
}
@@ -3165,9 +3168,9 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
}
void TransferWriteOp::print(OpAsmPrinter &p) {
- p << " " << vector() << ", " << source() << "[" << indices() << "]";
- if (mask())
- p << ", " << mask();
+ p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
+ if (getMask())
+ p << ", " << getMask();
printTransferAttrs(p, *this);
p << " : " << getVectorType() << ", " << getShapedType();
}
@@ -3177,9 +3180,9 @@ LogicalResult TransferWriteOp::verify() {
ShapedType shapedType = getShapedType();
VectorType vectorType = getVectorType();
VectorType maskType = getMaskType();
- auto permutationMap = permutation_map();
+ auto permutationMap = getPermutationMap();
- if (llvm::size(indices()) != shapedType.getRank())
+ if (llvm::size(getIndices()) != shapedType.getRank())
return emitOpError("requires ") << shapedType.getRank() << " indices";
// We do not allow broadcast dimensions on TransferWriteOps for the moment,
@@ -3189,7 +3192,7 @@ LogicalResult TransferWriteOp::verify() {
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
shapedType, vectorType, maskType, permutationMap,
- in_bounds() ? *in_bounds() : ArrayAttr())))
+ getInBounds() ? *getInBounds() : ArrayAttr())))
return failure();
return verifyPermutationMap(permutationMap,
@@ -3219,20 +3222,21 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
// TODO: support 0-d corner case.
if (write.getTransferRank() == 0)
return failure();
- auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
+ auto rankedTensorType =
+ write.getSource().getType().dyn_cast<RankedTensorType>();
// If not operating on tensors, bail.
if (!rankedTensorType)
return failure();
// If no read, bail.
- auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
+ auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
if (!read)
return failure();
// TODO: support 0-d corner case.
if (read.getTransferRank() == 0)
return failure();
// For now, only accept minor identity. Future: composition is minor identity.
- if (!read.permutation_map().isMinorIdentity() ||
- !write.permutation_map().isMinorIdentity())
+ if (!read.getPermutationMap().isMinorIdentity() ||
+ !write.getPermutationMap().isMinorIdentity())
return failure();
// Bail on mismatching ranks.
if (read.getTransferRank() != write.getTransferRank())
@@ -3241,7 +3245,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
return failure();
// Tensor types must be the same.
- if (read.source().getType() != rankedTensorType)
+ if (read.getSource().getType() != rankedTensorType)
return failure();
// Vector types must be the same.
if (read.getVectorType() != write.getVectorType())
@@ -3254,20 +3258,21 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
return !cstOp || cstOp.value() != 0;
};
- if (llvm::any_of(read.indices(), isNotConstantZero) ||
- llvm::any_of(write.indices(), isNotConstantZero))
+ if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
+ llvm::any_of(write.getIndices(), isNotConstantZero))
return failure();
// Success.
- results.push_back(read.source());
+ results.push_back(read.getSource());
return success();
}
static bool checkSameValueWAR(vector::TransferReadOp read,
vector::TransferWriteOp write) {
- return read.source() == write.source() && read.indices() == write.indices() &&
- read.permutation_map() == write.permutation_map() &&
- read.getVectorType() == write.getVectorType() && !read.mask() &&
- !write.mask();
+ return read.getSource() == write.getSource() &&
+ read.getIndices() == write.getIndices() &&
+ read.getPermutationMap() == write.getPermutationMap() &&
+ read.getVectorType() == write.getVectorType() && !read.getMask() &&
+ !write.getMask();
}
/// Fold transfer_write write after read:
/// ```
@@ -3285,15 +3290,15 @@ static bool checkSameValueWAR(vector::TransferReadOp read,
/// ```
static LogicalResult foldWAR(TransferWriteOp write,
SmallVectorImpl<OpFoldResult> &results) {
- if (!write.source().getType().isa<RankedTensorType>())
+ if (!write.getSource().getType().isa<RankedTensorType>())
return failure();
- auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
+ auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
if (!read)
return failure();
if (!checkSameValueWAR(read, write))
return failure();
- results.push_back(read.source());
+ results.push_back(read.getSource());
return success();
}
@@ -3316,7 +3321,7 @@ void TransferWriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getShapedType().isa<MemRefType>())
- effects.emplace_back(MemoryEffects::Write::get(), source(),
+ effects.emplace_back(MemoryEffects::Write::get(), getSource(),
SideEffects::DefaultResource::get());
}
@@ -3354,10 +3359,11 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
return failure();
vector::TransferWriteOp writeToModify = writeOp;
- auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
+ auto defWrite =
+ writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
- writeToModify.sourceMutable().assign(defWrite.source());
+ writeToModify.getSourceMutable().assign(defWrite.getSource());
return success();
}
if (!isDisjointTransferIndices(
@@ -3369,7 +3375,7 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
if (!defWrite->hasOneUse())
break;
writeToModify = defWrite;
- defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+ defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
}
return failure();
}
@@ -3410,7 +3416,7 @@ struct FoldInsertSliceIntoTransferWrite
return failure();
if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
return failure();
- if (xferOp.mask())
+ if (xferOp.getMask())
return failure();
// Fold only if the TransferWriteOp completely overwrites the `source` with
// a vector. I.e., the result of the TransferWriteOp is a new tensor whose
@@ -3418,7 +3424,7 @@ struct FoldInsertSliceIntoTransferWrite
if (!llvm::equal(xferOp.getVectorType().getShape(),
xferOp.getShapedType().getShape()))
return failure();
- if (!xferOp.permutation_map().isIdentity())
+ if (!xferOp.getPermutationMap().isIdentity())
return failure();
// Bail on illegal rank-reduction: we need to check that the rank-reduced
@@ -3453,7 +3459,7 @@ struct FoldInsertSliceIntoTransferWrite
SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(),
+ rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
insertOp.dest(), indices,
ArrayRef<bool>{inBounds});
return success();
@@ -3494,7 +3500,7 @@ LogicalResult vector::LoadOp::verify() {
if (resVecTy.getElementType() != memElemTy)
return emitOpError("base and result element types should match");
- if (llvm::size(indices()) != memRefTy.getRank())
+ if (llvm::size(getIndices()) != memRefTy.getRank())
return emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
@@ -3527,7 +3533,7 @@ LogicalResult vector::StoreOp::verify() {
if (valueVecTy.getElementType() != memElemTy)
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(indices()) != memRefTy.getRank())
+ if (llvm::size(getIndices()) != memRefTy.getRank())
return emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
@@ -3549,7 +3555,7 @@ LogicalResult MaskedLoadOp::verify() {
if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected result dim to match mask dim");
@@ -3564,13 +3570,13 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(load.mask())) {
+ switch (get1DMaskFormat(load.getMask())) {
case MaskFormat::AllTrue:
- rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
- load.base(), load.indices());
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ load, load.getType(), load.getBase(), load.getIndices());
return success();
case MaskFormat::AllFalse:
- rewriter.replaceOp(load, load.pass_thru());
+ rewriter.replaceOp(load, load.getPassThru());
return success();
case MaskFormat::Unknown:
return failure();
@@ -3602,7 +3608,7 @@ LogicalResult MaskedStoreOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected valueToStore dim to match mask dim");
@@ -3615,10 +3621,10 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(store.mask())) {
+ switch (get1DMaskFormat(store.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- store, store.valueToStore(), store.base(), store.indices());
+ store, store.getValueToStore(), store.getBase(), store.getIndices());
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(store);
@@ -3653,7 +3659,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != indVType.getDimSize(0))
return emitOpError("expected result dim to match indices dim");
@@ -3670,11 +3676,11 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(gather.mask())) {
+ switch (get1DMaskFormat(gather.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
- rewriter.replaceOp(gather, gather.pass_thru());
+ rewriter.replaceOp(gather, gather.getPassThru());
return success();
case MaskFormat::Unknown:
return failure();
@@ -3701,7 +3707,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != indVType.getDimSize(0))
return emitOpError("expected valueToStore dim to match indices dim");
@@ -3716,7 +3722,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(scatter.mask())) {
+ switch (get1DMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
@@ -3747,7 +3753,7 @@ LogicalResult ExpandLoadOp::verify() {
if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected result dim to match mask dim");
@@ -3762,13 +3768,13 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(expand.mask())) {
+ switch (get1DMaskFormat(expand.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::LoadOp>(
- expand, expand.getType(), expand.base(), expand.indices());
+ expand, expand.getType(), expand.getBase(), expand.getIndices());
return success();
case MaskFormat::AllFalse:
- rewriter.replaceOp(expand, expand.pass_thru());
+ rewriter.replaceOp(expand, expand.getPassThru());
return success();
case MaskFormat::Unknown:
return failure();
@@ -3794,7 +3800,7 @@ LogicalResult CompressStoreOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(indices()) != memType.getRank())
+ if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected valueToStore dim to match mask dim");
@@ -3807,11 +3813,11 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(compress.mask())) {
+ switch (get1DMaskFormat(compress.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- compress, compress.valueToStore(), compress.base(),
- compress.indices());
+ compress, compress.getValueToStore(), compress.getBase(),
+ compress.getIndices());
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(compress);
@@ -3894,8 +3900,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
}
LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType = source().getType().dyn_cast_or_null<VectorType>();
- auto resultVectorType = result().getType().dyn_cast_or_null<VectorType>();
+ auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
+ auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
// Check if source/result are of vector type.
if (sourceVectorType && resultVectorType)
@@ -3906,16 +3912,16 @@ LogicalResult ShapeCastOp::verify() {
OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
// Nop shape cast.
- if (source().getType() == result().getType())
- return source();
+ if (getSource().getType() == getResult().getType())
+ return getSource();
// Canceling shape casts.
- if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
- if (result().getType() == otherOp.source().getType())
- return otherOp.source();
+ if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+ if (getResult().getType() == otherOp.getSource().getType())
+ return otherOp.getSource();
// Only allows valid transitive folding.
- VectorType srcType = otherOp.source().getType().cast<VectorType>();
+ VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
VectorType resultType = getResult().getType().cast<VectorType>();
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
@@ -3927,7 +3933,7 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
return {};
}
- setOperand(otherOp.source());
+ setOperand(otherOp.getSource());
return getResult();
}
return {};
@@ -3941,7 +3947,8 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
- auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>();
+ auto constantOp =
+ shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
// Only handle splat for now.
@@ -3998,13 +4005,13 @@ LogicalResult BitCastOp::verify() {
OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
// Nop cast.
- if (source().getType() == result().getType())
- return source();
+ if (getSource().getType() == getResult().getType())
+ return getSource();
// Canceling bitcasts.
- if (auto otherOp = source().getDefiningOp<BitCastOp>())
- if (result().getType() == otherOp.source().getType())
- return otherOp.source();
+ if (auto otherOp = getSource().getDefiningOp<BitCastOp>())
+ if (getResult().getType() == otherOp.getSource().getType())
+ return otherOp.getSource();
Attribute sourceConstant = operands.front();
if (!sourceConstant)
@@ -4113,7 +4120,7 @@ OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
return {};
}
- return vector();
+ return getVector();
}
LogicalResult vector::TransposeOp::verify() {
@@ -4123,7 +4130,7 @@ LogicalResult vector::TransposeOp::verify() {
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = transp().getValue();
+ auto transpAttr = getTransp().getValue();
int64_t size = transpAttr.size();
if (rank != size)
return emitOpError("transposition length mismatch: ") << size;
@@ -4168,7 +4175,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
// Return if the input of 'transposeOp' is not defined by another transpose.
vector::TransposeOp parentTransposeOp =
- transposeOp.vector().getDefiningOp<vector::TransposeOp>();
+ transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
if (!parentTransposeOp)
return failure();
@@ -4177,7 +4184,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
// Replace 'transposeOp' with a new transpose operation.
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- parentTransposeOp.vector(),
+ parentTransposeOp.getVector(),
vector::getVectorSubscriptAttr(rewriter, permutation));
return success();
}
@@ -4191,7 +4198,7 @@ void vector::TransposeOp::getCanonicalizationPatterns(
}
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(transp(), results);
+ populateFromInt64AttrArray(getTransp(), results);
}
//===----------------------------------------------------------------------===//
@@ -4202,23 +4209,23 @@ LogicalResult ConstantMaskOp::verify() {
auto resultType = getResult().getType().cast<VectorType>();
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
- if (mask_dim_sizes().size() != 1)
+ if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
- auto dim = mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
+ auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
}
// Verify that array attr size matches the rank of the vector result.
- if (static_cast<int64_t>(mask_dim_sizes().size()) != resultType.getRank())
+ if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
return emitOpError(
"must specify array attr of size equal vector result rank");
// Verify that each array attr element is in bounds of corresponding vector
// result dimension size.
auto resultShape = resultType.getShape();
SmallVector<int64_t, 4> maskDimSizes;
- for (const auto &it : llvm::enumerate(mask_dim_sizes())) {
+ for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
if (attrValue < 0 || attrValue > resultShape[it.index()])
return emitOpError(
@@ -4238,7 +4245,7 @@ LogicalResult ConstantMaskOp::verify() {
// `vector.constant_mask`. In the future, a convention could be established
// to decide if a specific dimension value could be considered as "all set".
if (resultType.isScalable() &&
- mask_dim_sizes()[0].cast<IntegerAttr>().getInt() != 0)
+ getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
@@ -4329,7 +4336,7 @@ LogicalResult ScanOp::verify() {
VectorType initialType = getInitialValueType();
// Check reduction dimension < rank.
int64_t srcRank = srcType.getRank();
- int64_t reductionDim = reduction_dim();
+ int64_t reductionDim = getReductionDim();
if (reductionDim >= srcRank)
return emitOpError("reduction dimension ")
<< reductionDim << " has to be less than " << srcRank;
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index c823f34b695ac..dd834306c3e77 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -55,9 +55,9 @@ struct TransferReadOpInterface
Value buffer =
*state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
- rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
- readOp.permutation_map(), readOp.padding(), readOp.mask(),
- readOp.in_boundsAttr());
+ rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
+ readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
+ readOp.getInBoundsAttr());
return success();
}
};
@@ -107,8 +107,9 @@ struct TransferWriteOpInterface
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(
- writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
- writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
+ writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
+ writeOp.getIndices(), writeOp.getPermutationMapAttr(),
+ writeOp.getInBoundsAttr());
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index e6afec8b2c5c6..d555c60439f71 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -63,16 +63,16 @@ struct CastAwayExtractStridedSliceLeadingOneDim
Location loc = extractOp.getLoc();
Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, extractOp.vector(), splatZero(dropCount));
+ loc, extractOp.getVector(), splatZero(dropCount));
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
auto newOffsets = rewriter.getArrayAttr(
- extractOp.offsets().getValue().drop_front(dropCount));
+ extractOp.getOffsets().getValue().drop_front(dropCount));
auto newSizes = rewriter.getArrayAttr(
- extractOp.sizes().getValue().drop_front(dropCount));
+ extractOp.getSizes().getValue().drop_front(dropCount));
auto newStrides = rewriter.getArrayAttr(
- extractOp.strides().getValue().drop_front(dropCount));
+ extractOp.getStrides().getValue().drop_front(dropCount));
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
@@ -106,14 +106,14 @@ struct CastAwayInsertStridedSliceLeadingOneDim
Location loc = insertOp.getLoc();
Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.source(), splatZero(srcDropCount));
+ loc, insertOp.getSource(), splatZero(srcDropCount));
Value newDstVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.dest(), splatZero(dstDropCount));
+ loc, insertOp.getDest(), splatZero(dstDropCount));
auto newOffsets = rewriter.getArrayAttr(
- insertOp.offsets().getValue().take_back(newDstType.getRank()));
+ insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
auto newStrides = rewriter.getArrayAttr(
- insertOp.strides().getValue().take_back(newSrcType.getRank()));
+ insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
@@ -138,10 +138,10 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getTransferRank() == 0)
return failure();
- if (read.mask())
+ if (read.getMask())
return failure();
- auto shapedType = read.source().getType().cast<ShapedType>();
+ auto shapedType = read.getSource().getType().cast<ShapedType>();
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -151,7 +151,7 @@ struct CastAwayTransferReadLeadingOneDim
if (newType == oldType)
return failure();
- AffineMap oldMap = read.permutation_map();
+ AffineMap oldMap = read.getPermutationMap();
ArrayRef<AffineExpr> newResults =
oldMap.getResults().take_back(newType.getRank());
AffineMap newMap =
@@ -159,13 +159,13 @@ struct CastAwayTransferReadLeadingOneDim
rewriter.getContext());
ArrayAttr inBoundsAttr;
- if (read.in_bounds())
+ if (read.getInBounds())
inBoundsAttr = rewriter.getArrayAttr(
- read.in_boundsAttr().getValue().take_back(newType.getRank()));
+ read.getInBoundsAttr().getValue().take_back(newType.getRank()));
auto newRead = rewriter.create<vector::TransferReadOp>(
- read.getLoc(), newType, read.source(), read.indices(),
- AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
+ read.getLoc(), newType, read.getSource(), read.getIndices(),
+ AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
@@ -186,10 +186,10 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getTransferRank() == 0)
return failure();
- if (write.mask())
+ if (write.getMask())
return failure();
- auto shapedType = write.source().getType().dyn_cast<ShapedType>();
+ auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -199,7 +199,7 @@ struct CastAwayTransferWriteLeadingOneDim
return failure();
int64_t dropDim = oldType.getRank() - newType.getRank();
- AffineMap oldMap = write.permutation_map();
+ AffineMap oldMap = write.getPermutationMap();
ArrayRef<AffineExpr> newResults =
oldMap.getResults().take_back(newType.getRank());
AffineMap newMap =
@@ -207,14 +207,14 @@ struct CastAwayTransferWriteLeadingOneDim
rewriter.getContext());
ArrayAttr inBoundsAttr;
- if (write.in_bounds())
+ if (write.getInBounds())
inBoundsAttr = rewriter.getArrayAttr(
- write.in_boundsAttr().getValue().take_back(newType.getRank()));
+ write.getInBoundsAttr().getValue().take_back(newType.getRank()));
auto newVector = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.vector(), splatZero(dropDim));
+ write.getLoc(), write.getVector(), splatZero(dropDim));
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- write, newVector, write.source(), write.indices(),
+ write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), inBoundsAttr);
return success();
@@ -237,7 +237,7 @@ struct CastAwayContractionLeadingOneDim
if (oldAccType.getRank() < 2)
return failure();
// TODO: implement masks.
- if (llvm::size(contractOp.masks()) != 0)
+ if (llvm::size(contractOp.getMasks()) != 0)
return failure();
if (oldAccType.getShape()[0] != 1)
return failure();
@@ -248,7 +248,7 @@ struct CastAwayContractionLeadingOneDim
auto oldIndexingMaps = contractOp.getIndexingMaps();
SmallVector<AffineMap> newIndexingMaps;
- auto oldIteratorTypes = contractOp.iterator_types();
+ auto oldIteratorTypes = contractOp.getIteratorTypes();
SmallVector<Attribute> newIteratorTypes;
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
@@ -264,8 +264,8 @@ struct CastAwayContractionLeadingOneDim
newIteratorTypes.push_back(it.value());
}
- SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(),
- contractOp.acc()};
+ SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
+ contractOp.getAcc()};
SmallVector<Value> newOperands;
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
@@ -336,7 +336,7 @@ struct CastAwayContractionLeadingOneDim
auto newContractOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
- rewriter.getArrayAttr(newIteratorTypes), contractOp.kind());
+ rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
contractOp, contractOp->getResultTypes()[0], newContractOp);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 4308fa6a43be9..2a384c3bf7853 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -62,7 +62,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
- if (op.offsets().getValue().empty())
+ if (op.getOffsets().getValue().empty())
return failure();
auto loc = op.getLoc();
@@ -74,21 +74,21 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
- Value extracted =
- rewriter.create<ExtractOp>(loc, op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
+ Value extracted = rewriter.create<ExtractOp>(
+ loc, op.getDest(),
+ getI64SubArray(op.getOffsets(), /*dropFront=*/0,
+ /*dropBack=*/rankRest));
// A
diff erent pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.source(), extracted,
- getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
- getI64SubArray(op.strides(), /*dropFront=*/0));
+ loc, op.getSource(), extracted,
+ getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
+ getI64SubArray(op.getStrides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
- op, stridedSliceInnerOp.getResult(), op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
+ op, stridedSliceInnerOp.getResult(), op.getDest(),
+ getI64SubArray(op.getOffsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
return success();
}
@@ -118,7 +118,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
- if (op.offsets().getValue().empty())
+ if (op.getOffsets().getValue().empty())
return failure();
int64_t srcRank = srcType.getRank();
@@ -128,18 +128,18 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
return failure();
if (srcType == dstType) {
- rewriter.replaceOp(op, op.source());
+ rewriter.replaceOp(op, op.getSource());
return success();
}
int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
+ op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
- Value res = op.dest();
+ Value res = op.getDest();
if (srcRank == 1) {
int nSrc = srcType.getShape().front();
@@ -148,8 +148,8 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
- Value scaledSource =
- rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
+ Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
+ op.getSource(), offsets);
// 2. Create a mask where we take the value from scaledSource of dest
// depending on the offset.
@@ -162,7 +162,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
}
// 3. Replace with a ShuffleOp.
- rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
offsets);
return success();
@@ -172,17 +172,17 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
- Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
+ Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
if (extractedSource.getType().isa<VectorType>()) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
- Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
+ Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
- getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
+ getI64SubArray(op.getOffsets(), /* dropFront=*/1),
+ getI64SubArray(op.getStrides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -212,27 +212,28 @@ class VectorExtractStridedSliceOpRewritePattern
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+ op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size =
+ op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
+ op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
- if (op.offsets().getValue().size() == 1) {
+ if (op.getOffsets().getValue().size() == 1) {
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
- rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
- op.vector(),
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
+ op.getVector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
@@ -243,11 +244,11 @@ class VectorExtractStridedSliceOpRewritePattern
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
- Value one = extractOne(rewriter, loc, op.vector(), off);
+ Value one = extractOne(rewriter, loc, op.getVector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.sizes(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
+ loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
+ getI64SubArray(op.getSizes(), /* dropFront=*/1),
+ getI64SubArray(op.getStrides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index db5c667a49355..07e24deb810fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -38,13 +38,13 @@ class InnerOuterDimReductionConversion
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
- auto src = multiReductionOp.source();
+ auto src = multiReductionOp.getSource();
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Separate reduction and parallel dims
auto reductionDimsRange =
- multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
+ multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
@@ -86,8 +86,8 @@ class InnerOuterDimReductionConversion
reductionMask[i] = true;
}
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
- multiReductionOp, transposeOp.result(), reductionMask,
- multiReductionOp.kind());
+ multiReductionOp, transposeOp.getResult(), reductionMask,
+ multiReductionOp.getKind());
return success();
}
@@ -186,17 +186,17 @@ class ReduceMultiDimReductionRank
auto castedType = VectorType::get(
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
Value cast = rewriter.create<vector::ShapeCastOp>(
- loc, castedType, multiReductionOp.source());
+ loc, castedType, multiReductionOp.getSource());
// 5. Creates the flattened form of vector.multi_reduction with inner/outer
// most dim as reduction.
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, mask, multiReductionOp.kind());
+ loc, cast, mask, multiReductionOp.getKind());
// 6. If there are no parallel shapes, the result is a scalar.
// TODO: support 0-d vectors when available.
if (parallelShapes.empty()) {
- rewriter.replaceOp(multiReductionOp, newOp.dest());
+ rewriter.replaceOp(multiReductionOp, newOp.getDest());
return success();
}
@@ -205,7 +205,7 @@ class ReduceMultiDimReductionRank
parallelShapes,
multiReductionOp.getSourceVectorType().getElementType());
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- multiReductionOp, outputCastedType, newOp.dest());
+ multiReductionOp, outputCastedType, newOp.getDest());
return success();
}
@@ -238,12 +238,12 @@ struct TwoDimMultiReductionToElementWise
return failure();
Value result =
- rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
+ rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0)
.getResult();
for (int64_t i = 1; i < srcShape[0]; i++) {
- auto operand =
- rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
- result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
+ auto operand = rewriter.create<vector::ExtractOp>(
+ loc, multiReductionOp.getSource(), i);
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
operand, result);
}
@@ -275,9 +275,9 @@ struct TwoDimMultiReductionToReduction
for (int i = 0; i < outerDim; ++i) {
auto v = rewriter.create<vector::ExtractOp>(
- loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
- auto reducedValue =
- rewriter.create<vector::ReductionOp>(loc, multiReductionOp.kind(), v);
+ loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
+ auto reducedValue = rewriter.create<vector::ReductionOp>(
+ loc, multiReductionOp.getKind(), v);
result = rewriter.create<vector::InsertElementOp>(
loc, reducedValue, result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
@@ -317,9 +317,9 @@ struct OneDimMultiReductionToTwoDim
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
Value cast = rewriter.create<vector::ShapeCastOp>(
- loc, castedType, multiReductionOp.source());
+ loc, castedType, multiReductionOp.getSource());
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, mask, multiReductionOp.kind());
+ loc, cast, mask, multiReductionOp.getKind());
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
ArrayRef<int64_t>{0});
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 931574641b752..364f09c1f5646 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -96,7 +96,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> reads;
Operation *firstOverwriteCandidate = nullptr;
- for (auto *user : write.source().getUsers()) {
+ for (auto *user : write.getSource().getUsers()) {
if (user == write.getOperation())
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
@@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
- for (Operation *user : read.source().getUsers()) {
+ for (Operation *user : read.getSource().getUsers()) {
if (isa<vector::TransferReadOp>(user))
continue;
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
@@ -207,7 +207,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
<< " to: " << *read.getOperation() << "\n");
- read.replaceAllUsesWith(lastwrite.vector());
+ read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
@@ -259,9 +259,9 @@ class TransferReadDropUnitDimsPattern
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
- Value vector = transferReadOp.vector();
+ Value vector = transferReadOp.getVector();
VectorType vectorType = vector.getType().cast<VectorType>();
- Value source = transferReadOp.source();
+ Value source = transferReadOp.getSource();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// TODO: support tensor types.
if (!sourceType || !sourceType.hasStaticShape())
@@ -271,7 +271,7 @@ class TransferReadDropUnitDimsPattern
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
- if (!transferReadOp.permutation_map().isMinorIdentity())
+ if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
@@ -279,7 +279,7 @@ class TransferReadDropUnitDimsPattern
if (reducedRank != vectorType.getRank())
return failure(); // This pattern requires the vector shape to match the
// reduced source shape.
- if (llvm::any_of(transferReadOp.indices(),
+ if (llvm::any_of(transferReadOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
Value reducedShapeSource =
@@ -302,9 +302,9 @@ class TransferWriteDropUnitDimsPattern
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
- Value vector = transferWriteOp.vector();
+ Value vector = transferWriteOp.getVector();
VectorType vectorType = vector.getType().cast<VectorType>();
- Value source = transferWriteOp.source();
+ Value source = transferWriteOp.getSource();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// TODO: support tensor type.
if (!sourceType || !sourceType.hasStaticShape())
@@ -314,7 +314,7 @@ class TransferWriteDropUnitDimsPattern
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
- if (!transferWriteOp.permutation_map().isMinorIdentity())
+ if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
@@ -322,7 +322,7 @@ class TransferWriteDropUnitDimsPattern
if (reducedRank != vectorType.getRank())
return failure(); // This pattern requires the vector shape to match the
// reduced source shape.
- if (llvm::any_of(transferWriteOp.indices(),
+ if (llvm::any_of(transferWriteOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
Value reducedShapeSource =
@@ -366,9 +366,9 @@ class FlattenContiguousRowMajorTransferReadPattern
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
- Value vector = transferReadOp.vector();
+ Value vector = transferReadOp.getVector();
VectorType vectorType = vector.getType().cast<VectorType>();
- Value source = transferReadOp.source();
+ Value source = transferReadOp.getSource();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// Contiguity check is valid on tensors only.
if (!sourceType)
@@ -386,11 +386,11 @@ class FlattenContiguousRowMajorTransferReadPattern
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
- if (!transferReadOp.permutation_map().isMinorIdentity())
+ if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
- if (transferReadOp.mask())
+ if (transferReadOp.getMask())
return failure();
- if (llvm::any_of(transferReadOp.indices(),
+ if (llvm::any_of(transferReadOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -418,9 +418,9 @@ class FlattenContiguousRowMajorTransferWritePattern
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
- Value vector = transferWriteOp.vector();
+ Value vector = transferWriteOp.getVector();
VectorType vectorType = vector.getType().cast<VectorType>();
- Value source = transferWriteOp.source();
+ Value source = transferWriteOp.getSource();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// Contiguity check is valid on tensors only.
if (!sourceType)
@@ -438,11 +438,11 @@ class FlattenContiguousRowMajorTransferWritePattern
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
- if (!transferWriteOp.permutation_map().isMinorIdentity())
+ if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
- if (transferWriteOp.mask())
+ if (transferWriteOp.getMask())
return failure();
- if (llvm::any_of(transferWriteOp.indices(),
+ if (llvm::any_of(transferWriteOp.getIndices(),
[](Value v) { return !isZero(v); }))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index baf6973be12e2..948814556b157 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -62,7 +62,7 @@ struct TransferReadPermutationLowering
return failure();
SmallVector<unsigned> permutation;
- AffineMap map = op.permutation_map();
+ AffineMap map = op.getPermutationMap();
if (map.getNumResults() == 0)
return failure();
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
@@ -85,7 +85,7 @@ struct TransferReadPermutationLowering
// Transpose mask operand.
Value newMask;
- if (op.mask()) {
+ if (op.getMask()) {
// Remove unused dims from the permutation map. E.g.:
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
@@ -99,22 +99,23 @@ struct TransferReadPermutationLowering
maskTransposeIndices.push_back(expr.getPosition());
}
- newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
+ newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
maskTransposeIndices);
}
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
- op.in_bounds() ? transposeInBoundsAttr(
- rewriter, op.in_bounds().getValue(), permutation)
- : ArrayAttr();
+ op.getInBounds()
+ ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(),
+ permutation)
+ : ArrayAttr();
// Generate new transfer_read operation.
VectorType newReadType =
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(),
- AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr);
+ op.getLoc(), newReadType, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -151,7 +152,7 @@ struct TransferWritePermutationLowering
return failure();
SmallVector<unsigned> permutation;
- AffineMap map = op.permutation_map();
+ AffineMap map = op.getPermutationMap();
if (map.isMinorIdentity())
return failure();
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
@@ -169,23 +170,24 @@ struct TransferWritePermutationLowering
});
// Transpose mask operand.
- Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
- op.getLoc(), op.mask(), indices)
- : Value();
+ Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
+ op.getLoc(), op.getMask(), indices)
+ : Value();
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
- op.in_bounds() ? transposeInBoundsAttr(
- rewriter, op.in_bounds().getValue(), permutation)
- : ArrayAttr();
+ op.getInBounds()
+ ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(),
+ permutation)
+ : ArrayAttr();
// Generate new transfer_write operation.
- Value newVec =
- rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
+ Value newVec = rewriter.create<vector::TransposeOp>(
+ op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
+ op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
newMask, newInBoundsAttr);
return success();
@@ -209,7 +211,7 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
if (op.getTransferRank() == 0)
return failure();
- AffineMap map = op.permutation_map();
+ AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
@@ -237,12 +239,12 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
if (reducedShapeRank == 0) {
Value newRead;
if (op.getShapedType().isa<TensorType>()) {
- newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(),
- op.indices());
+ newRead = rewriter.create<tensor::ExtractOp>(
+ op.getLoc(), op.getSource(), op.getIndices());
} else {
newRead = rewriter.create<memref::LoadOp>(
- op.getLoc(), originalVecType.getElementType(), op.source(),
- op.indices());
+ op.getLoc(), originalVecType.getElementType(), op.getSource(),
+ op.getIndices());
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
@@ -256,13 +258,14 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
VectorType newReadType =
VectorType::get(newShape, originalVecType.getElementType());
ArrayAttr newInBoundsAttr =
- op.in_bounds()
+ op.getInBounds()
? rewriter.getArrayAttr(
- op.in_boundsAttr().getValue().take_back(reducedShapeRank))
+ op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(),
- AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr);
+ op.getLoc(), newReadType, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+ newInBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index c457621ec61bc..5e090a6ccc718 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -249,7 +249,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
+ Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
loc, returnTypes, inBoundsCond,
[&](OpBuilder &b, Location loc) {
@@ -257,12 +257,12 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
if (compatibleMemRefType != xferOp.getShapedType())
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
- xferOp.indices().end());
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
- b.create<linalg::FillOp>(loc, ValueRange{xferOp.padding()},
+ b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
ValueRange{alloc});
// Take partial subview of memref which guarantees no dimension
// overflows.
@@ -304,7 +304,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
Location loc = xferOp.getLoc();
scf::IfOp fullPartialIfOp;
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
+ Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
loc, returnTypes, inBoundsCond,
[&](OpBuilder &b, Location loc) {
@@ -312,8 +312,8 @@ static scf::IfOp createFullPartialVectorTransferRead(
if (compatibleMemRefType != xferOp.getShapedType())
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
- xferOp.indices().end());
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
@@ -354,7 +354,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
+ Value memref = xferOp.getSource();
return b
.create<scf::IfOp>(
loc, returnTypes, inBoundsCond,
@@ -364,8 +364,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(),
- xferOp.indices().begin(),
- xferOp.indices().end());
+ xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
@@ -430,9 +430,10 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b,
b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
BlockAndValueMapping mapping;
Value load = b.create<memref::LoadOp>(
- loc, b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
- mapping.map(xferOp.vector(), load);
+ loc,
+ b.create<vector::TypeCastOp>(
+ loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
+ mapping.map(xferOp.getVector(), load);
b.clone(*xferOp.getOperation(), mapping);
b.create<scf::YieldOp>(loc, ValueRange{});
});
@@ -530,9 +531,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
if (!(xferReadOp || xferWriteOp))
return failure();
- if (xferWriteOp && xferWriteOp.mask())
+ if (xferWriteOp && xferWriteOp.getMask())
return failure();
- if (xferReadOp && xferReadOp.mask())
+ if (xferReadOp && xferReadOp.getMask())
return failure();
}
@@ -601,8 +602,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
// The operation is cloned to prevent deleting information needed for the
// later IR creation.
BlockAndValueMapping mapping;
- mapping.map(xferWriteOp.source(), memrefAndIndices.front());
- mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
+ mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
+ mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
auto *clone = b.clone(*xferWriteOp, mapping);
clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 32e2fa78174f3..2ca6481b18998 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -168,19 +168,19 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
- shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
+ shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
auto resultVectorType =
- shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
+ shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
if (!sourceVectorType || !resultVectorType)
return failure();
// Check if shape cast op source operand is also a shape cast op.
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
- shapeCastOp.source().getDefiningOp());
+ shapeCastOp.getSource().getDefiningOp());
if (!sourceShapeCastOp)
return failure();
auto operandSourceVectorType =
- sourceShapeCastOp.source().getType().cast<VectorType>();
+ sourceShapeCastOp.getSource().getType().cast<VectorType>();
auto operandResultVectorType = sourceShapeCastOp.getType();
// Check if shape cast operations invert each other.
@@ -188,7 +188,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
operandResultVectorType != sourceVectorType)
return failure();
- rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
+ rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
return success();
}
};
@@ -207,7 +207,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Scalar to any vector can use splat.
if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
return success();
}
@@ -219,9 +219,9 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
+ ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
else
- ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+ ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
@@ -240,7 +240,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value bcst =
- rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
+ rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
@@ -260,7 +260,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// All trailing dimensions are the same. Simply pass through.
if (m == -1) {
- rewriter.replaceOp(op, op.source());
+ rewriter.replaceOp(op, op.getSource());
return success();
}
@@ -285,14 +285,14 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
loc, dstType, rewriter.getZeroAttr(dstType));
if (m == 0) {
// Stetch at start.
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
} else {
// Stetch not at start.
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
}
@@ -338,13 +338,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- Value input = op.vector();
+ Value input = op.getVector();
VectorType inputType = op.getVectorType();
VectorType resType = op.getResultType();
// Set up convenience transposition table.
SmallVector<int64_t, 4> transp;
- for (auto attr : op.transp())
+ for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (vectorTransformOptions.vectorTransposeLowering ==
@@ -433,7 +433,7 @@ class TransposeOp2DToShuffleLowering
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
SmallVector<int64_t, 4> transp;
- for (auto attr : op.transp())
+ for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
@@ -444,7 +444,8 @@ class TransposeOp2DToShuffleLowering
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
Value casted = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
+ loc, VectorType::get({m * n}, srcType.getElementType()),
+ op.getVector());
SmallVector<int64_t> mask;
mask.reserve(m * n);
for (int64_t j = 0; j < n; ++j)
@@ -490,15 +491,15 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
VectorType resType = op.getVectorType();
Type eltType = resType.getElementType();
bool isInt = eltType.isa<IntegerType, IndexType>();
- Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
- vector::CombiningKind kind = op.kind();
+ Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
+ vector::CombiningKind kind = op.getKind();
if (!rhsType) {
// Special case: AXPY operation.
- Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
+ Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
Optional<Value> mult =
- isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
- : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
+ isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter)
+ : genMultF(loc, op.getLhs(), b, acc, kind, rewriter);
if (!mult.hasValue())
return failure();
rewriter.replaceOp(op, mult.getValue());
@@ -509,13 +510,15 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
- Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
+ Value x =
+ rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
- Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
- : genMultF(loc, a, op.rhs(), r, kind, rewriter);
+ Optional<Value> m =
+ isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter)
+ : genMultF(loc, a, op.getRhs(), r, kind, rewriter);
if (!m.hasValue())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
@@ -588,7 +591,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
auto loc = op.getLoc();
auto dstType = op.getType();
auto eltType = dstType.getElementType();
- auto dimSizes = op.mask_dim_sizes();
+ auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
if (rank == 0) {
@@ -715,7 +718,7 @@ class ShapeCastOp2DDownCastRewritePattern
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
- Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
+ Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
desc = rewriter.create<vector::InsertStridedSliceOp>(
loc, vec, desc,
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
@@ -749,7 +752,7 @@ class ShapeCastOp2DUpCastRewritePattern
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
+ loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
/*sizes=*/mostMinorVectorSize,
/*strides=*/1);
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
@@ -804,7 +807,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
}
- Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
+ Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
}
rewriter.replaceOp(op, result);
@@ -844,9 +847,9 @@ struct MultiReduceToContract
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
- if (reduceOp.kind() != vector::CombiningKind::ADD)
+ if (reduceOp.getKind() != vector::CombiningKind::ADD)
return failure();
- Operation *mulOp = reduceOp.source().getDefiningOp();
+ Operation *mulOp = reduceOp.getSource().getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return failure();
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
@@ -905,8 +908,8 @@ struct CombineContractTranspose
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMaps());
- Value lhs = contractOp.lhs();
- Value rhs = contractOp.rhs();
+ Value lhs = contractOp.getLhs();
+ Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
@@ -917,17 +920,17 @@ struct CombineContractTranspose
SmallVector<int64_t> perm;
transposeOp.getTransp(perm);
AffineMap permutationMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(transposeOp.transp()),
+ extractVector<unsigned>(transposeOp.getTransp()),
contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
- *operand = transposeOp.vector();
+ *operand = transposeOp.getVector();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
- contractOp, lhs, rhs, contractOp.acc(),
- rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
+ contractOp, lhs, rhs, contractOp.getAcc(),
+ rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
return success();
}
};
@@ -962,8 +965,8 @@ struct CombineContractBroadcast
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMaps());
- Value lhs = contractOp.lhs();
- Value rhs = contractOp.rhs();
+ Value lhs = contractOp.getLhs();
+ Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
@@ -996,14 +999,14 @@ struct CombineContractBroadcast
AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
contractOp.getContext());
map = broadcastMap.compose(map);
- *operand = broadcast.source();
+ *operand = broadcast.getSource();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
- contractOp, lhs, rhs, contractOp.acc(),
- rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
+ contractOp, lhs, rhs, contractOp.getAcc(),
+ rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
return success();
}
};
@@ -1036,8 +1039,9 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
castResTy = VectorType::get(vecTy.getShape(), castResTy);
- auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- bcastOp.source(), castResTy, op->getAttrs());
+ auto castOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+ bcastOp.getSource(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, op->getResult(0).getType(), castOp->getResult(0));
return success();
@@ -1075,8 +1079,9 @@ struct ReorderCastOpsOnTranspose
auto castResTy = transpOp.getVectorType();
castResTy = VectorType::get(castResTy.getShape(),
getElementTypeOrSelf(op->getResult(0)));
- auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- transpOp.vector(), castResTy, op->getAttrs());
+ auto castOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+ transpOp.getVector(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
op, op->getResult(0).getType(), castOp->getResult(0),
transpOp.getTransp());
@@ -1127,7 +1132,7 @@ LogicalResult
ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rew) const {
// TODO: implement masks
- if (llvm::size(op.masks()) != 0)
+ if (llvm::size(op.getMasks()) != 0)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
@@ -1135,7 +1140,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
if (failed(filter(op)))
return failure();
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypes().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
@@ -1152,16 +1157,16 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
AffineExpr m, n, k;
bindDims(rew.getContext(), m, n, k);
// LHS must be A(m, k) or A(k, m).
- Value lhs = op.lhs();
- auto lhsMap = op.indexing_maps()[0];
+ Value lhs = op.getLhs();
+ auto lhsMap = op.getIndexingMaps()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
return failure();
// RHS must be B(k, n) or B(n, k).
- Value rhs = op.rhs();
- auto rhsMap = op.indexing_maps()[1];
+ Value rhs = op.getRhs();
+ auto rhsMap = op.getIndexingMaps()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
@@ -1187,11 +1192,11 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
- getElementTypeOrSelf(op.acc().getType())),
+ getElementTypeOrSelf(op.getAcc().getType())),
mul);
// ACC must be C(m, n) or C(n, m).
- auto accMap = op.indexing_maps()[2];
+ auto accMap = op.getIndexingMaps()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
@@ -1199,8 +1204,9 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
Value res =
elementType.isa<IntegerType>()
- ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
- : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
+ ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+ : static_cast<Value>(
+ rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
rew.replaceOp(op, res);
return success();
@@ -1226,11 +1232,10 @@ struct Red : public IteratorType {
/// This unrolls outer-products along the reduction dimension.
struct UnrolledOuterProductGenerator
: public StructuredGenerator<vector::ContractionOp> {
-
UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
: StructuredGenerator<vector::ContractionOp>(builder, op),
- kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
- lhsType(op.getLhsType()) {}
+ kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
+ res(op.getAcc()), lhsType(op.getLhsType()) {}
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
@@ -1356,7 +1361,7 @@ struct UnrolledOuterProductGenerator
LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
vector::ContractionOp op, PatternRewriter &rewriter) const {
// TODO: implement masks
- if (llvm::size(op.masks()) != 0)
+ if (llvm::size(op.getMasks()) != 0)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
@@ -1390,7 +1395,7 @@ LogicalResult
ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
// TODO: implement masks
- if (llvm::size(op.masks()) != 0)
+ if (llvm::size(op.getMasks()) != 0)
return failure();
if (failed(filter(op)))
@@ -1400,10 +1405,10 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
vector::VectorContractLowering::Dot)
return failure();
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypes().getValue();
static constexpr std::array<int64_t, 2> perm = {1, 0};
Location loc = op.getLoc();
- Value lhs = op.lhs(), rhs = op.rhs();
+ Value lhs = op.getLhs(), rhs = op.getRhs();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
@@ -1495,7 +1500,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
}
}
- if (auto acc = op.acc())
+ if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
rewriter.replaceOp(op, res);
return success();
@@ -1522,7 +1527,7 @@ LogicalResult
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
// TODO: implement masks.
- if (llvm::size(op.masks()) != 0)
+ if (llvm::size(op.getMasks()) != 0)
return failure();
if (failed(filter(op)))
@@ -1627,15 +1632,15 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
- rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+ rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) {
- auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
- auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
- auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
+ auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
+ auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
Value lowContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, lowAffine, lowIter);
result =
@@ -1667,10 +1672,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
- Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
+ Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
- if (auto acc = op.acc())
+ if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
return res;
}
@@ -1681,15 +1686,15 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
- rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+ rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops.
// By feeding the initial accumulator into the first contraction,
// and the result of each contraction into the next, eventually
// the sum of all reductions is computed.
- Value result = op.acc();
+ Value result = op.getAcc();
for (int64_t d = 0; d < dimSize; ++d) {
- auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
- auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+ auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
lowAffine, lowIter);
}
@@ -1753,7 +1758,7 @@ struct TransferReadToVectorLoadLowering
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
- if (!read.permutation_map().isMinorIdentityWithBroadcasting(
+ if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return failure();
@@ -1792,16 +1797,16 @@ struct TransferReadToVectorLoadLowering
// Create vector load op.
Operation *loadOp;
- if (read.mask()) {
+ if (read.getMask()) {
Value fill = rewriter.create<vector::SplatOp>(
- read.getLoc(), unbroadcastedVectorType, read.padding());
+ read.getLoc(), unbroadcastedVectorType, read.getPadding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
- read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
- read.mask(), fill);
+ read.getLoc(), unbroadcastedVectorType, read.getSource(),
+ read.getIndices(), read.getMask(), fill);
} else {
- loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
- unbroadcastedVectorType,
- read.source(), read.indices());
+ loadOp = rewriter.create<vector::LoadOp>(
+ read.getLoc(), unbroadcastedVectorType, read.getSource(),
+ read.getIndices());
}
// Insert a broadcasting op if required.
@@ -1836,7 +1841,7 @@ struct VectorLoadToMemrefLoadLowering
if (vecType.getNumElements() != 1)
return failure();
auto memrefLoad = rewriter.create<memref::LoadOp>(
- loadOp.getLoc(), loadOp.base(), loadOp.indices());
+ loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
@@ -1857,15 +1862,15 @@ struct VectorStoreToMemrefStoreLowering
if (vecType.getRank() == 0) {
// TODO: Unifiy once ExtractOp supports 0-d vectors.
extracted = rewriter.create<vector::ExtractElementOp>(
- storeOp.getLoc(), storeOp.valueToStore());
+ storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.valueToStore(), indices);
+ storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
rewriter.replaceOpWithNewOp<memref::StoreOp>(
- storeOp, extracted, storeOp.base(), storeOp.indices());
+ storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};
@@ -1893,7 +1898,7 @@ struct TransferWriteToVectorStoreLowering
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
if ( // pass-through for the 0-d corner case.
- !write.permutation_map().isMinorIdentity())
+ !write.getPermutationMap().isMinorIdentity())
return failure();
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
@@ -1918,12 +1923,13 @@ struct TransferWriteToVectorStoreLowering
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (write.hasOutOfBoundsDim())
return failure();
- if (write.mask()) {
+ if (write.getMask()) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- write, write.source(), write.indices(), write.mask(), write.vector());
+ write, write.getSource(), write.getIndices(), write.getMask(),
+ write.getVector());
} else {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- write, write.vector(), write.source(), write.indices());
+ write, write.getVector(), write.getSource(), write.getIndices());
}
return success();
}
@@ -1957,7 +1963,7 @@ struct BubbleDownVectorBitCastForExtract
if (extractOp.getVectorType().getRank() != 1)
return failure();
- auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+ auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
@@ -1983,14 +1989,14 @@ struct BubbleDownVectorBitCastForExtract
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
};
- uint64_t index = getFirstIntValue(extractOp.position());
+ uint64_t index = getFirstIntValue(extractOp.getPosition());
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
VectorType oneScalarType =
VectorType::get({1}, castSrcType.getElementType());
Value packedValue = rewriter.create<vector::ExtractOp>(
- extractOp.getLoc(), oneScalarType, castOp.source(),
+ extractOp.getLoc(), oneScalarType, castOp.getSource(),
rewriter.getI64ArrayAttr(index / expandRatio));
// Cast it to a vector with the desired scalar's type.
@@ -2027,7 +2033,7 @@ struct BubbleDownBitCastForStridedSliceExtract
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
- auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+ auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
@@ -2042,7 +2048,7 @@ struct BubbleDownBitCastForStridedSliceExtract
return failure();
// Only accept all one strides for now.
- if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
+ if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
@@ -2054,7 +2060,7 @@ struct BubbleDownBitCastForStridedSliceExtract
// are selecting the full range for the last bitcasted dimension; other
// dimensions aren't affected. Otherwise, we need to scale down the last
// dimension's offset given we are extracting from less elements now.
- ArrayAttr newOffsets = extractOp.offsets();
+ ArrayAttr newOffsets = extractOp.getOffsets();
if (newOffsets.size() == rank) {
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % expandRatio != 0)
@@ -2064,7 +2070,7 @@ struct BubbleDownBitCastForStridedSliceExtract
}
// Similarly for sizes.
- ArrayAttr newSizes = extractOp.sizes();
+ ArrayAttr newSizes = extractOp.getSizes();
if (newSizes.size() == rank) {
SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
if (sizes.back() % expandRatio != 0)
@@ -2080,8 +2086,8 @@ struct BubbleDownBitCastForStridedSliceExtract
VectorType::get(dims, castSrcType.getElementType());
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
- newSizes, extractOp.strides());
+ extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
+ newSizes, extractOp.getStrides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
@@ -2120,12 +2126,12 @@ struct BubbleUpBitCastForStridedSliceInsert
int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
auto insertOp =
- bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
+ bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
if (!insertOp)
return failure();
// Only accept all one strides for now.
- if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
+ if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
@@ -2135,7 +2141,7 @@ struct BubbleUpBitCastForStridedSliceInsert
if (rank != insertOp.getDestVectorType().getRank())
return failure();
- ArrayAttr newOffsets = insertOp.offsets();
+ ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % shrinkRatio != 0)
@@ -2150,7 +2156,7 @@ struct BubbleUpBitCastForStridedSliceInsert
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastSrcType, insertOp.source());
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
SmallVector<int64_t, 4> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
@@ -2159,11 +2165,11 @@ struct BubbleUpBitCastForStridedSliceInsert
VectorType::get(dstDims, castDstType.getElementType());
auto newCastDstOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastDstType, insertOp.dest());
+ bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
- insertOp.strides());
+ insertOp.getStrides());
return success();
}
@@ -2229,7 +2235,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
return failure();
if (xferOp.getVectorType().getRank() > 1 ||
- llvm::size(xferOp.indices()) == 0)
+ llvm::size(xferOp.getIndices()) == 0)
return failure();
Location loc = xferOp->getLoc();
@@ -2240,24 +2246,24 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
- unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
- Value off = xferOp.indices()[lastIndex];
+ unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
+ Value off = xferOp.getIndices()[lastIndex];
Value dim =
- vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
+ vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
vtp.getNumScalableDims()),
b);
- if (xferOp.mask()) {
+ if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
- mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
+ mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
}
rewriter.updateRootInPlace(xferOp, [&]() {
- xferOp.maskMutable().assign(mask);
- xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
+ xferOp.getMaskMutable().assign(mask);
+ xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
});
return success();
@@ -2306,14 +2312,14 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
return failure();
// TODO: support mask.
- if (readOp.mask())
+ if (readOp.getMask())
return failure();
- auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
+ auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
if (!srcType || !srcType.hasStaticShape())
return failure();
- if (!readOp.permutation_map().isMinorIdentity())
+ if (!readOp.getPermutationMap().isMinorIdentity())
return failure();
auto targetType = readOp.getVectorType();
@@ -2366,19 +2372,19 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
SmallVector<int64_t> strides(srcType.getRank(), 1);
ArrayAttr inBoundsAttr =
- readOp.in_bounds()
+ readOp.getInBounds()
? rewriter.getArrayAttr(
- readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
+ readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
- loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
+ loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
strides);
auto permMap = getTransferMinorIdentityMap(
rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
- readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
- readOp.padding(),
+ readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+ readOp.getPadding(),
// TODO: support mask.
/*mask=*/Value(), inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
@@ -2514,14 +2520,14 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
ArrayRef<int64_t> destShape = destType.getShape();
auto elType = destType.getElementType();
bool isInt = elType.isIntOrIndex();
- if (!isValidKind(isInt, scanOp.kind()))
+ if (!isValidKind(isInt, scanOp.getKind()))
return failure();
VectorType resType = VectorType::get(destShape, elType);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
- int64_t reductionDim = scanOp.reduction_dim();
- bool inclusive = scanOp.inclusive();
+ int64_t reductionDim = scanOp.getReductionDim();
+ bool inclusive = scanOp.getInclusive();
int64_t destRank = destType.getRank();
VectorType initialValueType = scanOp.getInitialValueType();
int64_t initialValueRank = initialValueType.getRank();
@@ -2541,7 +2547,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
offsets[reductionDim] = i;
ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
Value input = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
+ loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
scanStrides);
Value output;
if (i == 0) {
@@ -2551,15 +2557,15 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
if (initialValueRank == 0) {
// ShapeCastOp cannot handle 0-D vectors
output = rewriter.create<vector::BroadcastOp>(
- loc, input.getType(), scanOp.initial_value());
+ loc, input.getType(), scanOp.getInitialValue());
} else {
output = rewriter.create<vector::ShapeCastOp>(
- loc, input.getType(), scanOp.initial_value());
+ loc, input.getType(), scanOp.getInitialValue());
}
}
} else {
Value y = inclusive ? input : lastInput;
- output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
+ output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
assert(output != nullptr);
}
result = rewriter.create<vector::InsertStridedSliceOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 2d1e7c1eac89a..2b730182d2088 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -112,7 +112,7 @@ struct UnrollTransferReadPattern
// TODO: support 0-d corner case.
if (readOp.getTransferRank() == 0)
return failure();
- if (readOp.mask())
+ if (readOp.getMask())
return failure();
auto targetShape = getTargetShape(options, readOp);
if (!targetShape)
@@ -129,16 +129,16 @@ struct UnrollTransferReadPattern
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
- SmallVector<Value, 4> originalIndices(readOp.indices().begin(),
- readOp.indices().end());
+ SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
+ readOp.getIndices().end());
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<Value, 4> indices =
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
- readOp.permutation_map(), loc, rewriter);
+ readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
- loc, targetType, readOp.source(), indices,
- readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
- readOp.in_boundsAttr());
+ loc, targetType, readOp.getSource(), indices,
+ readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
+ readOp.getInBoundsAttr());
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
@@ -165,7 +165,7 @@ struct UnrollTransferWritePattern
if (writeOp.getTransferRank() == 0)
return failure();
- if (writeOp.mask())
+ if (writeOp.getMask())
return failure();
auto targetShape = getTargetShape(options, writeOp);
if (!targetShape)
@@ -177,21 +177,21 @@ struct UnrollTransferWritePattern
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);
- SmallVector<Value, 4> originalIndices(writeOp.indices().begin(),
- writeOp.indices().end());
+ SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
+ writeOp.getIndices().end());
Value resultTensor;
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, writeOp.vector(), elementOffsets, *targetShape, strides);
+ loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value, 4> indices =
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
- writeOp.permutation_map(), loc, rewriter);
+ writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
- loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
- indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
+ loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
+ indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
// For the tensor case update the destination for the next transfer write.
if (!slicedWrite->getResults().empty())
resultTensor = slicedWrite->getResult(0);
@@ -267,19 +267,21 @@ struct UnrollContractionPattern
AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
SmallVector<int64_t> lhsOffets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
- extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
+ extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
// If there is a mask associated to lhs, extract it as well.
if (slicesOperands.size() > 3)
- extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets);
+ extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
+ lhsOffets);
// Extract the new rhs operand.
AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
SmallVector<int64_t> rhsOffets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
- extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
+ extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
// If there is a mask associated to rhs, extract it as well.
if (slicesOperands.size() > 4)
- extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets);
+ extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
+ rhsOffets);
AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
SmallVector<int64_t> accOffets =
@@ -290,7 +292,7 @@ struct UnrollContractionPattern
if (accIt != accCache.end())
slicesOperands[2] = accIt->second;
else
- extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
+ extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
SmallVector<int64_t> dstShape =
applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
@@ -367,8 +369,8 @@ struct UnrollMultiReductionPattern
// reduction loop keeps updating the accumulator.
auto accIt = accCache.find(destOffset);
if (accIt != accCache.end())
- result = makeArithReduction(rewriter, loc, reductionOp.kind(), result,
- accIt->second);
+ result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
+ result, accIt->second);
accCache[destOffset] = result;
}
// Assemble back the accumulator into a single vector.
@@ -451,7 +453,7 @@ struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
- Operation *definedOp = extract.vector().getDefiningOp();
+ Operation *definedOp = extract.getVector().getDefiningOp();
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
definedOp->getNumResults() != 1)
return failure();
@@ -467,7 +469,7 @@ struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
loc,
VectorType::get(extract.getResultType().getShape(),
vecType.getElementType()),
- operand.get(), extract.ids()));
+ operand.get(), extract.getIds()));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, definedOp, extractOperands, extract.getResultType());
@@ -482,7 +484,7 @@ struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
- Operation *definedOp = extract.vector().getDefiningOp();
+ Operation *definedOp = extract.getVector().getDefiningOp();
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
if (!contract)
return failure();
@@ -514,7 +516,7 @@ struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
VectorType newVecType =
VectorType::get(operandShape, vecType.getElementType());
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
- loc, newVecType, operand, extract.ids()));
+ loc, newVecType, operand, extract.getIds()));
}
Operation *newOp =
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
@@ -554,11 +556,12 @@ struct TransferReadExtractPattern
dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
if (!extract)
return failure();
- if (read.mask())
+ if (read.getMask())
return failure();
- SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
- AffineMap indexMap = extract.map().compose(read.permutation_map());
+ SmallVector<Value, 4> indices(read.getIndices().begin(),
+ read.getIndices().end());
+ AffineMap indexMap = extract.map().compose(read.getPermutationMap());
unsigned idCount = 0;
ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
for (auto it :
@@ -574,14 +577,15 @@ struct TransferReadExtractPattern
extract.getResultType().getDimSize(vectorPos), read.getContext());
indices[indexPos] = makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
- {indices[indexPos], extract.ids()[idCount++]});
+ {indices[indexPos], extract.getIds()[idCount++]});
}
Value newRead = lb.create<vector::TransferReadOp>(
- extract.getType(), read.source(), indices, read.permutation_mapAttr(),
- read.padding(), read.mask(), read.in_boundsAttr());
+ extract.getType(), read.getSource(), indices,
+ read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+ read.getInBoundsAttr());
Value dest = lb.create<arith::ConstantOp>(
read.getType(), rewriter.getZeroAttr(read.getType()));
- newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
+ newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
rewriter.replaceOp(read, newRead);
return success();
}
@@ -597,14 +601,14 @@ struct TransferWriteInsertPattern
if (write.getTransferRank() == 0)
return failure();
- auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
+ auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
if (!insert)
return failure();
- if (write.mask())
+ if (write.getMask())
return failure();
- SmallVector<Value, 4> indices(write.indices().begin(),
- write.indices().end());
- AffineMap indexMap = insert.map().compose(write.permutation_map());
+ SmallVector<Value, 4> indices(write.getIndices().begin(),
+ write.getIndices().end());
+ AffineMap indexMap = insert.map().compose(write.getPermutationMap());
unsigned idCount = 0;
Location loc = write.getLoc();
for (auto it :
@@ -619,13 +623,13 @@ struct TransferWriteInsertPattern
auto scale = getAffineConstantExpr(
insert.getSourceVectorType().getDimSize(vectorPos),
write.getContext());
- indices[indexPos] =
- makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
- {indices[indexPos], insert.ids()[idCount++]});
+ indices[indexPos] = makeComposedAffineApply(
+ rewriter, loc, d0 + scale * d1,
+ {indices[indexPos], insert.getIds()[idCount++]});
}
rewriter.create<vector::TransferWriteOp>(
- loc, insert.vector(), write.source(), indices,
- write.permutation_mapAttr(), write.in_boundsAttr());
+ loc, insert.getVector(), write.getSource(), indices,
+ write.getPermutationMapAttr(), write.getInBoundsAttr());
rewriter.eraseOp(write);
return success();
}
@@ -654,7 +658,7 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t> strides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.vector(), offsets, *targetShape, strides);
+ loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
@@ -664,7 +668,7 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
accumulator = result;
} else {
// On subsequent reduction, combine with the accumulator.
- accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
+ accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
accumulator, result);
}
}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 065848d003d9c..1346256bb0e11 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -264,7 +264,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(op, "Unsupported vector type");
SmallVector<int64_t, 4> transp;
- for (auto attr : op.transp())
+ for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
// Check whether the two source vector dimensions that are greater than one
@@ -289,7 +289,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType::get({n * m}, op.getVectorType().getElementType());
auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
auto reshInput =
- ib.create<vector::ShapeCastOp>(flattenedType, op.vector());
+ ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
// Extract 1-D vectors from the higher-order dimension of the input
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 59b5891263f50..03fc3a88c1d91 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -86,7 +86,7 @@ struct TestVectorToVectorLowering
dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
- auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
+ auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return llvm::None;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
More information about the Mlir-commits
mailing list