[Mlir-commits] [mlir] b47be47 - [mlir][Vector] Switch ExtractOp to the declarative assembly format
Benjamin Kramer
llvmlistbot at llvm.org
Fri Feb 18 03:00:51 PST 2022
Author: Benjamin Kramer
Date: 2022-02-18T11:45:59+01:00
New Revision: b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5
URL: https://github.com/llvm/llvm-project/commit/b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5
DIFF: https://github.com/llvm/llvm-project/commit/b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5.diff
LOG: [mlir][Vector] Switch ExtractOp to the declarative assembly format
This is a bit awkward since ExtractOp allows both `f32` and
`vector<1xf32>` results for a scalar extraction. Allow both, but make
inference return the scalar to make this as NFC as possible.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1e16dbbb97295..4a20ea0dc4d10 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -551,7 +551,8 @@ def Vector_ExtractElementOp :
def Vector_ExtractOp :
Vector_Op<"extract", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
@@ -577,9 +578,10 @@ def Vector_ExtractOp :
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
}
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
+ let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
let hasCanonicalizer = 1;
- let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 560746453a079..4ffb2b8c75696 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -940,21 +940,9 @@ LogicalResult vector::ExtractElementOp::verify() {
// ExtractOp
//===----------------------------------------------------------------------===//
-static Type inferExtractOpResultType(VectorType vectorType,
- ArrayAttr position) {
- if (static_cast<int64_t>(position.size()) == vectorType.getRank())
- return vectorType.getElementType();
- return VectorType::get(vectorType.getShape().drop_front(position.size()),
- vectorType.getElementType());
-}
-
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<int64_t> position) {
- result.addOperands(source);
- auto positionAttr = getVectorSubscriptAttr(builder, position);
- result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
- positionAttr));
- result.addAttribute(getPositionAttrStrName(), positionAttr);
+ build(builder, result, source, getVectorSubscriptAttr(builder, position));
}
// Convenience builder which assumes the values are constant indices.
@@ -967,40 +955,34 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, source, positionConstants);
}
-void vector::ExtractOp::print(OpAsmPrinter &p) {
- p << " " << vector() << position();
- p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
- p << " : " << vector().getType();
+LogicalResult
+ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ ExtractOp::Adaptor op(operands, attributes);
+ auto vectorType = op.vector().getType().cast<VectorType>();
+ if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) {
+ inferredReturnTypes.push_back(vectorType.getElementType());
+ } else {
+ auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1);
+ inferredReturnTypes.push_back(VectorType::get(
+ vectorType.getShape().drop_front(n), vectorType.getElementType()));
+ }
+ return success();
}
-ParseResult vector::ExtractOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SMLoc attributeLoc, typeLoc;
- NamedAttrList attrs;
- OpAsmParser::OperandType vector;
- Type type;
- Attribute attr;
- if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
- parser.parseAttribute(attr, "position", attrs) ||
- parser.parseOptionalAttrDict(attrs) ||
- parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
- return failure();
-
- auto vectorType = type.dyn_cast<VectorType>();
- if (!vectorType)
- return parser.emitError(typeLoc, "expected vector type");
-
- auto positionAttr = attr.dyn_cast<ArrayAttr>();
- if (!positionAttr ||
- static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
- return parser.emitError(
- attributeLoc,
- "expected position attribute of rank smaller than vector rank");
-
- Type resType = inferExtractOpResultType(vectorType, positionAttr);
- result.attributes = attrs;
- return failure(parser.resolveOperand(vector, type, result.operands) ||
- parser.addTypeToList(resType, result.types));
+bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // Allow extracting 1-element vectors instead of scalars.
+ auto isCompatible = [](TypeRange l, TypeRange r) {
+ auto vectorType = l.front().dyn_cast<VectorType>();
+ return vectorType && vectorType.getShape().equals({1}) &&
+ vectorType.getElementType() == r.front();
+ };
+ if (l.size() == 1 && r.size() == 1 &&
+ (isCompatible(l, r) || isCompatible(r, l)))
+ return true;
+ return l == r;
}
LogicalResult vector::ExtractOp::verify() {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index bc75e0bbe8b2e..2e224f7f58ebe 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -104,7 +104,7 @@ func @extract_element(%arg0: vector<4x4xf32>) {
// -----
func @extract_vector_type(%arg0: index) {
- // expected-error at +1 {{expected vector type}}
+ // expected-error at +1 {{invalid kind of type specified}}
%1 = vector.extract %arg0[] : index
}
More information about the Mlir-commits
mailing list