[Mlir-commits] [mlir] 93fd30b - [mlir][Linalg] Evolve named ops to use assembly form and support linalg on tensors.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 18 03:17:37 PDT 2020
Author: Nicolas Vasilache
Date: 2020-09-18T06:14:30-04:00
New Revision: 93fd30bac3345fea4f5beba3241f1ef4f2f5f419
URL: https://github.com/llvm/llvm-project/commit/93fd30bac3345fea4f5beba3241f1ef4f2f5f419
DIFF: https://github.com/llvm/llvm-project/commit/93fd30bac3345fea4f5beba3241f1ef4f2f5f419.diff
LOG: [mlir][Linalg] Evolve named ops to use assembly form and support linalg on tensors.
This revision allows representing a reduction at the level of linalg on tensors for named ops. When a structured op has a reduction and returns tensor(s), new conventions are added and documented.
As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:
```
linalg.matmul ins(%a, %b : memref<?x?xf32>, tensor<?x?xf32>)
outs(%c : memref<?x?xf32>)
```
, whereas the syntax for a `linalg.matmul` returning a new tensor is:
```
%d = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
init(%c : memref<?x?xf32>)
-> tensor<?x?xf32>
```
Other parts of linalg will be extended accordingly to allow mixed buffer/tensor semantics in the presence of reductions.
Added:
Modified:
mlir/docs/Dialects/Linalg.md
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/include/mlir/IR/OpBase.td
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
mlir/test/Dialect/Linalg/affine.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
mlir/test/Dialect/Linalg/fusion-2-level.mlir
mlir/test/Dialect/Linalg/fusion.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/promote.mlir
mlir/test/Dialect/Linalg/promotion_options.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/standard.mlir
mlir/test/Dialect/Linalg/tile-and-distribute.mlir
mlir/test/Dialect/Linalg/tile.mlir
mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/IR/slice.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 7ae1b73f48a7..140197b16815 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -40,7 +40,8 @@ generic properties that enable [key transformations](#key_transformations),
including lowering to scalar load/store and other operations or to external
library calls and intrinsics.
-These ops can have ***either tensor or buffer operands***.
+These ops can have ***either tensor or buffer operands***, subject to
+[conventions and limitations](#tensors_and_buffers).
### Payload-Carrying Ops<a name="payload_ops"></a>
Linalg defines two payload carrying operations that implement the [structured ops](
@@ -463,6 +464,76 @@ because of empirical evidence building and working on multiple high-level
compilers. As we lay those down and engage more with the community, we expect
multiple rounds of discussions and design changes to the original architecture.
+### Tensors and Buffers: Conventions and Limitations <a name="tensors_and_buffers"></a>
+
+Tensors are immutable SSA values, buffers are mutable regions of memory subject
+to side-effects and aliasing. As a consequence, output buffers are passed as
+operands whereas output tensors are new SSA values corresponding to op results.
+Inputs can be arbitrary tensors or buffers and are always passed as operands.
+
+The following convention is currently in-flight and is in the process of
+replacing other existing conventions. The following convention currently applies
+to "named" structured ops which are auto-generated by the linalg-ods tool.
+
+The convention adopted is as follows:
+
+1. A first block of `ins` op operands hold read-only inputs of ShapedType.
+2. An optional second block of `outs` op operands hold read-write output
+ buffers of MemRefType.
+3. An optional third block of `init` operands hold initialization tensors of
+ RankedTensorType. Such tensors can appear when the op performs a reduction
+ and returns a tensor.
+
+Structured ops with fully parallel semantics, have empty `init`. They may either
+write in-place into `outs` buffers or return new tensors.
+
+Structured ops with reduction semantics and output tensor(s) however have
+additional restrictions:
+
+1. They can only return a single tensor for now.
+2. They cannot have any output buffer operand (i.e. `outs` is empty).
+3. They have exactly one `init` tensor of the same type as the unique output
+ tensor. Such an `init` tensor does not have an explicit associate indexing
+ map. Instead the map of the result tensor is used to signify that the `init`
+ and the `result` are "tied".
+
+Points 1. and 2. keep complexity of the representation in check by allowing only
+a single result tensor, when reductions are present.
+
+Point 3. is related to the fact that SSA values cannot represent in-place
+updates. Instead, linalg adopts a similar convention that exists in e.g.
+`vector.outerproduct`: the value that is reduced into is passed as an explicit
+argument and a new result of the same shape is produced.
+
+It is expected buffer allocation will fold this last input onto the result in a
+single output buffer argument, which is why the same indexing map is required:
+the last input operand is said to be "tied" to the result.
+
+Alternative, more complex representations, would allow for:
+
+1. Multiple results and `init` tensors in arbitrary orders, which could be
+ captured by an extra ArrayAttr of position pairs.
+2. Relaxing the conditions on the indexing map equalities on the each pair and
+ e.g. allow implicit broadcasts of the input.
+
+These representations are deemed unnecessarily complex for now and are left for
+future discussion.
+
+As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:
+
+```
+linalg.matmul ins(%a, %b : memref<?x?xf32>, tensor<?x?xf32>)
+ outs(%c : memref<?x?xf32>)
+```
+
+, whereas the syntax for a `linalg.matmul` returning a new tensor is:
+
+```
+%d = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+ init(%c : tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+```
+
### Data Representation: Views<a name="views"></a>
The current implementation uses the [Strided MemRef (a.k.a View)](
https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio)
@@ -570,10 +641,10 @@ When `mlir-linalg-ods-gen -gen-ods-decl=1` is called, the following ODS is
produced:
```
- def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
- NInputs<2>,
- NOutputs<1>,
- NamedStructuredOpTraits]> { ... }
+def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
+ NInputs<2>,
+ NOutputs<1>,
+ NamedStructuredOpTrait]> { ... }
```
When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced:
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 1b1a2125e95d..6e4a35035110 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -699,6 +699,14 @@ The available directives are as follows:
- `input` must be either an operand or result [variable](#variables), the
`operands` directive, or the `results` directive.
+* `type_ref` ( input )
+
+ - Represents a reference to the type of the given input that must have
+ already been resolved.
+ - `input` must be either an operand or result [variable](#variables), the
+ `operands` directive, or the `results` directive.
+ - Used to pass previously parsed types to custom directives.
+
#### Literals
A literal is either a keyword or punctuation surrounded by \`\`.
@@ -762,6 +770,10 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `Type &`
- Optional: `Type &`
- Variadic: `SmallVectorImpl<Type> &`
+* TypeRef Directives
+ - Single: `Type`
+ - Optional: `Type`
+ - Variadic: `const SmallVectorImpl<Type> &`
When a variable is optional, the value should only be specified if the variable
is present. Otherwise, the value should remain `None` or null.
@@ -788,6 +800,10 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Type`
- Optional: `Type`
- Variadic: `TypeRange`
+* TypeRef Directives
+ - Single: `Type`
+ - Optional: `Type`
+ - Variadic: `TypeRange`
When a variable is optional, the provided value may be null.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 09fc11bc4917..a35964c5eab4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -85,6 +85,11 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
+/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
+/// Assumes `op` is a LinalgOp.
+void getDimsOfType(Operation *op, StringRef iteratorTypeName,
+ SmallVectorImpl<AffineExpr> &res);
+
} // namespace linalg
} // namespace mlir
@@ -96,5 +101,4 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
-
#endif // MLIR_DIALECT_LINALG_LINALGOPS_H_
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 41beab059008..a6c0d16a9ee2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -32,6 +32,7 @@ class NOutputs<int args_out> :
NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
+def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on linalg::View as their
@@ -798,24 +799,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
-class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
- : LinalgStructuredBase_Op<mnemonic, props> {
- string spec = ?;
- // We cannot use an assemblyFormat atm because we need to hook in a custom-
- // built implicit region from a static OpClass method.
- // TODO: Revisit in the future if/when appropriate.
- // let assemblyFormat = "`(` operands `)` attr-dict `:` "
- // "functional-type(operands, results)";
-
- // The parser needs to specialize on the OpType so it has to be auto-generated
- // in the linalg-ods tool.
- let printer = [{ return ::printNamedStructuredOp(p, *this); }];
- let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
- let hasFolder = 1;
- let hasCanonicalizer = 1;
-}
-
-// This file is auto-generated from a tc specification.
+// This file is auto-generated from a TC def specification.
include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
#endif // LINALG_STRUCTURED_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 0e8216cc4268..1e0e85f82c7f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -25,7 +25,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return the number of parallel loops within the current operation.
+ Return the number of parallel loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumParallelLoops",
@@ -38,7 +38,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of reduction loops within the current operation.
+ Return the dims that are parallel loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getParallelDims",
+ /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType($_op, getParallelIteratorTypeName(), res);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of reduction loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumReductionLoops",
@@ -51,7 +63,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of window loops within the current operation.
+ Return the dims that are reduction loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getReductionDims",
+ /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType($_op, getReductionIteratorTypeName(), res);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of window loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumWindowLoops",
@@ -62,6 +86,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
$_op.iterator_types());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are window loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getWindowDims",
+ /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the total number of loops within the current operation.
@@ -99,14 +135,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// linalg.indexed_generic ops).
InterfaceMethod<
/*desc=*/[{
- Return the number of inputs from the current operation.
+ Return the number of inputs.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumInputs"
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of outputs from the current operation.
+ Return the number of outputs.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumOutputs"
@@ -160,7 +196,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the input operands from the current operation.
+ Return the input operands.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getInputs",
@@ -187,7 +223,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return res;
}]
>,
-
//===------------------------------------------------------------------===//
// Output arguments handling.
//===------------------------------------------------------------------===//
@@ -267,7 +302,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}]>,
InterfaceMethod<
/*desc=*/[{
- Return the output buffers (operands) from the current operation.
+ Return the output buffers (operands).
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getOutputBuffers",
@@ -354,7 +389,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return getInputShapedType(i);
if (i < getNumInputsAndOutputBuffers())
return getOutputBufferType(i - $_op.getNumInputs());
- return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
+ return this->getOperation()->getResult(
+ i - getNumInputsAndOutputBuffers()).
+ getType().template cast<ShapedType>();
}]>,
InterfaceMethod<
/*desc=*/[{
@@ -408,11 +445,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::to_vector<4>(
- llvm::map_range($_op.indexing_maps(),
- [](Attribute attr) -> AffineMap {
- return attr.cast<AffineMapAttr>().getValue();
- }));
+ return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange<AffineMapAttr>());
}]
>,
InterfaceMethod<
@@ -425,10 +458,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < getNumInputsAndOutputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast<AffineMapAttr>()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -441,10 +471,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumInputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast<AffineMapAttr>()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -457,10 +484,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumOutputs());
- return $_op.indexing_maps()
- .getValue()[i + $_op.getNumInputs()]
- .template cast<AffineMapAttr>()
- .getValue();
+ return getIndexingMaps()[i + $_op.getNumInputs()];
}]
>,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index c4790ca617f1..ae56dd66f57a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -71,6 +71,80 @@ class StructuredOpTraits
}
};
+/// This class provides a verifier for structured ops that are known to operate
+/// on buffers or tensors and that support `ins`, `outs` and `init` arguments.
+/// This trait must be used in conjunction with an op definition or a trait that
+/// provides the methods `getNumInputs` and `getNumOutputs`.
+///
+/// Use as a trait as follows:
+///
+/// class MatmulOp : public Op<MatmulOp, OpTrait::NamedStructuredOpTrait> {
+///
+template <typename ConcreteType>
+class NamedStructuredOpTrait
+ : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTrait> {
+public:
+ unsigned getNumInputs() {
+ return cast<ConcreteType>(this->getOperation()).inputs().size();
+ }
+ unsigned getNumOutputs() {
+ ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
+ return concreteOp.output_buffers().size() +
+ concreteOp.output_tensors().size();
+ }
+ static LogicalResult verifyTrait(Operation *op) {
+ ConcreteType concreteOp = cast<ConcreteType>(op);
+ unsigned nInputAndBufferOperands =
+ concreteOp.getNumInputsAndOutputBuffers();
+ if (failed(
+ OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
+ return failure();
+
+ SmallVector<AffineExpr, 4> redDims;
+ concreteOp.getReductionDims(redDims);
+ // If no result and no reduction, only check there is no init tensor and we
+ // are done.
+ if (redDims.empty() || op->getNumResults() == 0) {
+ if (!concreteOp.init_tensors().empty())
+ return op->emitError("expected empty `init` when op has no "
+ "results or no reduction dims");
+ return success();
+ }
+
+ // Only a single tensor result supported atm.
+ if (op->getNumResults() != 1)
+ return op->emitError(
+ "expected single tensor result when reduction present");
+
+ if (concreteOp.init_tensors().size() != op->getNumResults())
+ return op->emitError(
+ "expected #init tensors to match #results when reduction present");
+
+ for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx)
+ if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx])
+ return op->emitError("expected init tensor #")
+ << idx << " of the same type as result #" << idx;
+
+ // Output tensor indexing map may not depend on reduction index.
+ // TODO: this is not yet tested. Add a test when linalg.generic switches to
+ // this representation.
+ for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) {
+ AffineMap outputMap = concreteOp.getOutputIndexingMap(idx);
+ for (auto expr : outputMap.getResults()) {
+ for (auto dim : redDims) {
+ unsigned pos = dim.cast<AffineDimExpr>().getPosition();
+ if (expr.isFunctionOfDim(pos))
+ return op->emitError(
+ "unexpected single tensor output indexing map ")
+ << "is function of reduction dim @" << pos;
+ }
+ }
+ }
+
+ return success();
+ }
+};
+
} // namespace linalg
} // namespace OpTrait
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index b038819bca3d..c9103a2b8b63 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -15,8 +15,6 @@
include "mlir/IR/OpBase.td"
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
//===----------------------------------------------------------------------===//
// Shape Inference dialect definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ec0e229ae627..d314393caae2 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -571,6 +571,10 @@ class VectorOfLengthAndType<list<int> allowedLengths,
def AnyVector : VectorOf<[AnyType]>;
+// Shaped types.
+
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
index 8f3c6df79f90..97ea95c8bcd1 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_1d_filled_f32(%s1 : index, %f : f32) -> memref<?xf32> {
}
func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- linalg.conv_1d %arg0, %arg1, %arg2 : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+ linalg.conv_1d ins (%arg0, %arg1: memref<?xf32>, memref<?xf32>)
+ outs (%arg2: memref<?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
index 46634a7e5921..dcfcc9b62bbc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
}
func @conv_1d_ncw(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ linalg.conv_1d_ncw ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%arg2: memref<?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
index a6aeb30fc153..2e79b46801bc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
}
func @conv_1d_nwc(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ linalg.conv_1d_nwc ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%arg2: memref<?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
index 819d95ef5da0..e271b0a009b6 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_2d_filled_f32(%s1 : index, %s2 : index, %f : f32) -> memref<?x?xf32>
}
func @conv_2d(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
- linalg.conv_2d %arg0, %arg1, %arg2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.conv_2d ins (%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+ outs (%arg2: memref<?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
index fb0e70861864..e27c40524fcc 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
}
func @conv_2d_nchw(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
- linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+ linalg.conv_2d_nchw ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+ outs (%arg2: memref<?x?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
index 5888eec7d67a..b5b4a5c82c09 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
}
func @conv_2d_nhwc(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
- linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+ linalg.conv_2d_nhwc ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+ outs (%arg2: memref<?x?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
index f0ca37f86fcd..12ea94696660 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
@@ -30,7 +30,8 @@ func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> me
}
func @conv_3d(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.conv_3d %arg0, %arg1, %arg2 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ linalg.conv_3d ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%arg2: memref<?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
index a56a260b9cd8..e36abc83b700 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
@@ -30,7 +30,8 @@ func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s
}
func @conv_3d_ncdhw(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
- linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ linalg.conv_3d_ncdhw ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ outs (%arg2: memref<?x?x?x?x?xf32>)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
index 37fc6453e5dd..b302b3e0d8bd 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
@@ -30,7 +30,8 @@ func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s
}
func @conv_3d_ndhwc(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
- linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ linalg.conv_3d_ndhwc ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ outs (%arg2: memref<?x?x?x?x?xf32>)
return
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index efe2e45f78ea..7b9ba74f5492 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -35,13 +36,29 @@ using namespace mlir::linalg;
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(
- Builder &builder, OperationState &result, TypeRange operandTypes,
- TypeRange tensorResultTypes);
+ OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
+ TypeRange outputBufferTypes, TypeRange initTensorTypes,
+ TypeRange resultTypes);
+
template <typename NamedStructuredOpType>
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputBufferTypes,
+ TypeRange initTensorTypes, TypeRange resultTypes);
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl<Type> &resultTypes);
+
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
+
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+ TypeRange resultTypes);
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+
template <typename NamedStructuredOpType>
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
@@ -248,11 +265,6 @@ template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
- if (nInputsAndOutputBuffers != llvm::size(op.views()))
- return op.emitOpError("expected exactly ")
- << nInputsAndOutputBuffers
- << " inputs (tensor or buffer) and output buffer operands";
auto ®ion = op.region();
if (!llvm::hasSingleElement(region))
@@ -302,8 +314,27 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
return success();
}
-static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
-static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
+static LogicalResult verify(GenericOp op) {
+ // Temporarily hoisted here to avoid duplicating more code.
+ // TODO: uniformize with named structured ops.
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
+ return verifyGenericOp(op);
+}
+
+static LogicalResult verify(IndexedGenericOp op) {
+ // Temporarily hoisted here to avoid duplicating more code.
+ // TODO: uniformize with named structured ops.
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
+ return verifyGenericOp(op);
+}
//===----------------------------------------------------------------------===//
// ReshapeOp
@@ -1098,12 +1129,28 @@ static LogicalResult verify(PoolingSumOp op) {
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
+/// Assumes `op` is a LinalgOp.
+void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
+ SmallVectorImpl<AffineExpr> &res) {
+ unsigned dim = 0;
+ MLIRContext *ctx = op->getContext();
+ for (auto tn :
+ cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
+ if (tn == iteratorTypeName)
+ res.push_back(getAffineDimExpr(dim, ctx));
+ ++dim;
+ }
+}
+
AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
unsigned rank,
MLIRContext *context) {
@@ -1196,8 +1243,8 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
}
// TODO: Consider making all this boilerplate easy to autogenerate
-// with Tablegen. This seems a desirable property in the context of OpInterfaces
-// where a Linalg "named" op **isa** LinalgOp.
+// with Tablegen. This seems a desirable property in the context of
+// OpInterfaces where a Linalg "named" op **isa** LinalgOp.
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
@@ -1222,23 +1269,28 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
//===----------------------------------------------------------------------===//
template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
- OperationState &result,
- TypeRange operandTypes,
- TypeRange tensorResultTypes) {
- Region ®ion = *result.addRegion();
- Block *body = new Block();
+static void buildNamedStructuredOpRegionAndAttributesImpl(
+ OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
+ TypeRange outputBufferTypes, TypeRange initTensorTypes,
+ TypeRange resultTypes,
+ std::function<void(unsigned, unsigned)> errorHandler) {
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
- for (auto t : operandTypes)
- body->addArgument(getElementTypeOrSelf(t));
- for (auto t : tensorResultTypes)
- body->addArgument(getElementTypeOrSelf(t));
- region.push_back(body);
-
- OpBuilder opBuilder(builder.getContext());
- opBuilder.setInsertionPointToStart(®ion.front());
- mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
+ SmallVector<Type, 8> argTypes;
+ for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
+ for (auto t : containers)
+ argTypes.push_back(getElementTypeOrSelf(t));
+
+ // RAII.
+ OpBuilder::InsertionGuard guard(opBuilder);
+ Block *body = opBuilder.createBlock(®ion, {}, argTypes);
+ unsigned actual = body->getNumArguments();
+ unsigned expected = NamedStructuredOpType::getNumRegionArgs();
+ if (expected != actual)
+ return errorHandler(expected, actual);
+
+ opBuilder.setInsertionPointToStart(body);
+ mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body);
// indexing_maps is an auto-generated method.
@@ -1247,59 +1299,133 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
}
template <typename NamedStructuredOpType>
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
- std::array<StringRef, 2> silentAttrNames{getIndexingMapsAttrName(),
- getIteratorTypesAttrName()};
- p << op.getOperationName() << ' ';
- p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
- p << ' ' << op.getOperands();
- p << " : (" << op.getOperandTypes() << ")";
- auto outputTensorTypes = op.getResultTypes();
- if (!outputTensorTypes.empty())
- p << " -> (" << outputTensorTypes << ")";
+void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputBufferTypes,
+ TypeRange initTensorTypes,
+ TypeRange resultTypes) {
+ Region ®ion = *result.addRegion();
+ buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+ resultTypes, [&](unsigned expected, unsigned actual) {
+ llvm::errs() << "region expects " << expected << " args, got "
+ << actual;
+ assert(expected != actual && "incorrect number of arguments");
+ });
+}
+
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputBufferTypes,
+ TypeRange initTensorTypes, TypeRange resultTypes) {
+ ParseResult res = success();
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+ opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+ resultTypes, [&](unsigned expected, unsigned actual) {
+ res = parser.emitError(parser.getCurrentLocation(),
+ llvm::formatv("region expects {0} args, got {1}",
+ expected, actual));
+ });
+ return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl<Type> &resultTypes) {
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(resultTypes))
+ return failure();
+ return success();
}
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
- result.getContext()->getOrLoadDialect<StandardOpsDialect>();
+ llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
+ initTensorsOperandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
+ outputBuffersOperands, initTensorsOperands;
+ SmallVector<Type, 1> inputsTypes, outputBuffersTypes, initTensorsTypes,
+ outputTensorsTypes;
+ std::unique_ptr<Region> regionRegion = std::make_unique<Region>();
- // Optional attributes may be added.
- if (parser.parseOperandList(operandsInfo) ||
- parser.parseOptionalAttrDict(result.attributes))
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseKeyword("ins") || parser.parseLParen())
return failure();
- SmallVector<Type, 8> operandTypes;
- if (parser.parseColon() || parser.parseLParen() ||
- parser.parseTypeList(operandTypes) || parser.parseRParen())
+ inputsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(inputsOperands) || parser.parseColon() ||
+ parser.parseTypeList(inputsTypes) || parser.parseRParen())
return failure();
- // Generic ops may specify that a subset of its outputs are tensors. Such
- // outputs are specified in the result type.
- SmallVector<Type, 8> tensorResultTypes;
- if (parser.parseOptionalArrowTypeList(tensorResultTypes))
+ if (succeeded(parser.parseOptionalKeyword("outs"))) {
+ outputBuffersOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseLParen() ||
+ parser.parseOperandList(outputBuffersOperands) || parser.parseColon() ||
+ parser.parseTypeList(outputBuffersTypes) || parser.parseRParen())
+ return failure();
+ }
+ if (succeeded(parser.parseOptionalKeyword("init"))) {
+ initTensorsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
+ parser.parseColon() || parser.parseTypeList(initTensorsTypes) ||
+ parser.parseRParen())
+ return failure();
+ }
+
+ if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
return failure();
- if (!tensorResultTypes.empty())
- result.addTypes(tensorResultTypes);
+ if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
+ parser, *regionRegion, inputsTypes, outputBuffersTypes,
+ initTensorsTypes, outputTensorsTypes))
+ return failure();
- // The number of parsed arguments must equal
- // the number of expected arguments for the current operation.
- auto parsedArgs = operandsInfo.size();
- auto expectedArgs = NamedStructuredOpType::getNumInputs() +
- NamedStructuredOpType::getNumOutputs();
- if (parsedArgs != expectedArgs)
- return parser.emitError(parser.getNameLoc(),
- "expects " + std::to_string(expectedArgs) +
- " operands, but found " +
- std::to_string(parsedArgs));
+ if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
+ result.operands) ||
+ parser.resolveOperands(outputBuffersOperands, outputBuffersTypes,
+ outputBuffersOperandsLoc, result.operands) ||
+ parser.resolveOperands(initTensorsOperands, initTensorsTypes,
+ initTensorsOperandsLoc, result.operands))
+ return failure();
- buildNamedStructuredOpRegionAndAttributes<NamedStructuredOpType>(
- parser.getBuilder(), result, operandTypes, tensorResultTypes);
+ result.addTypes(outputTensorsTypes);
+ result.addRegion(std::move(regionRegion));
+ result.addAttribute("operand_segment_sizes",
+ parser.getBuilder().getI32VectorAttr(
+ {static_cast<int32_t>(inputsOperands.size()),
+ static_cast<int32_t>(outputBuffersOperands.size()),
+ static_cast<int32_t>(initTensorsOperands.size())}));
+ return success();
+}
- return parser.resolveOperands(operandsInfo, operandTypes,
- parser.getCurrentLocation(), result.operands);
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+ TypeRange resultTypes) {
+ if (resultTypes.empty())
+ return;
+ p << "-> " << resultTypes;
+}
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+ p << op.getOperationName();
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{"operand_segment_sizes"});
+ p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+ if (!op.output_buffers().empty())
+ p << " outs(" << op.output_buffers() << " : "
+ << op.output_buffers().getTypes() << ")";
+ if (!op.init_tensors().empty())
+ p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
+ << ")";
+ p << " ";
+ printNamedStructuredOpResults(p, op.output_tensors().getTypes());
+ p << " ";
+
+ // Region is elided.
}
template <typename NamedStructuredOpType>
@@ -1354,8 +1480,6 @@ CANONICALIZERS_AND_FOLDERS(FillOp)
CANONICALIZERS_AND_FOLDERS(GenericOp)
CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
-#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
-
// TODO: Determine whether we can generate the folders and verifiers.
CANONICALIZERS_AND_FOLDERS(BatchMatmulOp)
CANONICALIZERS_AND_FOLDERS(DotOp)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index abc82f300f63..4dfc3d605570 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -58,6 +58,8 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void mlir::linalg::LinalgDialect::initialize() {
+ getContext()->getOrLoadDialect("std");
+
addTypes<RangeType>();
addOperations<
#define GET_OP_LIST
@@ -67,6 +69,7 @@ void mlir::linalg::LinalgDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
+
addInterfaces<LinalgInlinerInterface>();
}
diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
index c2e8a31eb443..eeb2ca31fd2a 100644
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -8,7 +8,8 @@
// CHECK-DAG: #[[$map5:.*]] = affine_map<(d0) -> (d0)>
func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- linalg.conv_1d %arg0, %arg1, %arg2 : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+ linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
+ outs(%arg2 : memref<?xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index 13f37d844b8a..0df7db06e4c4 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -15,7 +15,8 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
%B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
%C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
- linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
@@ -102,7 +103,8 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
// Named ops to loops.
//----------------------------------------------------------------------------//
func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
- linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C : memref<?x?x?xf32>)
return
}
// CHECK-LABEL: @named_batch_matmul
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 85321084cd0c..23e9f1784541 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -14,8 +14,9 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
// CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
%4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
- // CHECK: linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>)
- linalg.matmul %3, %3, %3 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
+ linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%3: memref<?x?xf32>)
return %4: memref<?x?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
index 84c56ee3d840..72d76b3d1869 100644
--- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
+++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns
-//| FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s
// CHECK-LABEL: scf_for
func @scf_for(%A : memref<i64>, %step : index) {
diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 44dd268998d2..0c9b9ca0dca7 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -12,7 +12,8 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
%0 = dim %C, %c0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%1 = dim %C, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%2 = dim %D, %c1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
- linalg.matmul %A, %B, %C : (memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>)
scf.for %arg5 = %c0 to %0 step %c20 {
scf.for %arg6 = %c0 to %2 step %c30 {
scf.for %arg7 = %c0 to %1 step %c40 {
@@ -28,7 +29,8 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
%14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %14, %16, %17 : (memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%14, %16: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%17: memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 7fb5a7ab4e85..38e1b43be668 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -14,10 +14,9 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
%0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, 1]>
%1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
%2 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, 1]>,
- memref<?x?xf32, offset: 0, strides: [?, 1]>,
- memref<?x?xf32, offset: 0, strides: [?, 1]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, 1]>,
+ memref<?x?xf32, offset: 0, strides: [?, 1]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, 1]>)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
@@ -30,10 +29,9 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
%8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, 1]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8: memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -61,10 +59,9 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C: memref<?x?xf32, offset: 0, strides: [?, ?]>)
%0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -80,10 +77,9 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -113,10 +109,9 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
%0 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -132,10 +127,9 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -165,14 +159,12 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
- linalg.matmul %A, %B, %D :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
%0 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
@@ -188,10 +180,9 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -227,14 +218,12 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%0 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %D, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
- linalg.matmul %C, %B, %D :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%C, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
scf.for %arg5 = %c0 to %1 step %c2 {
scf.for %arg6 = %c0 to %0 step %c3 {
scf.for %arg7 = %c0 to %2 step %c4 {
@@ -247,10 +236,9 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -275,9 +263,9 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
// CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK: linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]]
-// CHECK: linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]]
-// CHECK: linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]]
+// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
+// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
+// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
@@ -297,14 +285,12 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c3 = constant 3 : index
%c2 = constant 2 : index
%0 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
- linalg.matmul %A, %C, %E :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%E : memref<?x?xf32, offset: 0, strides: [?, ?]>)
%1 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
scf.for %arg5 = %c0 to %1 step %c2 {
@@ -322,10 +308,9 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -359,14 +344,12 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%2 = dim %C, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%3 = dim %C, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%4 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
- linalg.matmul %A, %C, %E :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%E : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
@@ -379,10 +362,9 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %7, %9, %10 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%7, %9 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%10 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -398,10 +380,9 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %7, %9, %10 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%7, %9 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%10 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -414,7 +395,7 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: linalg.matmul %[[A]], %[[C]], %[[E]]
+// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
@@ -445,14 +426,12 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c2 = constant 2 : index
%0 = dim %A, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %A, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
- linalg.matmul %A, %C, %D :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %C : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%D : memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
%2 = dim %D, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
@@ -469,10 +448,9 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %5, %7, %8 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%5, %7 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%8 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -742,10 +720,9 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
%B = alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%C = alloc(%dim, %dim)[%s0, %s1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
- linalg.matmul %A, %B, %C :
- (memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%C : memref<?x?xf32, offset: 0, strides: [?, ?]>)
scf.for %i = %c0 to %dim step %c2 {
scf.for %j = %c0 to %dim step %c3 {
@@ -759,10 +736,9 @@ func @accept_
diff erent_alloc_ops(%dim: index, %s0 : index, %s1: index) {
%2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %0, %1, %2 :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul ins(%0, %1 : memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%2 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 3774aed7ad1f..dce5c21db252 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -428,13 +428,6 @@ func @generic(%arg0: memref<?x?xi4>) {
// -----
-func @generic_result_0_element_type(%arg0: memref<?xf32>) {
- // expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}}
- linalg.dot %arg0, %arg0 : (memref<?xf32>, memref<?xf32>)
-}
-
-// -----
-
func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
// expected-error @+1 {{expects memref ranks to be greater than 2}}
linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
@@ -511,7 +504,8 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
// expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref<?x?xf32>'}}
- linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
+ linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
+ outs(%c3 : memref<?x?x?xf32>)
return
}
@@ -531,3 +525,52 @@ func @generic(%arg0: tensor<?x?xi4>) {
} : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
return
}
+
+// -----
+
+func @empty_init_expected(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+ // expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}}
+ linalg.matmul ins(%m, %m: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%m : memref<?x?xf32>)
+ init(%t : tensor<?x?xf32>)
+ return
+}
+
+// -----
+
+func @incorrect_region_arg_count(%m: memref<?x?xf32>) {
+ // expected-error @+3 {{region expects 3 args, got 4}}
+ %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+ -> tensor<?x?xf32>, tensor<?x?xf32>
+ return
+}
+
+// -----
+
+func @single_tensor_result(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+ // expected-error @+1 {{expected single tensor result when reduction present}}
+ %res:2 = linalg.matmul ins(%m : memref<?x?xf32>)
+ init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
+ -> tensor<?x?xf32>, tensor<?x?xf32>
+ return
+}
+
+// -----
+
+func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+ // expected-error @+1 {{expected #init tensors to match #results when reduction present}}
+ %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+ init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return
+}
+
+// -----
+
+func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
+ // expected-error @+1 {{expected init tensor #0 of the same type as result #0}}
+ %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
+ init(%t : tensor<?x?xf32>)
+ -> tensor<?xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 1e10e036ee2d..04ca27b8e175 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -39,7 +39,8 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
%B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
%C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
- linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
@@ -83,7 +84,8 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
%2 = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
%3 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
%4 = view %arg0[%c0][%N] : memref<?xi8> to memref<?xf32>
- linalg.matvec %2, %3, %4 : (memref<?x?xf32>, memref<?xf32>, memref<?xf32>)
+ linalg.matvec ins(%2, %3: memref<?x?xf32>, memref<?xf32>)
+ outs(%4 : memref<?xf32>)
return
}
// CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
@@ -123,7 +125,8 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
%1 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
%2 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
%3 = view %arg0[%c0][] : memref<?xi8> to memref<f32>
- linalg.dot %1, %2, %3 : (memref<?xf32>, memref<?xf32>, memref<f32>)
+ linalg.dot ins(%1, %2 : memref<?xf32>, memref<?xf32>)
+ outs(%3 : memref<f32>)
return
}
// CHECKLOOP-LABEL: func @dot(%{{.*}}: memref<?xi8>,
@@ -154,9 +157,9 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
- linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>)
+ linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg2: memref<f32>)
return
}
// CHECKLOOP-LABEL: func @dot_view(
@@ -880,7 +883,8 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
// Named ops to loops.
//----------------------------------------------------------------------------//
func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
- linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+ linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C : memref<?x?x?xf32>)
return
}
// CHECKLOOP-LABEL: @named_batch_matmul
@@ -1288,7 +1292,8 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : m
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
- linalg.conv_1d %in, %filter, %out : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
+ linalg.conv_1d ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+ outs(%out : memref<?xf32>)
return
}
@@ -1330,7 +1335,8 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
- linalg.conv_2d %in, %filter, %out : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.conv_2d ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%out: memref<?x?xf32>)
return
}
// CHECKLOOP-LABEL: @conv2d_no_symbols
@@ -1382,7 +1388,8 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
- linalg.conv_3d %in, %filter, %out : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ linalg.conv_3d ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%out : memref<?x?x?xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index ecf8e20e7ef4..7988abbeaf43 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -27,10 +27,10 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
%11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
- linalg.matmul %11, %14, %17 :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul
+ ins(%11, %14: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%17: memref<?x?xf32, offset: ?, strides: [?, 1]>)
}
}
}
@@ -67,10 +67,7 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D_dynamic]]>
//
-// CHECK: linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] :
-// CHECK: memref<?x?xf32, #[[$strided2D_dynamic]]>,
-// CHECK: memref<?x?xf32, #[[$strided2D_dynamic]]>,
-// CHECK: memref<?x?xf32, #[[$strided2D_dynamic]]>
+// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
//
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) :
// CHECK: memref<?x?xf32, #[[$strided2D_dynamic]]>,
@@ -103,10 +100,10 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
%11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
%14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
%17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
- linalg.matmul %11, %14, %17 :
- (memref<?x?xf64, offset: ?, strides: [?, 1]>,
- memref<?x?xf64, offset: ?, strides: [?, 1]>,
- memref<?x?xf64, offset: ?, strides: [?, 1]>)
+ linalg.matmul
+ ins(%11, %14: memref<?x?xf64, offset: ?, strides: [?, 1]>,
+ memref<?x?xf64, offset: ?, strides: [?, 1]>)
+ outs(%17: memref<?x?xf64, offset: ?, strides: [?, 1]>)
}
}
}
@@ -140,10 +137,7 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D_dynamic]]>
//
-// CHECK: linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] :
-// CHECK: memref<?x?xf64, #[[$strided2D_dynamic]]>,
-// CHECK: memref<?x?xf64, #[[$strided2D_dynamic]]>,
-// CHECK: memref<?x?xf64, #[[$strided2D_dynamic]]>
+// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
//
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) :
// CHECK: memref<?x?xf64, #[[$strided2D_dynamic]]>,
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 270a63cf8609..0f38c904eb5a 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -2,8 +2,9 @@
func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "START"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
@@ -26,7 +27,7 @@ func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: linalg.copy(%[[T7]], %[[T19]])
// CHECK: linalg.fill(%[[T21]], %[[C42]])
// CHECK: linalg.copy(%[[T17]], %[[T21]])
-// CHECK: linalg.matmul %[[T19]], %[[T12]], %[[T21]]
+// CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]]
// CHECK-NOT: linalg.fill
// CHECK: linalg.copy(%[[T21]], %[[T17]])
// CHECK: dealloc %[[T18]]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 404c978fa61b..1d58259a578e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -83,30 +83,30 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?xf32, offset: ?, strides: [1]>,
%arg2: memref<?xf32, offset: ?, strides: [1]>,
%arg3: memref<f32>) {
- linalg.matmul %arg0, %arg0, %arg0 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>)
- linalg.matvec %arg0, %arg1, %arg2 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>)
- linalg.dot %arg1, %arg2, %arg3 : (memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>)
+ linalg.matmul ins(%arg0, %arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ linalg.matvec ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg2: memref<?xf32, offset: ?, strides: [1]>)
+ linalg.dot ins(%arg1, %arg2: memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg3: memref<f32>)
return
}
// CHECK-LABEL: func @ops(%
-// CHECK-NEXT: linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref<?x?xf32, #[[$strided2D]]>,
-// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>,
-// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>)
-// CHECK-NEXT: linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref<?x?xf32, #[[$strided2D]]>,
-// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
-// CHECK-SAME: memref<?xf32, #[[$strided1D]]>)
-// CHECK-NEXT: linalg.dot %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref<?xf32, #[[$strided1D]]>,
-// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
-// CHECK-SAME: memref<f32>)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : memref<?x?xf32, #[[$strided2D]]>,
+// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>)
+// CHECK-SAME: outs(%{{.*}} : memref<?x?xf32, #[[$strided2D]]>)
+// CHECK: linalg.matvec
+// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref<?x?xf32, #[[$strided2D]]>,
+// CHECK-SAME: memref<?xf32, #[[$strided1D]]>)
+// CHECK-SAME: outs(%{{.*}}: memref<?xf32, #[[$strided1D]]>)
+// CHECK: linalg.dot
+// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref<?xf32, #[[$strided1D]]>,
+// CHECK-SAME: memref<?xf32, #[[$strided1D]]>)
+// CHECK-SAME: outs(%{{.*}}: memref<f32>)
// -----
@@ -619,17 +619,27 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x?x?xf32, #[[$strided3D]]>
-
-// TODO: Return tensors need a semantics convention update.
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
- %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) {
- linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
- linalg.batch_matmul %ta3, %tb3, %c3 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, memref<?x?x?xf32>) -> ()
- return
+ %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+{
+ linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%c3: memref<?x?x?xf32>)
+ linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%c3: memref<?x?x?xf32>)
+ %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ init(%tc3: tensor<?x?x?xf32>)
+ -> tensor<?x?x?xf32>
+ %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
+ init(%tc3: tensor<?x?x?xf32>)
+ -> tensor<?x?x?xf32>
+ return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
// CHECK-LABEL: func @named_ops
// CHECK: linalg.batch_matmul
// CHECK: linalg.batch_matmul
+// CHECK: linalg.batch_matmul
+// CHECK: linalg.batch_matmul
// -----
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
index 60b348110c4f..638fdb885162 100644
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -13,9 +13,9 @@
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>,
%arg1: memref<?xf32, offset: ?, strides: [1]>,
%arg2: memref<f32>) {
- linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>)
+ linalg.dot ins(%arg0, %arg1: memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg2: memref<f32>)
return
}
// CHECK-LABEL: func @dot(
diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index 08f6d19fe6d6..6ff4be0169fb 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -2,8 +2,9 @@
func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute1"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute1"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -21,14 +22,15 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute2"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute2"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c:memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -52,14 +54,15 @@ func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute3"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute3"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -80,14 +83,15 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute4"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute4"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -108,14 +112,15 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute5"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute5"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -138,14 +143,15 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute6"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "distribute6"}
+ ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -165,4 +171,4 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index 9a1bbfc1dc18..cd20d4f9e537 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -31,10 +31,10 @@
func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul %arg0, %arg1, %arg2 :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul
+ ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>)
return
}
// TILE-2-LABEL: func @matmul(
@@ -50,10 +50,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-2: %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]]
// TILE-2: %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] :
-// TILE-2: (memref<?x?xf32, #[[$strided2D]]>,
-// TILE-2: memref<?x?xf32, #[[$strided2D]]>,
-// TILE-2: memref<?x?xf32, #[[$strided2D]]>)
+// TILE-2: linalg.matmul ins(%[[sAi]]{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul(
// TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -68,10 +65,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-02: %[[localK:.*]] = dim %{{.*}}, %c1
// TILE-02: %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]]
// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-// TILE-02: (memref<?x?xf32, #[[$strided2D]]>,
-// TILE-02: memref<?x?xf32, #[[$strided2D]]>,
-// TILE-02: memref<?x?xf32, #[[$strided2D]]>)
+// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul(
// TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -86,10 +80,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-002: %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]]
// TILE-002: %[[N:.*]] = dim %{{.*}}, %c1 : memref<?x?xf32, #[[$strided2D]]>
// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-// TILE-002: (memref<?x?xf32, #[[$strided2D]]>,
-// TILE-002: memref<?x?xf32, #[[$strided2D]]>,
-// TILE-002: memref<?x?xf32, #[[$strided2D]]>)
+// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul(
// TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -118,10 +109,7 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]]
// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref<?x?xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
//
-// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-// TILE-234: (memref<?x?xf32, #[[$strided2D]]>,
-// TILE-234: memref<?x?xf32, #[[$strided2D]]>,
-// TILE-234: memref<?x?xf32, #[[$strided2D]]>)
+// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
// When the buffer shapes are known at compile time, it is possible to avoid
// the "min" in subview size computation. This test uses buffer sizes divisible
@@ -130,10 +118,10 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul %arg0, %arg1, %arg2 :
- (memref<10x16xf32, offset: ?, strides: [?, 1]>,
- memref<16x12xf32, offset: ?, strides: [?, 1]>,
- memref<10x12xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul
+ ins(%arg0, %arg1: memref<10x16xf32, offset: ?, strides: [?, 1]>,
+ memref<16x12xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>)
return
}
// TILE-2-LABEL: func @matmul_static(
@@ -148,7 +136,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]]
+// TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul_static(
// TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -159,10 +147,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-// TILE-02: (memref<10x16xf32, #[[$strided2D]]>,
-// TILE-02: memref<16x?xf32, #[[$strided2D]]>,
-// TILE-02: memref<10x?xf32, #[[$strided2D]]>)
+// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul_static(
// TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -173,10 +158,7 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
-// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-// TILE-002: (memref<10x?xf32, #[[$strided2D]]>,
-// TILE-002: memref<?x12xf32, #[[$strided2D]]>,
-// TILE-002: memref<10x12xf32, #[[$strided2D]]>)
+// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul_static(
// TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -193,16 +175,13 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
//
-// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-// TILE-234: (memref<?x?xf32, #[[$strided2D]]>,
-// TILE-234: memref<?x?xf32, #[[$strided2D]]>,
-// TILE-234: memref<?x?xf32, #[[$strided2D]]>)
+// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec %arg0, %arg1, %arg2 : (
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>)
+ linalg.matvec
+ ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg2: memref<?xf32, offset: ?, strides: [1]>)
return
}
// TILE-2-LABEL: func @matvec(
@@ -220,7 +199,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-2: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]]
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-2: linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+// TILE-2: linalg.matvec ins(%[[sAi]], %{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matvec(
// TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -237,7 +216,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-02: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]]
// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-02: linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+// TILE-02: linalg.matvec ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-002-LABEL: func @matvec(
// TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -268,12 +247,12 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
//
-// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
+// TILE-234: linalg.matvec ins(%[[sAij]], %[[sBj]]{{.*}} outs(%[[sCi]]
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
- linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>)
+ linalg.dot
+ ins(%arg0, %arg1: memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>)
+ outs(%arg2: memref<f32>)
return
}
// TILE-2-LABEL: func @dot(
@@ -287,7 +266,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
// TILE-2: %[[localM:.*]] = dim %{{.*}}, %c0
// TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]]
// TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-2: linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
+// TILE-2: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
// TILE-02-LABEL: func @dot(
// TILE-02-NOT: scf.for
@@ -306,7 +285,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
// TILE-234: %[[localM:.*]] = dim %{{.*}}, %c0
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-234: linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
+// TILE-234: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32
diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
index 9d86e5e3f50c..a9733cf5b04e 100644
--- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -6,8 +6,8 @@ func @gemm(%arg0 : memref<?x?xf32>,
%arg1 : memref<?x?xf32>,
%arg2 : memref<?x?xf32>)
{
- linalg.matmul %arg0, %arg1, %arg2
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
// CHECK-LABEL: func @gemm
@@ -21,7 +21,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// TILE1-LABEL: func @gemm
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
@@ -30,7 +30,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
-// TILE1: linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]]
+// TILE1: linalg.matmul ins(%[[SV1]], %{{.*}} outs(%[[SV3]]
// TILE2-LABEL: func @gemm
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
@@ -40,7 +40,7 @@ func @gemm(%arg0 : memref<?x?xf32>,
// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-// TILE2: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// TILE2: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 683aeb241318..bc3b7477885a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -5,10 +5,10 @@
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
- linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} :
- (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
- memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
- memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+ linalg.matmul {__internal_linalg_transform__ = "START"}
+ ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+ outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
return
}
@@ -31,7 +31,8 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
- linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref<f32>)
+ linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
+ outs(%C: memref<f32>)
return
}
@@ -39,8 +40,8 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
- linalg.matvec %A, %B, %C :
- (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
+ linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
+ outs(%C: memref<1584xf32>)
return
}
@@ -48,8 +49,8 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
- linalg.matmul %A, %B, %C :
- (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
+ linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+ outs(%C: memref<1584x1584xf32>)
return
}
@@ -57,7 +58,8 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
- linalg.batch_matmul %A, %B, %C :
- (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ linalg.batch_matmul
+ ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ outs(%C: memref<1584x1584x1584xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 1a4100403b00..6d0039676b04 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,10 +14,11 @@
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
%v: memref<f32>) {
- linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } :
- (memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>)
+ linalg.dot { __internal_linalg_transform__ = "MEM" }
+ ins(%x, %y: memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%v: memref<f32>)
+
return
}
// CHECK-LABEL: func @dot
@@ -36,10 +37,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec %A, %x, %y :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>)
+ linalg.matvec
+ ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
+ outs(%y: memref<?xf32, offset: ?, strides: [1]>)
return
}
// CHECK-LABEL: func @matvec
@@ -48,15 +49,17 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.parallel {{.*}} step (%[[c5]])
// CHECK: scf.for {{.*}} step %[[c6]]
-// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK: linalg.matvec
+// CHECK: ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK: outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul { __internal_linalg_transform__ = "MEM" }
+ ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>)
return
}
// CHECK-LABEL: func @matmul
@@ -85,10 +88,9 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
-// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK: linalg.matmul
+// CHECK: ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK: outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
#matmul_trait = {
args_in = 2,
@@ -137,8 +139,9 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
- linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :
- (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>)
+ linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"}
+ ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C: memref<8x32xf32>)
return
}
// CHECK-LABEL: func @vectorization_test_2
@@ -236,10 +239,10 @@ func @permute_generic_indexed(
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
+ linalg.matvec {__internal_linalg_transform__ = "__with_perm__"}
+ ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?xf32, offset: ?, strides: [1]>)
+ outs(%y: memref<?xf32, offset: ?, strides: [1]>)
return
}
// CHECK-LABEL: func @matvec_perm
@@ -248,15 +251,17 @@ func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
-// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK: linalg.matvec
+// CHECK: ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
+// CHECK: outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} :
- (memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ linalg.matmul {__internal_linalg_transform__ = "__with_perm__"}
+ ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>)
return
}
// CHECK-LABEL: func @matmul_perm
@@ -279,10 +284,9 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
-// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK: linalg.matmul
+// CHECK: ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK: outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -304,10 +308,10 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul {__internal_linalg_transform__ = "_promote_views_"}
+ ins(%3, %4: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%5: memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -336,8 +340,9 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK: linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK: linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
-// CHECK: linalg.matmul %[[v0]], %[[v1]], %[[v2]] :
-// CHECK: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>)
func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -359,10 +364,10 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} :
- (memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"}
+ ins(%3, %4: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%5: memref<?x?xf32, offset: ?, strides: [?, ?]>)
}
}
}
@@ -391,10 +396,9 @@ func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
-// CHECK: linalg.matmul %[[v0]], %[[s1]], %[[s2]] :
-// CHECK: (memref<?x?xf32>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>,
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, #[[$STRIDED_2D]]>)
+// CHECK-SAME: outs(%[[s2]] : memref<?x?xf32, #[[$STRIDED_2D]]>)
func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
%c2000 = constant 2000 : index
@@ -421,8 +425,9 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
- linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"}
- : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"}
+ ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
// CHECK-LABEL: func @tile_permute_parallel_loop
diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
index 731f3872f67d..68ddeb6ad9d0 100644
--- a/mlir/test/IR/slice.mlir
+++ b/mlir/test/IR/slice.mlir
@@ -5,8 +5,10 @@ func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
%b = alloc(%arg2, %arg1) : memref<?x?xf32>
%c = alloc(%arg0, %arg1) : memref<?x?xf32>
%d = alloc(%arg0, %arg1) : memref<?x?xf32>
- linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
- linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c : memref<?x?xf32>)
+ linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%d : memref<?x?xf32>)
dealloc %c : memref<?x?xf32>
dealloc %b : memref<?x?xf32>
dealloc %a : memref<?x?xf32>
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index d75422e84124..6602060db8dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -308,6 +308,25 @@ parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
return failure();
return success();
}
+static ParseResult
+parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
+ Type optOperandType,
+ const SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseKeyword("type_refs_capture"))
+ return failure();
+
+ Type operandType2, optOperandType2;
+ SmallVector<Type, 1> varOperandTypes2;
+ if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+ varOperandTypes2))
+ return failure();
+
+ if (operandType != operandType2 || optOperandType != optOperandType2 ||
+ varOperandTypes != varOperandTypes2)
+ return failure();
+
+ return success();
+}
static ParseResult parseCustomDirectiveOperandsAndTypes(
OpAsmParser &parser, OpAsmParser::OperandType &operand,
Optional<OpAsmParser::OperandType> &optOperand,
@@ -365,6 +384,14 @@ static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
printer << ", " << optOperandType;
printer << " -> (" << varOperandTypes << ")";
}
+static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+ Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " type_refs_capture ";
+ printCustomDirectiveResults(printer, operandType, optOperandType,
+ varOperandTypes);
+}
static void
printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
Value optOperand, OperandRange varOperands,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ae36ed1710c..2a3f5929f08e 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -53,8 +53,6 @@ def ComplexTensorOp : TEST_Op<"complex_f64_tensor"> {
let results = (outs TensorOf<[ComplexF64]>);
}
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
def TupleOp : TEST_Op<"tuple_32_bit"> {
let results = (outs TupleOf<[I32, F32]>);
}
@@ -1518,6 +1516,22 @@ def FormatCustomDirectiveResults
}];
}
+def FormatCustomDirectiveResultsWithTypeRefs
+ : TEST_Op<"format_custom_directive_results_with_type_refs",
+ [AttrSizedResultSegments]> {
+ let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+ Variadic<AnyType>:$varResults);
+ let assemblyFormat = [{
+ custom<CustomDirectiveResults>(
+ type($result), type($optResult), type($varResults)
+ )
+ custom<CustomDirectiveWithTypeRefs>(
+ type_ref($result), type_ref($optResult), type_ref($varResults)
+ )
+ attr-dict
+ }];
+}
+
def FormatCustomDirectiveSuccessors
: TEST_Op<"format_custom_directive_successors", [Terminator]> {
let successors = (successor AnySuccessor:$successor,
diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
index ba2ea59cb22d..846011a0f488 100644
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -51,7 +51,8 @@ func @dot() -> f32 {
%B = view %bB[%c0][%c16] : memref<?xi8> to memref<?xf32>
%C = view %bC[%c0][] : memref<?xi8> to memref<f32>
- linalg.dot %A, %B, %C : (memref<?xf32>, memref<?xf32>, memref<f32>)
+ linalg.dot ins(%A, %B : memref<?xf32>, memref<?xf32>)
+ outs(%C : memref<f32>)
%res = load %C[] : memref<f32>
dealloc %bC : memref<?xi8>
@@ -83,7 +84,8 @@ func @matmul() -> f32 {
%B = view %bB[%c0][%c16, %c2] : memref<?xi8> to memref<?x?xf32>
%C = view %bC[%c0][%c2, %c2] : memref<?xi8> to memref<?x?xf32>
- linalg.matmul %A, %B, %C : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C : memref<?x?xf32>)
%res = load %C[%c0, %c1] : memref<?x?xf32>
dealloc %bC : memref<?xi8>
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index aad983eb85d2..9183f3a85b48 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -1,9 +1,9 @@
// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
-// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [
+// ODS-NEXT: NamedStructuredOpTrait
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
@@ -25,9 +25,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
}
-// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
+// ODS-NEXT: NamedStructuredOpTrait
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
@@ -49,9 +49,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
}
-// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
+// ODS-NEXT: NamedStructuredOpTrait
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 60189943ddab..0addff9f35fb 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -230,8 +230,66 @@ def DirectiveTypeZOperandInvalidH : TestFormat_Op<"type_operand_invalid_h", [{
def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{
type($result) type($result)
}]>, Results<(outs I64:$result)>;
+
+//===----------------------------------------------------------------------===//
+// type_ref
+
+// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{
+ type_ref($operand) type(operands)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{
+ type_ref(operands) type($operand)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{
+ type_ref($operand) type($operand)
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{
+ type_ref($result) type(results)
+}]>, Results<(outs I64:$result)>;
+// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{
+ type_ref(results) type($result)
+}]>, Results<(outs I64:$result)>;
+// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
+def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{
+ type_ref($result) type($result)
+}]>, Results<(outs I64:$result)>;
+
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{
+ type_ref(operands) attr-dict
+}]>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{
+ type(operands) type_ref($operand) attr-dict
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{
+ type($operand) type_ref($operand) attr-dict
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{
+ type(results) type_ref(results) attr-dict
+}]>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{
+ type($result) type_ref(results) attr-dict
+}]>, Results<(outs I64:$result)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{
+ type(results) type_ref($result) attr-dict
+}]>, Results<(outs I64:$result)>;
+// CHECK-NOT: error
+def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{
+ type($result) type_ref($result) attr-dict
+}]>, Results<(outs I64:$result)>;
+
// CHECK-NOT: error:
-def DirectiveTypeZOperandValid : TestFormat_Op<"type_operand_valid", [{
+def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{
type(operands) type(results) attr-dict
}]>;
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 9f7c9c0f4809..5066fe5a24e6 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -237,6 +237,12 @@ test.format_custom_directive_results : i64, i64 -> (i64)
// CHECK: test.format_custom_directive_results : i64 -> (i64)
test.format_custom_directive_results : i64 -> (i64)
+// CHECK: test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64)
+test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
+test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
+
func @foo() {
// CHECK: test.format_custom_directive_successors ^bb1, ^bb2
test.format_custom_directive_successors ^bb1, ^bb2
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 59d655684f48..99b0c03ce521 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -980,7 +980,7 @@ class TCParser {
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
- StringRef linalgOpName);
+ StringRef linalgOpName, ComprehensionParsingState &state);
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
@@ -1419,7 +1419,8 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
return failure();
}
if (genODSDecl) {
- printODS(os, cppOpName, tcName);
+ auto &state = perComprehensionStates.back();
+ printODS(os, cppOpName, tcName, state);
os << "\n";
}
if (genODSImpl) {
@@ -1442,31 +1443,72 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
- StringRef linalgOpName) {
- const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [
- NInputs<{2}>,
- NOutputs<{3}>,
+ StringRef linalgOpName,
+ ComprehensionParsingState &state) {
+ const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
+ NamedStructuredOpTrait,
+ AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"YieldOp">]> {
- let arguments = (ins Variadic<LinalgOperand>:$views);
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyMemRef>:$output_buffers,
+ Variadic<AnyRankedTensor>:$init_tensors);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
- let regions = (region SizedRegion<1>:$region);
- let builders = [OpBuilder<
- "OpBuilder &b, OperationState &result, TypeRange outputTypes, "
- # "ValueRange views",
+ let regions = (region AnyRegion:$region);
+
+ let builders = [ OpBuilder<
+ "OpBuilder &b, OperationState &result,"
+ "ValueRange inputs, ValueRange outputBuffers",
+ [{{
+ result.addOperands(inputs);
+ result.addOperands(outputBuffers);
+ result.addAttribute(
+ "operand_segment_sizes",
+ b.getI32VectorAttr({{static_cast<int32_t>(inputs.size()),
+ static_cast<int32_t>(outputBuffers.size()),
+ static_cast<int32_t>(0)}));
+ buildNamedStructuredOpRegionAndAttributes<{0}>(
+ b,
+ result,
+ TypeRange(inputs),
+ TypeRange(outputBuffers),
+ TypeRange(),
+ TypeRange());
+ }]>, OpBuilder<
+ "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors",
[{{
- result.addOperands(views);
- result.addTypes(outputTypes);
+ result.addOperands(inputs);
+ result.addOperands(outputBuffers);
+ result.addOperands(initTensors);
+ result.addTypes(resultTensorTypes);
+ result.addAttribute(
+ "operand_segment_sizes",
+ b.getI32VectorAttr({{static_cast<int32_t>(inputs.size()),
+ static_cast<int32_t>(outputBuffers.size()),
+ static_cast<int32_t>(initTensors.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
- b, result, TypeRange(views), outputTypes);
+ b,
+ result,
+ TypeRange(inputs),
+ TypeRange(outputBuffers),
+ TypeRange(initTensors),
+ resultTensorTypes);
}]>
];
- let parser = [{
- return ::parseNamedStructuredOp<{0}>(parser, result);
- }];
+ let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
+ let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
+ let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{{
+ // Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
+
+ // Generic methods.
+ static unsigned getNumRegionArgs() {{ return {4}; }
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
@@ -1481,7 +1523,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
nInputs++;
}
- os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
+ os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
+ state.orderedTensorArgs.size());
}
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -1680,7 +1723,7 @@ int main(int argc, char **argv) {
}
// Include the proper Linalg header for end-to-end tblgen testing without
- // resorting to non-portable shgell manipulations.
+ // resorting to non-portable shell manipulations.
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 01877855802d..7658963d6f9b 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -53,6 +53,7 @@ class Element {
ResultsDirective,
SuccessorsDirective,
TypeDirective,
+ TypeRefDirective,
/// This element is a literal.
Literal,
@@ -230,7 +231,19 @@ class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
/// The operand that is used to format the directive.
std::unique_ptr<Element> operand;
};
-} // end anonymous namespace
+
+/// This class represents the `type_ref` directive.
+class TypeRefDirective
+ : public DirectiveElement<Element::Kind::TypeRefDirective> {
+public:
+ TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
+ Element *getOperand() const { return operand.get(); }
+
+private:
+ /// The operand that is used to format the directive.
+ std::unique_ptr<Element> operand;
+};
+} // namespace
//===----------------------------------------------------------------------===//
// LiteralElement
@@ -805,6 +818,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
<< llvm::formatv(
" ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
name);
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+ ArgumentLengthKind lengthKind;
+ StringRef name = getTypeListName(dir->getOperand(), lengthKind);
+ // Refer to the previously encountered TypeDirective for name.
+ // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
+ // to properly track the types that will be parsed and pushed later on.
+ if (lengthKind != ArgumentLengthKind::Single)
+ body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name
+ << "TypesRef(" << name << "Types);\n";
+ else
+ body << llvm::formatv(
+ " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
+ name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
ArgumentLengthKind ignored;
body << " ::llvm::ArrayRef<::mlir::Type> "
@@ -844,6 +870,15 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
else
body << llvm::formatv("{0}Successor", name);
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv("{0}TypesRef", listName);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv("{0}TypeRef", listName);
+ else
+ body << formatv("{0}RawTypesRef[0]", listName);
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -876,6 +911,16 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
"{0}Operand;\n",
operand->getVar()->name);
}
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
+ // Reference to an optional which may or may not have been set.
+ // Retrieve from vector if not empty.
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(
+ " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
+ "? Type() : {0}TypesRef[0];\n",
+ listName);
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -907,6 +952,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
+ // In the `type_ref` case, do not parse a new Type that needs to be added.
+ // Just do nothing here.
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1101,6 +1149,15 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv(variadicTypeParserCode, listName);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(optionalTypeParserCode, listName);
+ else
+ body << formatv(typeParserCode, listName);
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1431,6 +1488,17 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
body << successor->getVar()->name << "()";
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
+ auto *typeOperand = dir->getOperand();
+ auto *operand = dyn_cast<OperandVariable>(typeOperand);
+ auto *var = operand ? operand->getVar()
+ : cast<ResultVariable>(typeOperand)->getVar();
+ if (var->isVariadic())
+ body << var->name << "().getTypes()";
+ else if (var->isOptional())
+ body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+ else
+ body << var->name << "().getType()";
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
auto *typeOperand = dir->getOperand();
auto *operand = dyn_cast<OperandVariable>(typeOperand);
@@ -1604,6 +1672,9 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
+ } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
+ body << " p << ";
+ genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
body << " p.printFunctionalType(";
genTypeOperandPrinter(dir->getInputs(), body) << ", ";
@@ -1670,6 +1741,7 @@ class Token {
kw_results,
kw_successors,
kw_type,
+ kw_type_ref,
keyword_end,
// String valued tokens.
@@ -1874,6 +1946,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
.Case("results", Token::kw_results)
.Case("successors", Token::kw_successors)
.Case("type", Token::kw_type)
+ .Case("type_ref", Token::kw_type_ref)
.Default(Token::identifier);
return Token(kind, str);
}
@@ -1994,8 +2067,9 @@ class FormatParser {
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
- bool isTopLevel);
- LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element);
+ bool isTopLevel, bool isTypeRef = false);
+ LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
+ bool isTypeRef = false);
//===--------------------------------------------------------------------===//
// Lexer Utilities
@@ -2440,6 +2514,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_successors:
return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
+ case Token::kw_type_ref:
+ return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
case Token::kw_type:
return parseTypeDirective(element, dirTok, isTopLevel);
@@ -2505,7 +2581,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
return ::mlir::success();
};
for (auto &ele : elements) {
- if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+ if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
+ if (failed(checkTypeOperand(typeEle->getOperand())))
+ return failure();
+ } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
if (failed(checkTypeOperand(typeEle->getOperand())))
return ::mlir::failure();
} else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
@@ -2565,7 +2644,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
// Literals, custom directives, and type directives may be used,
// but they can't anchor the group.
.Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
- OptionalElement, TypeDirective>([&](Element *) {
+ OptionalElement, TypeRefDirective, TypeDirective>([&](Element *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
@@ -2628,6 +2707,13 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
// After parsing all of the elements, ensure that all type directives refer
// only to variables.
for (auto &ele : elements) {
+ if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
+ if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
+ return emitError(curLoc,
+ "type_ref directives within a custom directive "
+ "may only refer to variables");
+ }
+ }
if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
return emitError(curLoc, "type directives within a custom directive "
@@ -2649,8 +2735,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return ::mlir::failure();
// Verify that the element can be placed within a custom directive.
- if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
- SuccessorVariable>(parameters.back().get())) {
+ if (!isa<TypeRefDirective, TypeDirective, AttributeVariable, OperandVariable,
+ RegionVariable, SuccessorVariable>(parameters.back().get())) {
return emitError(childLoc, "only variables and types may be used as "
"parameters to a custom directive");
}
@@ -2727,22 +2813,26 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
LogicalResult
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
- bool isTopLevel) {
+ bool isTopLevel, bool isTypeRef) {
llvm::SMLoc loc = tok.getLoc();
if (!isTopLevel)
return emitError(loc, "'type' is only valid as a top-level directive");
std::unique_ptr<Element> operand;
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
- failed(parseTypeDirectiveOperand(operand)) ||
+ failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
- element = std::make_unique<TypeDirective>(std::move(operand));
+ if (isTypeRef)
+ element = std::make_unique<TypeRefDirective>(std::move(operand));
+ else
+ element = std::make_unique<TypeDirective>(std::move(operand));
return ::mlir::success();
}
LogicalResult
-FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element) {
+FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
+ bool isTypeRef) {
llvm::SMLoc loc = curToken.getLoc();
if (failed(parseElement(element, /*isTopLevel=*/false)))
return ::mlir::failure();
@@ -2752,23 +2842,36 @@ FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element) {
if (auto *var = dyn_cast<OperandVariable>(element.get())) {
unsigned opIdx = var->getVar() - op.operand_begin();
- if (fmt.allOperandTypes || seenOperandTypes.test(opIdx))
+ if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
return emitError(loc, "'type' of '" + var->getVar()->name +
"' is already bound");
+ if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
+ return emitError(loc, "'type_ref' of '" + var->getVar()->name +
+ "' is not bound by a prior 'type' directive");
seenOperandTypes.set(opIdx);
} else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
unsigned resIdx = var->getVar() - op.result_begin();
- if (fmt.allResultTypes || seenResultTypes.test(resIdx))
+ if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
return emitError(loc, "'type' of '" + var->getVar()->name +
"' is already bound");
+ if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
+ return emitError(loc, "'type_ref' of '" + var->getVar()->name +
+ "' is not bound by a prior 'type' directive");
seenResultTypes.set(resIdx);
} else if (isa<OperandsDirective>(&*element)) {
- if (fmt.allOperandTypes || seenOperandTypes.any())
+ if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
return emitError(loc, "'operands' 'type' is already bound");
+ if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
+ return emitError(
+ loc,
+ "'operands' 'type_ref' is not bound by a prior 'type' directive");
fmt.allOperandTypes = true;
} else if (isa<ResultsDirective>(&*element)) {
- if (fmt.allResultTypes || seenResultTypes.any())
+ if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
return emitError(loc, "'results' 'type' is already bound");
+ if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
+ return emitError(
+ loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
fmt.allResultTypes = true;
} else {
return emitError(loc, "invalid argument to 'type' directive");
More information about the Mlir-commits
mailing list