[llvm-branch-commits] [mlir] 118a715 - [mlir][Linalg] Define a linalg.init_tensor operation.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 17 14:54:23 PST 2020


Author: MaheshRavishankar
Date: 2020-12-17T14:45:51-08:00
New Revision: 118a71565462db41cab1dbb0349200627d6e8524

URL: https://github.com/llvm/llvm-project/commit/118a71565462db41cab1dbb0349200627d6e8524
DIFF: https://github.com/llvm/llvm-project/commit/118a71565462db41cab1dbb0349200627d6e8524.diff

LOG: [mlir][Linalg] Define a linalg.init_tensor operation.

This operation is used to materialize a tensor of a particular
shape. The shape could be specified as a mix of static and dynamic
values.

The use of this operation is to be an `init` tensor for Linalg
structured operation on tensors where the bounds of the computation
depends on the shape of the output of the linalg operation. The result
of this operation will be used as the `init` tensor of such Linalg
operations. To note,

1) The values in the tensor materialized is not used. Any operation to
   which this is an init tensor is expected to overwrite the entire
   tensor.
2) The tensor is materialized only for the shape of the output and to
   make the loop bounds depend only on operands of the structured
   operation.

Based on (1) and (2) it is assumed that these operations eventually go
away since they are only used in `dim` operations that can be
canonicalized to make this operation dead. Such canonicalization are
added here too.

Differential Revision: https://reviews.llvm.org/D93374

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 454dde1bff93..a2e7d436eeb8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -32,6 +32,91 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
+  let summary = "operation to define a tensor of particular value";
+
+  let description = [{
+    `linalg.init_tensor` is an operation that materializes a tensor of
+    a given shape. The shape could be dynamic or static.
+  }];
+
+  let arguments =
+    (ins Variadic<Index>:$sizes, I64ArrayAttr:$static_sizes);
+
+  let results = (outs AnyTensor:$result);
+
+  let verifier = [{ return ::verify(*this); }];
+
+  let extraClassDeclaration = [{
+    static StringRef getStaticSizesAttrName() {
+      return "static_sizes";
+    }
+
+    RankedTensorType getType() {
+      return getResult().getType().cast<RankedTensorType>(); }
+
+    // Infer the shape of the result tensor given the static shapes
+    // and element type of the result tensor.
+    static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
+
+    // Return true if the size of the tensor is dynamic at `idx`
+    bool isDynamicSize(unsigned idx) {
+      APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
+      return ShapedType::isDynamic(v.getSExtValue());
+    }
+
+    // Assert that the size of the result tensor is static at `idx`
+    // and return the shape.
+    int64_t getStaticSize(unsigned idx) {
+      assert(!isDynamicSize(idx) && "expected static size");
+      APInt v = *(static_sizes().
+          template getAsValueRange<IntegerAttr>().begin() + idx);
+        return v.getSExtValue();
+    }
+
+    // Return the argument position that contains the dynamic size of
+    // the tensor at dimension `idx`. Asserts that the shape is
+    // dynamic at that `idx`.
+    unsigned getIndexOfDynamicSize(unsigned idx) {
+      assert(isDynamicSize(idx) && "expected dynamic size");
+      return std::count_if(
+          static_sizes().getValue().begin(),
+          static_sizes().getValue().begin() + idx,
+          [&](Attribute attr) {
+            return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
+          });
+    }
+
+    // Return the Value of the dynamic size of the tensor at dimension
+    // `idx`. Asserts that the shape is dynamic at that `idx.
+    Value getDynamicSize(unsigned idx) {
+      return getOperand(getIndexOfDynamicSize(idx));
+    }
+  }];
+
+  let builders = [
+    OpBuilderDAG<(ins "ValueRange":$shape,
+                  "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
+    [{
+      build($_builder, $_state,
+            InitTensorOp::inferResultType(staticShape, elementType),
+            shape, $_builder.getI64ArrayAttr(staticShape));
+    }]>,
+    OpBuilderDAG<(ins "ValueRange":$shape, "Type":$elementType),
+    [{
+      SmallVector<int64_t, 4> staticShape(
+        shape.size(), ShapedType::kDynamicSize);
+      build($_builder, $_state, shape, staticShape, elementType);
+    }]>,
+    OpBuilderDAG<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
+    [{
+      build($_builder, $_state, ValueRange{}, staticShape, elementType);
+    }]>
+  ];
+
+  let hasCanonicalizer = 1;
+}
+
 def Linalg_RangeOp :
     Linalg_Op<"range", [NoSideEffect]>,
     Arguments<(ins Index:$min, Index:$max, Index:$step)>,

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d62b25cb0e27..d5f44c3e63da 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -35,6 +35,14 @@ LogicalResult verify(OffsetSizeAndStrideOpInterface op);
 #include "mlir/Interfaces/ViewLikeInterface.h.inc"
 
 namespace mlir {
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+                                   ArrayAttr arrayAttr,
+                                   llvm::function_ref<bool(int64_t)> isDynamic);
+
 /// Print part of an op of the form:
 /// ```
 ///   <optional-offset-prefix>`[` offset-list `]`
@@ -48,6 +56,19 @@ void printOffsetsSizesAndStrides(
     ArrayRef<StringRef> elidedAttrs =
         OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
 
+/// Parse a mixed list with either (1) static integer values or (2) SSA values.
+/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
+/// encode the position of SSA values. Add the parsed SSA values to `ssa`
+/// in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+///   2. `ssa` is filled with "[%arg0, %arg1]".
+ParseResult
+parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
+                              StringRef attrName, int64_t dynVal,
+                              SmallVectorImpl<OpAsmParser::OperandType> &ssa);
+
 /// Parse trailing part of an op of the form:
 /// ```
 ///   <optional-offset-prefix>`[` offset-list `]`
@@ -87,6 +108,12 @@ ParseResult parseOffsetsSizesAndStrides(
     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
         nullptr);
 
+/// Verify that a the `values` has as many elements as the number of entries in
+/// `attr` for which `isDynamic` evaluates to true.
+LogicalResult verifyListOfOperandsOrIntegers(
+    Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
+    ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
+
 } // namespace mlir
 
 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7148fe68dd..f2b05448dbd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -550,6 +550,145 @@ static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
 
 static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 
+//===----------------------------------------------------------------------===//
+// InitTensorOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseInitTensorOp(OpAsmParser &parser,
+                                     OperationState &result) {
+  OpAsmParser::OperandType srcInfo;
+  Type dstType;
+  SmallVector<OpAsmParser::OperandType, 2> sizeInfo;
+  IndexType indexType = parser.getBuilder().getIndexType();
+  if (failed(parseListOfOperandsOrIntegers(
+          parser, result, InitTensorOp::getStaticSizesAttrName(),
+          ShapedType::kDynamicSize, sizeInfo)) ||
+      failed(parser.parseOptionalAttrDict(result.attributes)) ||
+      failed(parser.parseColonType(dstType)) ||
+      failed(parser.resolveOperands(sizeInfo, indexType, result.operands)))
+    return failure();
+  return parser.addTypeToList(dstType, result.types);
+}
+
+static void print(OpAsmPrinter &p, InitTensorOp op) {
+  p << op.getOperation()->getName() << ' ';
+  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+                                ShapedType::isDynamic);
+  p.printOptionalAttrDict(op.getAttrs(),
+                          InitTensorOp::getStaticSizesAttrName());
+  p << " : " << op.getType();
+}
+
+static LogicalResult verify(InitTensorOp op) {
+  RankedTensorType resultType = op.getType();
+  SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
+      op.static_sizes().cast<ArrayAttr>(),
+      [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
+
+  if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
+                                            op.static_sizes(), op.sizes(),
+                                            ShapedType::isDynamic)))
+    return failure();
+
+  Type expectedType =
+      InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
+  if (resultType != expectedType) {
+    return op.emitError("specified type ")
+           << resultType << " does not match the inferred type "
+           << expectedType;
+  }
+  return success();
+}
+
+Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
+                                   Type elementType) {
+  return RankedTensorType::get(staticSizes, elementType);
+}
+
+namespace {
+/// Change the type of the result of a `linalg.init_tensor` by making the result
+/// type statically sized along dimension that in the original operation where
+/// defined as dynamic, but the size was defined using a `constant` op. For
+/// example
+///
+///  %c5 = constant 5: index
+///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
+///
+///  to
+///
+///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
+struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
+  using OpRewritePattern<InitTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InitTensorOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value, 4> dynamicSizes;
+    SmallVector<int64_t, 4> staticSizes;
+    for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
+      // If the size is already static, nothing to do.
+      if (!op.isDynamicSize(i)) {
+        staticSizes.push_back(op.getStaticSize(i));
+        continue;
+      }
+
+      // If the size is dynamic but defined using a `constant` op, get the
+      // constant value to find the static size to use.
+      unsigned operandNum = op.getIndexOfDynamicSize(i);
+      Value sizeOperand = op.getOperand(operandNum);
+      if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
+        staticSizes.push_back(constantIndexOp.getValue());
+        continue;
+      }
+
+      // Fallback case. Keep the size dynamic.
+      dynamicSizes.push_back(sizeOperand);
+      staticSizes.push_back(ShapedType::kDynamicSize);
+    }
+    RankedTensorType newType =
+        RankedTensorType::get(staticSizes, op.getType().getElementType());
+    if (newType == op.getType())
+      return failure();
+    auto newOp =
+        rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
+                                      rewriter.getI64ArrayAttr(staticSizes));
+    rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
+    return success();
+  }
+};
+
+/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
+/// with
+/// - A constant value if the size is static along the dimension.
+/// - The dynamic value that defines the size of the result of
+///   `linalg.init_tensor` op.
+struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
+    if (!initTensorOp)
+      return failure();
+    auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
+    if (!dimIndex)
+      return failure();
+    int64_t index = dimIndex.getValue();
+    if (!initTensorOp.isDynamicSize(index)) {
+      rewriter.replaceOpWithNewOp<ConstantIndexOp>(
+          dimOp, initTensorOp.getStaticSize(index));
+    } else {
+      rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
+    }
+    return success();
+  }
+};
+} // namespace
+
+void InitTensorOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 6127d08a8fc5..16e44f2d227f 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -17,54 +17,43 @@ using namespace mlir;
 /// Include the definitions of the loop-like interfaces.
 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
 
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
-    OffsetSizeAndStrideOpInterface op, StringRef name,
-    unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
-    llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
+LogicalResult mlir::verifyListOfOperandsOrIntegers(
+    Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
+    ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
   /// Check static and dynamic offsets/sizes/strides breakdown.
   if (attr.size() != expectedNumElements)
-    return op.emitError("expected ")
+    return op->emitError("expected ")
            << expectedNumElements << " " << name << " values";
   unsigned expectedNumDynamicEntries =
       llvm::count_if(attr.getValue(), [&](Attribute attr) {
         return isDynamic(attr.cast<IntegerAttr>().getInt());
       });
   if (values.size() != expectedNumDynamicEntries)
-    return op.emitError("expected ")
+    return op->emitError("expected ")
            << expectedNumDynamicEntries << " dynamic " << name << " values";
   return success();
 }
 
 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
   std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "offset", ranks[0],
-          OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
-          op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
-          op.offsets())))
+  if (failed(verifyListOfOperandsOrIntegers(
+          op, "offset", ranks[0], op.static_offsets(), op.offsets(),
+          ShapedType::isDynamicStrideOrOffset)))
     return failure();
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "size", ranks[1],
-          OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
-          op.static_sizes(), ShapedType::isDynamic, op.sizes())))
+  if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1],
+                                            op.static_sizes(), op.sizes(),
+                                            ShapedType::isDynamic)))
     return failure();
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "stride", ranks[2],
-          OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
-          op.static_strides(), ShapedType::isDynamicStrideOrOffset,
-          op.strides())))
+  if (failed(verifyListOfOperandsOrIntegers(
+          op, "stride", ranks[2], op.static_strides(), op.strides(),
+          ShapedType::isDynamicStrideOrOffset)))
     return failure();
   return success();
 }
 
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void
-printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
-                              ArrayAttr arrayAttr,
-                              llvm::function_ref<bool(int64_t)> isDynamic) {
+void mlir::printListOfOperandsOrIntegers(
+    OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
+    llvm::function_ref<bool(int64_t)> isDynamic) {
   p << '[';
   unsigned idx = 0;
   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
@@ -95,18 +84,9 @@ void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
 }
 
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-///   2. `ssa` is filled with "[%arg0, %arg1]".
-static ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
-                              StringRef attrName, int64_t dynVal,
-                              SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+ParseResult mlir::parseListOfOperandsOrIntegers(
+    OpAsmParser &parser, OperationState &result, StringRef attrName,
+    int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
   if (failed(parser.parseLSquare()))
     return failure();
   // 0-D.

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a3d0db64c5e4..96ab1aa93355 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -351,3 +351,42 @@ func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?
                outs(%b : memref<?x?xf32>)
   return
 }
+// -----
+
+func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
+  %c6 = constant 6 : index
+  %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
+  return %0 : tensor<4x5x?xf32>
+}
+// CHECK: func @init_tensor_canonicalize
+// CHECK:   %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
+// CHECK:   %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
+// CHECK:   return %[[T1]]
+
+// -----
+
+func @init_tensor_static_dim() -> (index, index) {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c6 = constant 6 : index
+  %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
+  %1 = dim %0, %c2 : tensor<4x5x?xf32>
+  %2 = dim %0, %c0 : tensor<4x5x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: func @init_tensor_static_dim
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//  CHECK-DAG:   %[[C6:.+]] = constant 6 : index
+//      CHECK:   return %[[C6]], %[[C4]]
+
+// -----
+
+func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
+  %c2 = constant 2 : index
+  %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
+  %1 = dim %0, %c2 : tensor<4x5x?xf32>
+  return %1 : index
+}
+//      CHECK: func @init_tensor_dynamic_dim
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG0]]

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 3ae339b51fdf..be785ceb70d6 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt -split-input-file %s | FileCheck %s
-// | mlir-opt | FileCheck %s
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
 // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
 //
@@ -698,3 +697,42 @@ func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (
 // CHECK-LABEL: func @memref_reshape_zero_dim
 //       CHECK:   linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
 //       CHECK:   linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+
+// -----
+
+func @init_tensor(%arg0 : index, %arg1 : index)
+{
+  %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
+  %1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
+  return
+}
+// CHECK-LABEL: func @init_tensor
+//       CHECK:   linalg.init_tensor [3, 42] : tensor<3x42xf32>
+//       CHECK:   linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
+
+// -----
+
+func @init_tensor_err(%arg0 : index, %arg1 : index)
+{
+  // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}}
+  %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+func @init_tensor_err(%arg0 : index)
+{
+  // expected-error @+1 {{expected 4 sizes values}}
+  %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+func @init_tensor_err(%arg0 : index)
+{
+  // expected-error @+1 {{expected 2 dynamic sizes values}}
+  %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32>
+  return
+}


        


More information about the llvm-branch-commits mailing list