[Mlir-commits] [mlir] b47be47 - [mlir][Vector] Switch ExtractOp to the declarative assembly format

Benjamin Kramer llvmlistbot at llvm.org
Fri Feb 18 03:00:51 PST 2022


Author: Benjamin Kramer
Date: 2022-02-18T11:45:59+01:00
New Revision: b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5

URL: https://github.com/llvm/llvm-project/commit/b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5
DIFF: https://github.com/llvm/llvm-project/commit/b47be47ac2871b6f63f4acbaea8fa5e311f1ecc5.diff

LOG: [mlir][Vector] Switch ExtractOp to the declarative assembly format

This is a bit awkward since ExtractOp allows both `f32` and
`vector<1xf32>` results for a scalar extraction. Allow both, but make
inference return the scalar to make this as NFC as possible.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1e16dbbb97295..4a20ea0dc4d10 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -551,7 +551,8 @@ def Vector_ExtractElementOp :
 def Vector_ExtractOp :
   Vector_Op<"extract", [NoSideEffect,
      PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
     Results<(outs AnyType)> {
   let summary = "extract operation";
@@ -577,9 +578,10 @@ def Vector_ExtractOp :
     VectorType getVectorType() {
       return vector().getType().cast<VectorType>();
     }
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
+  let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
   let hasCanonicalizer = 1;
-  let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 560746453a079..4ffb2b8c75696 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -940,21 +940,9 @@ LogicalResult vector::ExtractElementOp::verify() {
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
-static Type inferExtractOpResultType(VectorType vectorType,
-                                     ArrayAttr position) {
-  if (static_cast<int64_t>(position.size()) == vectorType.getRank())
-    return vectorType.getElementType();
-  return VectorType::get(vectorType.getShape().drop_front(position.size()),
-                         vectorType.getElementType());
-}
-
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, ArrayRef<int64_t> position) {
-  result.addOperands(source);
-  auto positionAttr = getVectorSubscriptAttr(builder, position);
-  result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
-                                           positionAttr));
-  result.addAttribute(getPositionAttrStrName(), positionAttr);
+  build(builder, result, source, getVectorSubscriptAttr(builder, position));
 }
 
 // Convenience builder which assumes the values are constant indices.
@@ -967,40 +955,34 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, source, positionConstants);
 }
 
-void vector::ExtractOp::print(OpAsmPrinter &p) {
-  p << " " << vector() << position();
-  p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
-  p << " : " << vector().getType();
+LogicalResult
+ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+                            ValueRange operands, DictionaryAttr attributes,
+                            RegionRange,
+                            SmallVectorImpl<Type> &inferredReturnTypes) {
+  ExtractOp::Adaptor op(operands, attributes);
+  auto vectorType = op.vector().getType().cast<VectorType>();
+  if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) {
+    inferredReturnTypes.push_back(vectorType.getElementType());
+  } else {
+    auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1);
+    inferredReturnTypes.push_back(VectorType::get(
+        vectorType.getShape().drop_front(n), vectorType.getElementType()));
+  }
+  return success();
 }
 
-ParseResult vector::ExtractOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  SMLoc attributeLoc, typeLoc;
-  NamedAttrList attrs;
-  OpAsmParser::OperandType vector;
-  Type type;
-  Attribute attr;
-  if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
-      parser.parseAttribute(attr, "position", attrs) ||
-      parser.parseOptionalAttrDict(attrs) ||
-      parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
-    return failure();
-
-  auto vectorType = type.dyn_cast<VectorType>();
-  if (!vectorType)
-    return parser.emitError(typeLoc, "expected vector type");
-
-  auto positionAttr = attr.dyn_cast<ArrayAttr>();
-  if (!positionAttr ||
-      static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
-    return parser.emitError(
-        attributeLoc,
-        "expected position attribute of rank smaller than vector rank");
-
-  Type resType = inferExtractOpResultType(vectorType, positionAttr);
-  result.attributes = attrs;
-  return failure(parser.resolveOperand(vector, type, result.operands) ||
-                 parser.addTypeToList(resType, result.types));
+bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // Allow extracting 1-element vectors instead of scalars.
+  auto isCompatible = [](TypeRange l, TypeRange r) {
+    auto vectorType = l.front().dyn_cast<VectorType>();
+    return vectorType && vectorType.getShape().equals({1}) &&
+           vectorType.getElementType() == r.front();
+  };
+  if (l.size() == 1 && r.size() == 1 &&
+      (isCompatible(l, r) || isCompatible(r, l)))
+    return true;
+  return l == r;
 }
 
 LogicalResult vector::ExtractOp::verify() {

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index bc75e0bbe8b2e..2e224f7f58ebe 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -104,7 +104,7 @@ func @extract_element(%arg0: vector<4x4xf32>) {
 // -----
 
 func @extract_vector_type(%arg0: index) {
-  // expected-error at +1 {{expected vector type}}
+  // expected-error at +1 {{invalid kind of type specified}}
   %1 = vector.extract %arg0[] : index
 }
 


        


More information about the Mlir-commits mailing list