[Mlir-commits] [mlir] c3098e4 - [MLIR] Add TensorFromElementsOp to Standard ops.
Alexander Belyaev
llvmlistbot at llvm.org
Thu May 28 06:53:21 PDT 2020
Author: Alexander Belyaev
Date: 2020-05-28T15:48:10+02:00
New Revision: c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4
URL: https://github.com/llvm/llvm-project/commit/c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4
DIFF: https://github.com/llvm/llvm-project/commit/c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4.diff
LOG: [MLIR] Add TensorFromElementsOp to Standard ops.
Differential Revision: https://reviews.llvm.org/D80705
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 030470317236..eae71b0263c1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1524,6 +1524,39 @@ def ExtractElementOp : Std_Op<"extract_element",
}];
}
+//===----------------------------------------------------------------------===//
+// TensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+def TensorFromElementsOp : Std_Op<"tensor_from_elements",
+ [NoSideEffect, SameOperandsAndResultElementType]> {
+ string summary = "tensor from elements operation.";
+ string description = [{
+ Create a 1D tensor from a range of same-type arguments.
+
+ Example:
+
+ ```mlir
+ tensor_from_elements(i_1, ..., i_N) : tensor<Nxindex>
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$elements);
+ let results = (outs AnyTensor:$result);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "OpBuilder &builder, OperationState &result, ValueRange elements", [{
+ assert(!elements.empty() && "expected at least one element");
+ result.addOperands(elements);
+ result.addTypes(
+ RankedTensorType::get({static_cast<int64_t>(elements.size())},
+ *elements.getTypes().begin()));
+ }]>];
+
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3d493a8a57a5..118a1119833c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1640,6 +1640,86 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+//===----------------------------------------------------------------------===//
+// TensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
+ Type resultType;
+ if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
+ parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColon() || parser.parseType(resultType))
+ return failure();
+
+ if (parser.resolveOperands(elementsOperands,
+ resultType.cast<ShapedType>().getElementType(),
+ result.operands))
+ return failure();
+
+ result.addTypes(resultType);
+ return success();
+}
+
+static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
+ p << "tensor_from_elements(" << op.elements() << ')';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.result().getType();
+}
+
+static LogicalResult verify(TensorFromElementsOp op) {
+ auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
+ if (!resultTensorType)
+ return op.emitOpError("expected result type to be a ranked tensor");
+
+ int64_t elementsCount = static_cast<int64_t>(op.elements().size());
+ if (resultTensorType.getRank() != 1 ||
+ resultTensorType.getShape().front() != elementsCount)
+ return op.emitOpError()
+ << "expected result type to be a 1D tensor with " << elementsCount
+ << (elementsCount == 1 ? " element" : " elements");
+ return success();
+}
+
+namespace {
+
+// Canonicalizes the pattern of the form
+//
+// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
+// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
+//
+// to just %element.
+struct ExtractElementFromTensorFromElements
+ : public OpRewritePattern<ExtractElementOp> {
+ using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractElementOp extract,
+ PatternRewriter &rewriter) const final {
+ if (extract.indices().size() != 1)
+ return failure();
+
+ auto tensor_from_elements =
+ dyn_cast<TensorFromElementsOp>(extract.aggregate().getDefiningOp());
+ if (tensor_from_elements == nullptr)
+ return failure();
+
+ APInt index;
+ if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
+ return failure();
+ rewriter.replaceOp(extract,
+ tensor_from_elements.getOperand(index.getZExtValue()));
+ return success();
+ }
+};
+
+} // namespace
+
+void TensorFromElementsOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExtractElementFromTensorFromElements>(context);
+}
+
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 41172aa22527..7727fa5e0363 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -644,6 +644,24 @@ func @extract_element(%arg0: tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: func @tensor_from_elements() {
+func @tensor_from_elements() {
+ %c0 = "std.constant"() {value = 0: index} : () -> index
+ // CHECK: %0 = tensor_from_elements(%c0) : tensor<1xindex>
+ %0 = tensor_from_elements(%c0) : tensor<1xindex>
+
+ %c1 = "std.constant"() {value = 1: index} : () -> index
+ // CHECK: %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
+ %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
+
+ %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32
+ // CHECK: [[C0_F32:%.*]] = constant
+ // CHECK: %2 = tensor_from_elements([[C0_F32]]) : tensor<1xf32>
+ %2 = tensor_from_elements(%c0_f32) : tensor<1xf32>
+
+ return
+}
+
// CHECK-LABEL: func @tensor_cast(%arg0
func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b0535047874f..c8e40c520139 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -605,7 +605,24 @@ func @extract_element_tensor_too_many_indices(%t : tensor<2x3xf32>, %i : index)
func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
// expected-error at +1 {{incorrect number of indices for extract_element}}
- %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32
+ %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return
+}
+
+// -----
+
+func @tensor_from_elements_wrong_result_type() {
+ // expected-error at +2 {{expected result type to be a ranked tensor}}
+ %c0 = constant 0 : i32
+ %0 = tensor_from_elements(%c0) : tensor<*xi32>
+ return
+}
+
+// -----
+
+func @tensor_from_elements_wrong_elements_count() {
+ // expected-error at +2 {{expected result type to be a 1D tensor with 1 element}}
+ %c0 = constant 0 : index
+ %0 = tensor_from_elements(%c0) : tensor<2xindex>
return
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index b17cade291a5..6e24bb3b2d83 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -971,3 +971,15 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
// CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_tensor_from_elements
+func @extract_element_from_tensor_from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c0 = constant 0 : index
+ %tensor = tensor_from_elements(%element) : tensor<1xindex>
+ %extracted_element = extract_element %tensor[%c0] : tensor<1xindex>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
More information about the Mlir-commits
mailing list