[Mlir-commits] [mlir] 136eb79 - [MLIR][Standard] Add `dynamic_tensor_from_elements` operation

Frederik Gossen llvmlistbot at llvm.org
Mon Sep 7 04:45:33 PDT 2020


Author: Frederik Gossen
Date: 2020-09-07T11:44:43Z
New Revision: 136eb79a8846c4e8ff6ba5ccfc0c470ab351fb13

URL: https://github.com/llvm/llvm-project/commit/136eb79a8846c4e8ff6ba5ccfc0c470ab351fb13
DIFF: https://github.com/llvm/llvm-project/commit/136eb79a8846c4e8ff6ba5ccfc0c470ab351fb13.diff

LOG: [MLIR][Standard] Add `dynamic_tensor_from_elements` operation

With `dynamic_tensor_from_elements` tensor values of dynamic size can be
created. The body of the operation essentially maps the index space to tensor
elements.

Declare SCF operations in the `scf` namespace to avoid name clash with the new
`std.yield` operation. Resolve ambiguities between `linalg/shape/std/scf.yield`
operations.

Differential Revision: https://reviews.llvm.org/D86276

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 78aefec00bf7..59ba50fbe232 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -19,7 +19,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def SCF_Dialect : Dialect {
   let name = "scf";
-  let cppNamespace = "";
+  let cppNamespace = "scf";
 }
 
 // Base class for SCF dialect ops.
@@ -39,7 +39,7 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
 def ForOp : SCF_Op<"for",
       [DeclareOpInterfaceMethods<LoopLikeOpInterface>,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
-       SingleBlockImplicitTerminator<"YieldOp">,
+       SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveSideEffects]> {
   let summary = "for operation";
   let description = [{
@@ -183,7 +183,7 @@ def ForOp : SCF_Op<"for",
 
 def IfOp : SCF_Op<"if",
       [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
-       SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects,
+       SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
        NoRegionArguments]> {
   let summary = "if-then-else operation";
   let description = [{
@@ -271,7 +271,7 @@ def ParallelOp : SCF_Op<"parallel",
     [AttrSizedOperandSegments,
      DeclareOpInterfaceMethods<LoopLikeOpInterface>,
      RecursiveSideEffects,
-     SingleBlockImplicitTerminator<"YieldOp">]> {
+     SingleBlockImplicitTerminator<"scf::YieldOp">]> {
   let summary = "parallel for operation";
   let description = [{
     The "scf.parallel" operation represents a loop nest taking 4 groups of SSA

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index ae951e824e00..f326ae557865 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1475,6 +1475,37 @@ def DivFOp : FloatArithmeticOp<"divf"> {
   let summary = "floating point division operation";
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicTensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
+    [RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> {
+  string summary = "Creates a dynamically sized tensor from elements";
+  string description = [{
+    This operation creates a dynamically sized tensor with elements of any type.
+    It expects one index operand per dynamic extent of the result tensor.
+
+    The body region defines the tensor's elements. It takes index operands as
+    its region arguments that span the index space. The element at the given
+    position is yielded with the `yield` operation (see `YieldOp`).
+
+    Example:
+
+    ```mlir
+      %tnsr = dynamic_tensor_from_elements %m, %n {
+      ^bb0(%i : index, %j : index, %k : index):
+        ...
+        yield %elem : f32
+      } : tensor<?x3x?f32>
+    ```
+  }];
+
+  let arguments = (ins Variadic<Index>:$dynamicExtents);
+  let results = (outs AnyRankedTensor:$result);
+  let regions = (region SizedRegion<1>:$body);
+}
+
 //===----------------------------------------------------------------------===//
 // ExpOp
 //===----------------------------------------------------------------------===//
@@ -3252,6 +3283,24 @@ def ViewOp : Std_Op<"view", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
+                               HasParent<"DynamicTensorFromElementsOp">]> {
+  let summary = "Yield a value from a region";
+  let description = [{
+     This operation is used to yield a single value from a within a region. It
+     is used to create dynamically sized tensors
+     (see `DynamicTensorFromElementsOp`).
+  }];
+
+  let arguments = (ins AnyType:$value);
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // XOrOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 0460d98b44a4..f38eabb9465d 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -339,7 +339,8 @@ class TransposeOpConversion : public ConvertToLLVMPattern {
 class YieldOpConversion : public ConvertToLLVMPattern {
 public:
   explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
-      : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
+      : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
+                             lowering_) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,

diff  --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 34ee48758e9e..14f365f95ee5 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -356,7 +356,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
       // A loop is constructed with an empty "yield" terminator if there are
       // no results.
       rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
-      rewriter.create<YieldOp>(loc, forOp.getResults());
+      rewriter.create<scf::YieldOp>(loc, forOp.getResults());
     }
 
     rewriter.setInsertionPointToStart(forOp.getBody());
@@ -391,7 +391,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
 
   if (!yieldOperands.empty()) {
     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
-    rewriter.create<YieldOp>(loc, yieldOperands);
+    rewriter.create<scf::YieldOp>(loc, yieldOperands);
   }
 
   rewriter.replaceOp(parallelOp, loopResults);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fa45997ae801..c9b05f89f30b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -905,7 +905,7 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
 // YieldOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, YieldOp op) {
+static void print(OpAsmPrinter &p, linalg::YieldOp op) {
   p << op.getOperationName();
   if (op.getNumOperands() > 0)
     p << ' ' << op.getOperands();
@@ -926,7 +926,8 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
 
 // Check the operand number and types must match the element types of the
 // LinalgOp interface's shaped operands.
-static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
+static LogicalResult verifyYield(linalg::YieldOp op,
+                                 LinalgOp linalgOpInterface) {
   auto nOutputs = linalgOpInterface.getNumOutputs();
   if (op.getNumOperands() != nOutputs)
     return op.emitOpError("expected number of yield values (")
@@ -946,7 +947,7 @@ static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
   return success();
 }
 
-static LogicalResult verify(YieldOp op) {
+static LogicalResult verify(linalg::YieldOp op) {
   auto *parentOp = op.getParentOp();
   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
     return op.emitOpError("expected single non-empty parent region");

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 6c0c841451da..adbf4a7b8045 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -659,7 +659,7 @@ struct FuseGenericOpsOnTensors {
     // Add operations from producer (except the yield operation) to the fused
     // op.
     for (auto &op : producerBlock.getOperations()) {
-      if (auto yieldOp = dyn_cast<YieldOp>(op)) {
+      if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
         // Lookup the value the yield operation is mapped to.
         Value yieldVal = yieldOp.getOperand(0);
         if (Value clonedVal = mapper.lookupOrNull(yieldVal))

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 281edd9a91f6..d4d1d108be71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -147,7 +147,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
   }
 
   Operation &terminator = block.back();
-  assert(isa<YieldOp>(terminator) &&
+  assert(isa<linalg::YieldOp>(terminator) &&
          "expected a yield op in the end of the region");
   for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
     IndexedValueType O(outputBuffers[i]);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c8e20ce57842..ada89f1c82b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -48,14 +48,14 @@ static bool hasMultiplyAddBody(Region &r) {
   auto c = m_Val(r.getArgument(2));
   // TODO: Update this detection once we have  matcher support for specifying
   // that any permutation of operands matches.
-  auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
-  auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
-  auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
-  auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
-  auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
-  auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
-  auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
-  auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
+  auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
+  auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
+  auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
+  auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
+  auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
+  auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
+  auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
+  auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
   return pattern1.match(&r.front().back()) ||
          pattern2.match(&r.front().back()) ||
          pattern3.match(&r.front().back()) ||

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 6f3f1e4dc0d1..498246315d64 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -38,7 +38,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
   // as necessary. Required when the region has only one block.
   void handleTerminator(Operation *op,
                         ArrayRef<Value> valuesToRepl) const final {
-    auto retValOp = dyn_cast<YieldOp>(op);
+    auto retValOp = dyn_cast<scf::YieldOp>(op);
     if (!retValOp)
       return;
 
@@ -889,7 +889,7 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-static void print(OpAsmPrinter &p, YieldOp op) {
+static void print(OpAsmPrinter &p, scf::YieldOp op) {
   p << op.getOperationName();
   if (op.getNumOperands() != 0)
     p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
@@ -899,5 +899,9 @@ static void print(OpAsmPrinter &p, YieldOp op) {
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
+namespace mlir {
+namespace scf {
 #define GET_OP_CLASSES
 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
+} // namespace scf
+} // namespace mlir

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 511ec9bf2b4e..bcfaa896f63d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -779,7 +779,7 @@ void SizeToIndexOp::getCanonicalizationPatterns(
 // YieldOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(YieldOp op) {
+static LogicalResult verify(shape::YieldOp op) {
   auto *parentOp = op.getParentOp();
   auto results = parentOp->getResults();
   auto operands = op.getOperands();

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index a84fad1f9460..ff74ce069e40 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -45,7 +45,7 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
   OpBuilder b = OpBuilder::atBlockEnd(body);
   Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
                                   body->getArgument(2));
-  b.create<YieldOp>(loc, product);
+  b.create<shape::YieldOp>(loc, product);
 
   rewriter.replaceOp(op, reduce.result());
   return success();

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index b34257791d78..65f8b83d9a71 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1312,7 +1312,6 @@ Optional<int64_t> DimOp::getConstantIndex() {
 }
 
 static LogicalResult verify(DimOp op) {
-
   // Assume unknown index to be in range.
   Optional<int64_t> index = op.getConstantIndex();
   if (!index.hasValue())
@@ -1634,6 +1633,67 @@ LogicalResult DmaWaitOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicTensorFromElementsOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
+                                                    OperationState &result) {
+  // Parse operands.
+  SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
+  Type indexTy = parser.getBuilder().getIndexType();
+  if (parser.parseOperandList(dynamicExtents) ||
+      parser.resolveOperands(dynamicExtents, indexTy, result.operands))
+    return failure();
+
+  // Parse body.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, {}, {}))
+    return failure();
+
+  // Parse result type.
+  Type resultType;
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(resultType))
+    return failure();
+  result.addTypes(resultType);
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
+  p << "dynamic_tensor_from_elements " << op.dynamicExtents();
+  p.printRegion(op.body());
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.getType();
+}
+
+static LogicalResult verify(DynamicTensorFromElementsOp op) {
+  // Ensure that the tensor type has as many dynamic dimensions as are specified
+  // by the operands.
+  RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
+  if (op.getNumOperands() != resultTy.getNumDynamicDims())
+    return op.emitError("must have as many index operands as dynamic extents "
+                        "in the result type");
+
+  // Ensure that region arguments span the index space.
+  if (!llvm::all_of(op.body().getArgumentTypes(),
+                    [](Type ty) { return ty.isIndex(); }))
+    return op.emitError("all body arguments must be index");
+  if (op.body().getNumArguments() != resultTy.getRank())
+    return op.emitError("must have one body argument per input dimension");
+
+  // Ensure that the region yields an element of the right type.
+  auto yieldOp =
+      llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
+  if (yieldOp.value().getType() != resultTy.getElementType())
+    return op.emitOpError(
+        "body must be terminated with a `yield` operation of the tensor "
+        "element type");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index f2b71f634cd3..7f9c564e74f3 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -15,3 +15,69 @@ func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
   %0 = index_cast %arg0 : tensor<index> to i64
   return %0 : i64
 }
+
+// -----
+
+func @dynamic_tensor_from_elements(%m : index)
+    -> tensor<?x3x?xf32> {
+  // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}}
+  %tnsr = dynamic_tensor_from_elements %m {
+    ^bb0(%i : index, %j : index, %k : index):
+      %elem = constant 8.0 : f32
+      yield %elem : f32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @dynamic_tensor_from_elements(%m : index, %n : index)
+    -> tensor<?x3x?xf32> {
+  // expected-error @+1 {{must have one body argument per input dimension}}
+  %tnsr = dynamic_tensor_from_elements %m, %n {
+    ^bb0(%i : index, %j : index):
+      %elem = constant 8.0 : f32
+      yield %elem : f32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @dynamic_tensor_from_elements(%m : index, %n : index)
+    -> tensor<?x3x?xf32> {
+  // expected-error @+1 {{all body arguments must be index}}
+  %tnsr = dynamic_tensor_from_elements %m, %n {
+    ^bb0(%i : index, %j : index, %k : i64):
+      %elem = constant 8.0 : f32
+      yield %elem : f32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @dynamic_tensor_from_elements(%m : index, %n : index)
+    -> tensor<?x3x?xf32> {
+  // expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}}
+  // expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}}
+  %tnsr = dynamic_tensor_from_elements %m, %n {
+    ^bb0(%i : index, %j : index, %k : index):
+      %elem = constant 8.0 : f32
+      return %elem : f32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @dynamic_tensor_from_elements(%m : index, %n : index)
+    -> tensor<?x3x?xf32> {
+  // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}}
+  %tnsr = dynamic_tensor_from_elements %m, %n {
+    ^bb0(%i : index, %j : index, %k : index):
+      %elem = constant 8 : i32
+      yield %elem : i32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 24da04eebaaa..a765acb9657b 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -split-input-file %s | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: test_index_cast
 func @test_index_cast(%arg0 : index) -> i64 {
@@ -22,3 +23,14 @@ func @assert(%arg : i1) {
   assert %arg, "Some message in case this assertion fails."
   return
 }
+
+func @dynamic_tensor_from_elements(%m : index, %n : index)
+    -> tensor<?x3x?xf32> {
+  %tnsr = dynamic_tensor_from_elements %m, %n {
+    ^bb0(%i : index, %j : index, %k : index):
+      %elem = constant 8.0 : f32
+      yield %elem : f32
+  } : tensor<?x3x?xf32>
+  return %tnsr : tensor<?x3x?xf32>
+}
+


        


More information about the Mlir-commits mailing list