[Mlir-commits] [mlir] 018f693 - [MLIR][Standard] Simplify `tensor_from_elements`
Frederik Gossen
llvmlistbot at llvm.org
Thu Sep 10 07:43:14 PDT 2020
Author: Frederik Gossen
Date: 2020-09-10T14:42:51Z
New Revision: 018f6936dbcee63e0a1ffd3777e854150b8cf957
URL: https://github.com/llvm/llvm-project/commit/018f6936dbcee63e0a1ffd3777e854150b8cf957
DIFF: https://github.com/llvm/llvm-project/commit/018f6936dbcee63e0a1ffd3777e854150b8cf957.diff
LOG: [MLIR][Standard] Simplify `tensor_from_elements`
Define assembly format and add required traits.
Differential Revision: https://reviews.llvm.org/D87366
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 44bbb423b2d95..ec7ecf9b92d40 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1611,8 +1611,14 @@ def ExtractElementOp : Std_Op<"extract_element",
// TensorFromElementsOp
//===----------------------------------------------------------------------===//
-def TensorFromElementsOp : Std_Op<"tensor_from_elements",
- [NoSideEffect, SameOperandsAndResultElementType]> {
+def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
+ NoSideEffect,
+ SameOperandsAndResultElementType,
+ TypesMatchWith<"operand types match result element type",
+ "result", "elements", "SmallVector<Type, 2>("
+ "$_self.cast<ShapedType>().getDimSize(0), "
+ "$_self.cast<ShapedType>().getElementType())">
+ ]> {
string summary = "tensor from elements operation.";
string description = [{
Create a 1D tensor from a range of same-type arguments.
@@ -1625,9 +1631,13 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements",
}];
let arguments = (ins Variadic<AnyType>:$elements);
- let results = (outs AnyTensor:$result);
+ let results = (outs 1DTensorOf<[AnyType]>:$result);
+
+ let assemblyFormat = "$elements attr-dict `:` type($result)";
+
+ // This op is fully verified by its traits.
+ let verifier = ?;
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
];
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a0ad05852e230..dc45d5175277c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1756,50 +1756,12 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// TensorFromElementsOp
//===----------------------------------------------------------------------===//
-static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
- Type resultType;
- if (parser.parseOperandList(elementsOperands) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(resultType))
- return failure();
-
- if (parser.resolveOperands(elementsOperands,
- resultType.cast<ShapedType>().getElementType(),
- result.operands))
- return failure();
-
- result.addTypes(resultType);
- return success();
-}
-
-static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
- p << "tensor_from_elements " << op.elements();
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getType();
-}
-
-static LogicalResult verify(TensorFromElementsOp op) {
- auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
- if (!resultTensorType)
- return op.emitOpError("expected result type to be a ranked tensor");
-
- int64_t elementsCount = static_cast<int64_t>(op.elements().size());
- if (resultTensorType.getRank() != 1 ||
- resultTensorType.getShape().front() != elementsCount)
- return op.emitOpError()
- << "expected result type to be a 1D tensor with " << elementsCount
- << (elementsCount == 1 ? " element" : " elements");
- return success();
-}
-
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
- result.addOperands(elements);
- result.addTypes(RankedTensorType::get({static_cast<int64_t>(elements.size())},
- *elements.getTypes().begin()));
+ Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
+ elements.front().getType());
+ build(builder, result, resultTy, elements);
}
namespace {
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 71b007ef6e39f..e02dbca494df6 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -595,7 +595,7 @@ func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
// -----
func @tensor_from_elements_wrong_result_type() {
- // expected-error at +2 {{expected result type to be a ranked tensor}}
+ // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = constant 0 : i32
%0 = tensor_from_elements %c0 : tensor<*xi32>
return
@@ -604,7 +604,7 @@ func @tensor_from_elements_wrong_result_type() {
// -----
func @tensor_from_elements_wrong_elements_count() {
- // expected-error at +2 {{expected result type to be a 1D tensor with 1 element}}
+ // expected-error at +2 {{1 operands present, but expected 2}}
%c0 = constant 0 : index
%0 = tensor_from_elements %c0 : tensor<2xindex>
return
More information about the Mlir-commits
mailing list