[Mlir-commits] [mlir] c3098e4 - [MLIR] Add TensorFromElementsOp to Standard ops.

Alexander Belyaev llvmlistbot at llvm.org
Thu May 28 06:53:21 PDT 2020


Author: Alexander Belyaev
Date: 2020-05-28T15:48:10+02:00
New Revision: c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4

URL: https://github.com/llvm/llvm-project/commit/c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4
DIFF: https://github.com/llvm/llvm-project/commit/c3098e4f4036e96dbd3de0e61c5e114b0eb7bbb4.diff

LOG: [MLIR] Add TensorFromElementsOp to Standard ops.

Differential Revision: https://reviews.llvm.org/D80705

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 030470317236..eae71b0263c1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1524,6 +1524,39 @@ def ExtractElementOp : Std_Op<"extract_element",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+def TensorFromElementsOp : Std_Op<"tensor_from_elements",
+    [NoSideEffect, SameOperandsAndResultElementType]> {
+  string summary = "tensor from elements operation.";
+  string description = [{
+    Create a 1D tensor from a range of same-type arguments.
+
+    Example:
+
+    ```mlir
+    tensor_from_elements(i_1, ..., i_N) :  tensor<Nxindex>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$elements);
+  let results = (outs AnyTensor:$result);
+
+  let skipDefaultBuilders = 1;
+  let builders = [OpBuilder<
+    "OpBuilder &builder, OperationState &result, ValueRange elements", [{
+      assert(!elements.empty() && "expected at least one element");
+      result.addOperands(elements);
+      result.addTypes(
+          RankedTensorType::get({static_cast<int64_t>(elements.size())},
+                                *elements.getTypes().begin()));
+    }]>];
+
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // FPExtOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3d493a8a57a5..118a1119833c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1640,6 +1640,86 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// TensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
+                                             OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
+  Type resultType;
+  if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
+      parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColon() || parser.parseType(resultType))
+    return failure();
+
+  if (parser.resolveOperands(elementsOperands,
+                             resultType.cast<ShapedType>().getElementType(),
+                             result.operands))
+    return failure();
+
+  result.addTypes(resultType);
+  return success();
+}
+
+static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
+  p << "tensor_from_elements(" << op.elements() << ')';
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.result().getType();
+}
+
+static LogicalResult verify(TensorFromElementsOp op) {
+  auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
+  if (!resultTensorType)
+    return op.emitOpError("expected result type to be a ranked tensor");
+
+  int64_t elementsCount = static_cast<int64_t>(op.elements().size());
+  if (resultTensorType.getRank() != 1 ||
+      resultTensorType.getShape().front() != elementsCount)
+    return op.emitOpError()
+           << "expected result type to be a 1D tensor with " << elementsCount
+           << (elementsCount == 1 ? " element" : " elements");
+  return success();
+}
+
+namespace {
+
+// Canonicalizes the pattern of the form
+//
+// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
+// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
+//
+// to just %element.
+struct ExtractElementFromTensorFromElements
+    : public OpRewritePattern<ExtractElementOp> {
+  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractElementOp extract,
+                                PatternRewriter &rewriter) const final {
+    if (extract.indices().size() != 1)
+      return failure();
+
+    auto tensor_from_elements =
+        dyn_cast<TensorFromElementsOp>(extract.aggregate().getDefiningOp());
+    if (tensor_from_elements == nullptr)
+      return failure();
+
+    APInt index;
+    if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
+      return failure();
+    rewriter.replaceOp(extract,
+                       tensor_from_elements.getOperand(index.getZExtValue()));
+    return success();
+  }
+};
+
+} // namespace
+
+void TensorFromElementsOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ExtractElementFromTensorFromElements>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // FPExtOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 41172aa22527..7727fa5e0363 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -644,6 +644,24 @@ func @extract_element(%arg0: tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: func @tensor_from_elements() {
+func @tensor_from_elements() {
+  %c0 = "std.constant"() {value = 0: index} : () -> index
+  // CHECK: %0 = tensor_from_elements(%c0) : tensor<1xindex>
+  %0 = tensor_from_elements(%c0) : tensor<1xindex>
+
+  %c1 = "std.constant"() {value = 1: index} : () -> index
+  // CHECK: %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
+  %1 = tensor_from_elements(%c0, %c1) : tensor<2xindex>
+
+  %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32
+  // CHECK: [[C0_F32:%.*]] = constant
+  // CHECK: %2 = tensor_from_elements([[C0_F32]]) : tensor<1xf32>
+  %2 = tensor_from_elements(%c0_f32) : tensor<1xf32>
+
+  return
+}
+
 // CHECK-LABEL: func @tensor_cast(%arg0
 func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
   // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b0535047874f..c8e40c520139 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -605,7 +605,24 @@ func @extract_element_tensor_too_many_indices(%t : tensor<2x3xf32>, %i : index)
 
 func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
   // expected-error at +1 {{incorrect number of indices for extract_element}}
-  %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32
+  %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return
+}
+
+// -----
+
+func @tensor_from_elements_wrong_result_type() {
+  // expected-error at +2 {{expected result type to be a ranked tensor}}
+  %c0 = constant 0 : i32
+  %0 = tensor_from_elements(%c0) : tensor<*xi32>
+  return
+}
+
+// -----
+
+func @tensor_from_elements_wrong_elements_count() {
+  // expected-error at +2 {{expected result type to be a 1D tensor with 1 element}}
+  %c0 = constant 0 : index
+  %0 = tensor_from_elements(%c0) : tensor<2xindex>
   return
 }
 

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index b17cade291a5..6e24bb3b2d83 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -971,3 +971,15 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
   // CHECK:  memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
   return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_tensor_from_elements
+func @extract_element_from_tensor_from_elements(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %c0 = constant 0 : index
+  %tensor = tensor_from_elements(%element) : tensor<1xindex>
+  %extracted_element = extract_element %tensor[%c0] : tensor<1xindex>
+  // CHECK: [[ARG]] : index
+  return %extracted_element : index
+}


        


More information about the Mlir-commits mailing list