[Mlir-commits] [mlir] 018f693 - [MLIR][Standard] Simplify `tensor_from_elements`

Frederik Gossen llvmlistbot at llvm.org
Thu Sep 10 07:43:14 PDT 2020


Author: Frederik Gossen
Date: 2020-09-10T14:42:51Z
New Revision: 018f6936dbcee63e0a1ffd3777e854150b8cf957

URL: https://github.com/llvm/llvm-project/commit/018f6936dbcee63e0a1ffd3777e854150b8cf957
DIFF: https://github.com/llvm/llvm-project/commit/018f6936dbcee63e0a1ffd3777e854150b8cf957.diff

LOG: [MLIR][Standard] Simplify `tensor_from_elements`

Define assembly format and add required traits.

Differential Revision: https://reviews.llvm.org/D87366

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 44bbb423b2d95..ec7ecf9b92d40 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1611,8 +1611,14 @@ def ExtractElementOp : Std_Op<"extract_element",
 // TensorFromElementsOp
 //===----------------------------------------------------------------------===//
 
-def TensorFromElementsOp : Std_Op<"tensor_from_elements",
-    [NoSideEffect, SameOperandsAndResultElementType]> {
+def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
+    NoSideEffect,
+    SameOperandsAndResultElementType,
+    TypesMatchWith<"operand types match result element type",
+                   "result", "elements", "SmallVector<Type, 2>("
+                   "$_self.cast<ShapedType>().getDimSize(0), "
+                   "$_self.cast<ShapedType>().getElementType())">
+  ]> {
   string summary = "tensor from elements operation.";
   string description = [{
     Create a 1D tensor from a range of same-type arguments.
@@ -1625,9 +1631,13 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements",
   }];
 
   let arguments = (ins Variadic<AnyType>:$elements);
-  let results = (outs AnyTensor:$result);
+  let results = (outs 1DTensorOf<[AnyType]>:$result);
+
+  let assemblyFormat = "$elements attr-dict `:` type($result)";
+
+  // This op is fully verified by its traits.
+  let verifier = ?;
 
-  let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
   ];

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a0ad05852e230..dc45d5175277c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1756,50 +1756,12 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 // TensorFromElementsOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
-                                             OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
-  Type resultType;
-  if (parser.parseOperandList(elementsOperands) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(resultType))
-    return failure();
-
-  if (parser.resolveOperands(elementsOperands,
-                             resultType.cast<ShapedType>().getElementType(),
-                             result.operands))
-    return failure();
-
-  result.addTypes(resultType);
-  return success();
-}
-
-static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
-  p << "tensor_from_elements " << op.elements();
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getType();
-}
-
-static LogicalResult verify(TensorFromElementsOp op) {
-  auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
-  if (!resultTensorType)
-    return op.emitOpError("expected result type to be a ranked tensor");
-
-  int64_t elementsCount = static_cast<int64_t>(op.elements().size());
-  if (resultTensorType.getRank() != 1 ||
-      resultTensorType.getShape().front() != elementsCount)
-    return op.emitOpError()
-           << "expected result type to be a 1D tensor with " << elementsCount
-           << (elementsCount == 1 ? " element" : " elements");
-  return success();
-}
-
 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
                                  ValueRange elements) {
   assert(!elements.empty() && "expected at least one element");
-  result.addOperands(elements);
-  result.addTypes(RankedTensorType::get({static_cast<int64_t>(elements.size())},
-                                        *elements.getTypes().begin()));
+  Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
+                                        elements.front().getType());
+  build(builder, result, resultTy, elements);
 }
 
 namespace {

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 71b007ef6e39f..e02dbca494df6 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -595,7 +595,7 @@ func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
 // -----
 
 func @tensor_from_elements_wrong_result_type() {
-  // expected-error at +2 {{expected result type to be a ranked tensor}}
+  // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
   %c0 = constant 0 : i32
   %0 = tensor_from_elements %c0 : tensor<*xi32>
   return
@@ -604,7 +604,7 @@ func @tensor_from_elements_wrong_result_type() {
 // -----
 
 func @tensor_from_elements_wrong_elements_count() {
-  // expected-error at +2 {{expected result type to be a 1D tensor with 1 element}}
+  // expected-error at +2 {{1 operands present, but expected 2}}
   %c0 = constant 0 : index
   %0 = tensor_from_elements %c0 : tensor<2xindex>
   return


        


More information about the Mlir-commits mailing list