[llvm-branch-commits] [mlir] b7ae1d3 - [mlir][Linalg] Revisit the Linalg on tensors abstraction
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 21 12:34:21 PST 2020
Author: nicolasvasilache
Date: 2020-12-21T12:29:10-08:00
New Revision: b7ae1d3d2b1b1d73374a0583150c452273318268
URL: https://github.com/llvm/llvm-project/commit/b7ae1d3d2b1b1d73374a0583150c452273318268
DIFF: https://github.com/llvm/llvm-project/commit/b7ae1d3d2b1b1d73374a0583150c452273318268.diff
LOG: [mlir][Linalg] Revisit the Linalg on tensors abstraction
This revision drops init_tensor arguments from Linalg on tensors and instead uniformizes the output buffers and output tensors to be consistent.
This significantly simplifies the usage of Linalg on tensors and is a stepping stone for
its evolution towards a mixed tensor and shape abstraction discussed in https://llvm.discourse.group/t/linalg-and-shapes/2421/19.
Differential Revision: https://reviews.llvm.org/D93469
Added:
Modified:
mlir/docs/Dialects/Linalg.md
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/IR/OpBase.td
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/parallel-loops.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/sparse_1d.mlir
mlir/test/Dialect/Linalg/sparse_2d.mlir
mlir/test/Dialect/Linalg/sparse_3d.mlir
mlir/test/Dialect/Linalg/sparse_invalid.mlir
mlir/test/Dialect/Linalg/sparse_parallel.mlir
mlir/test/Dialect/Linalg/sparse_storage.mlir
mlir/test/Dialect/Linalg/tile-and-distribute.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/Dialect/Linalg/tile-tensors.mlir
mlir/test/EDSC/builder-api-test.cpp
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
################################################################################
diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 02508a81b63a..18473f4cb796 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -21,8 +21,8 @@ Linalg. They are all implemented in terms of the properties of the
one-off op knowledge.
The textual form description of these transformations is left for future work.
-Still, it is useful to at least the key transformations that are performed on
-the Linalg IR and that have influenced its design:
+Still, it is useful to list the key transformations that are performed on the
+Linalg IR and that have influenced its design:
1. Progressive Buffer Allocation.
1. Parametric Tiling.
@@ -42,8 +42,25 @@ Linalg takes at least some inspiration from all previously
[key transformations](#key_transformations), including lowering to scalar
load/store and other operations or to external library calls and intrinsics.
-These ops can have ***either tensor or buffer operands***, subject to
-[conventions and limitations](#tensors_and_buffers).
+These ops can have ***either tensor or buffer*** as both input and output
+operands. Output tensors operands serve the purpose of providing a unifying
+abstraction and give a shape to the results. Output tensors can come in 2
+flavors and are always associated with a corresponding op result:
+
+1. an "init tensor" output value which provides an initial value for a tensor
+ that is created by iteratively updating the result (also called "destructive
+ updates"). Such tensor is always materialized in some form. If enough fusion
+ occurs it may end up being materialized only as a register-level SSA value.
+ It is expected (but not required) that the destructive update pattern can be
+ rewritten as an inplace update on buffers.
+
+2. a "shape-only" tensor output value whose underlying elements are not used in
+ the payload computation and only serves the purpose of carrying shape
+ information to lower levels of abstraction. In the future this will be
+ replaced by an appropriate shape type when it is available as a builtin type
+ (see the discourse discussion
+ [Linalg and Shapes](https://llvm.discourse.group/t/linalg-and-shapes/2421)
+ for more details).
### Payload-Carrying Ops<a name="payload_ops"></a>
@@ -125,14 +142,15 @@ instance, it guarantees no out-of bounds access can occur by construction
(assuming dynamic operand dimensions agree with each other, which is the purpose
of the `assert` runtime check).
-Before lowering to loop form, loop induction variables and iterators are *not
-yet materialized*. This is a necessary property if we want an abstraction that
-works on both tensor values and buffers because ***values don’t escape
-loops/nesting***.
+Before lowering to loop form, loop induction variables and iterators are
+implicit (i.e. *not yet materialized*).
-The main implications are that: 1. The semantics of the ops are *restricted to
-operate on structured data types*, on which we can define an iterator. 2. This
-does not model arbitrary code with side-effects.
+The main implications are that:
+
+1. The semantics of the ops are *restricted to operate on structured data
+ types*, on which we can define an iterator.
+
+2. This does not model arbitrary code with side-effects.
We do not think these are serious limitations in practice because MLIR is all
about mixing
diff erent levels of abstractions in the same IR. As long as Linalg
@@ -483,76 +501,6 @@ because of empirical evidence building and working on multiple high-level
compilers. As we lay those down and engage more with the community, we expect
multiple rounds of discussions and design changes to the original architecture.
-### Tensors and Buffers: Conventions and Limitations <a name="tensors_and_buffers"></a>
-
-Tensors are immutable SSA values, buffers are mutable regions of memory subject
-to side-effects and aliasing. As a consequence, output buffers are passed as
-operands whereas output tensors are new SSA values corresponding to op results.
-Inputs can be arbitrary tensors or buffers and are always passed as operands.
-
-The following convention is currently in-flight and is in the process of
-replacing other existing conventions. The following convention currently applies
-to "named" structured ops which are auto-generated by the linalg-ods tool.
-
-The convention adopted is as follows:
-
-1. A first block of `ins` op operands hold read-only inputs of ShapedType.
-2. An optional second block of `outs` op operands hold read-write output
- buffers of MemRefType.
-3. An optional third block of `init` operands hold initialization tensors of
- RankedTensorType. Such tensors can appear when the op performs a reduction
- and returns a tensor.
-
-Structured ops with fully parallel semantics, have empty `init`. They may either
-write in-place into `outs` buffers or return new tensors.
-
-Structured ops with reduction semantics and output tensor(s) however have
-additional restrictions:
-
-1. They can only return a single tensor for now.
-2. They cannot have any output buffer operand (i.e. `outs` is empty).
-3. They have exactly one `init` tensor of the same type as the unique output
- tensor. Such an `init` tensor does not have an explicit associate indexing
- map. Instead the map of the result tensor is used to signify that the `init`
- and the `result` are "tied".
-
-Points 1. and 2. keep complexity of the representation in check by allowing only
-a single result tensor, when reductions are present.
-
-Point 3. is related to the fact that SSA values cannot represent in-place
-updates. Instead, linalg adopts a similar convention that exists in e.g.
-`vector.outerproduct`: the value that is reduced into is passed as an explicit
-argument and a new result of the same shape is produced.
-
-It is expected buffer allocation will fold this last input onto the result in a
-single output buffer argument, which is why the same indexing map is required:
-the last input operand is said to be "tied" to the result.
-
-Alternative, more complex representations, would allow for:
-
-1. Multiple results and `init` tensors in arbitrary orders, which could be
- captured by an extra ArrayAttr of position pairs.
-2. Relaxing the conditions on the indexing map equalities on the each pair and
- e.g. allow implicit broadcasts of the input.
-
-These representations are deemed unnecessarily complex for now and are left for
-future discussion.
-
-As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:
-
-```
-linalg.matmul ins(%a, %b : memref<?x?xf32>, tensor<?x?xf32>)
- outs(%c : memref<?x?xf32>)
-```
-
-, whereas the syntax for a `linalg.matmul` returning a new tensor is:
-
-```
-%d = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
- init(%c : tensor<?x?xf32>)
- -> tensor<?x?xf32>
-```
-
### Data Representation: Views<a name="views"></a>
The current implementation uses the
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 4ee5fac7f677..9aa50c25cd79 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -45,19 +45,17 @@ class Aliases {
class LinalgDependenceGraph {
public:
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
- struct LinalgOpView {
- Operation *op;
- unsigned operandIndex;
- };
+ // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
+ // need an extension to use OpResult.
struct LinalgDependenceGraphElem {
// dependentOpView may be either:
// 1. src in the case of dependencesIntoGraphs.
// 2. dst in the case of dependencesFromDstGraphs.
- LinalgOpView dependentOpView;
+ OpOperand *dependentOpView;
// View in the op that is used to index in the graph:
// 1. src in the case of dependencesFromDstGraphs.
// 2. dst in the case of dependencesIntoGraphs.
- LinalgOpView indexingOpView;
+ OpOperand *indexingOpView;
// Type of the dependence.
DependenceType dependenceType;
};
@@ -161,8 +159,8 @@ class LinalgDependenceGraph {
// Uses std::pair to keep operations and view together and avoid usage errors
// related to src/dst and producer/consumer terminology in the context of
// dependences.
- void addDependenceElem(DependenceType dt, LinalgOpView indexingOpView,
- LinalgOpView dependentOpView);
+ void addDependenceElem(DependenceType dt, OpOperand *indexingOpView,
+ OpOperand *dependentOpView);
/// Implementation detail for findCoveringxxx.
SmallVector<Operation *, 8>
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index ac9ca9581f0d..43dff8150f77 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -30,8 +30,8 @@ class ParallelOp;
namespace edsc {
inline void defaultRegionBuilder(ValueRange args) {}
-/// Build a `linalg.generic` op with the specified `inputs`, `outputBuffers`,
-/// `initTensors`, `resultTensorsTypes` and `region`.
+/// Build a `linalg.generic` op with the specified `inputs`, `outputs`,
+/// `resultTensorsTypes` and `region`.
///
/// `otherValues` and `otherAttributes` may be passed and will be appended as
/// operands and attributes respectively.
@@ -41,15 +41,12 @@ inline void defaultRegionBuilder(ValueRange args) {}
///
/// 1. `inputs` may contain StructuredIndexed that capture either buffer or
/// tensor values.
-/// 2. `outputsBuffers` may contain StructuredIndexed that capture buffer
-/// values.
-/// 3. `initTensors` contain tensor values, without indexing maps.
-/// 4. `resultTensorTypes` may contain StructuredIndexed that capture return
-/// tensor types.
+/// 2. `outputs` may contain StructuredIndexed that capture either buffer or
+/// tensor values. In the future this will be extended with ranked shape values.
+/// 4. `resultTensorTypes` may contain return tensor types.
Operation *makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
- ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
- ArrayRef<StructuredIndexed> resultTensorTypes,
+ ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder = defaultRegionBuilder,
ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
index d842069f6570..0b53fc7573a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
@@ -18,6 +18,7 @@ namespace intrinsics {
using linalg_copy = OperationBuilder<linalg::CopyOp>;
using linalg_dot = OperationBuilder<linalg::DotOp>;
using linalg_fill = OperationBuilder<linalg::FillOp>;
+using linalg_init_tensor = ValueBuilder<linalg::InitTensorOp>;
using linalg_matmul = OperationBuilder<linalg::MatmulOp>;
using linalg_matvec = OperationBuilder<linalg::MatvecOp>;
using linalg_vecmat = OperationBuilder<linalg::VecmatOp>;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 2438338a534f..b1ac1a3b48b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_
#define MLIR_DIALECT_LINALG_LINALGOPS_H_
-#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -111,9 +110,17 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
void getDimsOfType(Operation *op, StringRef iteratorTypeName,
SmallVectorImpl<AffineExpr> &res);
+namespace detail {
+LogicalResult verifyStructuredOpInterface(Operation *op);
+} // namespace detail
} // namespace linalg
} // namespace mlir
+namespace mlir {
+namespace linalg {
+class IndexedGenericOp;
+} // namespace linalg
+} // namespace mlir
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 66f39104d7e7..26db4c2f6735 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -19,26 +19,6 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-// The Linalg `NInputs` trait provides the API for ops that are known
-// to have a specified number of inputs, all passed as operands.
-// See Linalg/LinalgTraits.h for implementation details and usage.
-class NInputs<int n> :
- NativeOpTrait<"linalg::NInputs<" # !cast<string>(n) # ">::Impl"> {}
-
-// The Linalg `ZeroInitTensors` trait provides the API for ops that are known
-// to not have input tensor operands.
-// See Linalg/LinalgTraits.h for implementation details and usage.
-def ZeroInitTensors : NativeOpTrait<"linalg::ZeroInitTensors"> {}
-
-// The Linalg `NOutputs` trait provides the API for ops that are known
-// to have a specified number of outputs, all passed as operands.
-// See Linalg/LinalgTraits.h for implementation details and usage.
-class NOutputs<int n> :
- NativeOpTrait<"linalg::NOutputs<" # !cast<string>(n) # ">::Impl"> {}
-
-def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
-def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
-
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on ShapedType as their
// first operands. These may be optionally followed by non-view operands
@@ -50,7 +30,6 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
: LinalgStructuredBase_Op<mnemonic,
!listconcat(props, [
- StructuredOpTraits,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>])> {
code libraryCallName = [{
std::string getLibraryCallName() {
@@ -65,12 +44,7 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
//===----------------------------------------------------------------------===//
// At the moment these are not declarative and require a bunch of C++ code.
// In the future, these should be migrated to a declarative specification.
-def CopyOp : LinalgStructured_Op<"copy", [
- CopyOpInterface,
- NInputs<1>,
- ZeroInitTensors,
- NOutputs<1>
- ]> {
+def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
let description = [{
Copies the data in the input view into the output view.
@@ -137,6 +111,9 @@ def CopyOp : LinalgStructured_Op<"copy", [
}]>];
let extraClassDeclaration = libraryCallName # [{
+ ValueRange inputs() { return getOperands().take_front(); }
+ ValueRange outputs() { return getOperands().take_back(); }
+
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
ArrayAttr iterator_types() {
@@ -170,14 +147,13 @@ def CopyOp : LinalgStructured_Op<"copy", [
let hasCanonicalizer = 1;
}
-def FillOp : LinalgStructured_Op<"fill", [
- NInputs<0>,
- ZeroInitTensors,
- NOutputs<1>]> {
-
+def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyStridedMemRef:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let extraClassDeclaration = libraryCallName # [{
+ ValueRange inputs() { return {}; }
+ ValueRange outputs() { return getOperands().take_front(); }
+
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
ArrayAttr iterator_types() {
@@ -276,13 +252,8 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
}];
}
-def ConvOp : PoolingBase_Op<"conv", [
- NInputs<2>,
- // Despite having reductions, this manually defined ConvOp may only take
- // memref operands and can never have init tensors.
- ZeroInitTensors,
- NOutputs<1>]> {
-
+// Only support buffer semantics.
+def ConvOp : PoolingBase_Op<"conv", []> {
let description = [{
Generic n-D convolution as described in the TF documentation:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
@@ -313,6 +284,9 @@ def ConvOp : PoolingBase_Op<"conv", [
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = commonUtils # [{
+ ValueRange inputs() { return getOperands().slice(0, 2); }
+ ValueRange outputs() { return getOperands().take_back(); }
+
// TODO: extend to support more than 1 dimensions and potentially grouping
// too.
unsigned getNumBatchDimensions() { return 1; }
@@ -335,6 +309,12 @@ def ConvOp : PoolingBase_Op<"conv", [
// parallelized across; i.e. [zs] in the TF notation above whose number
// match `xs` (i.e. 1 window loop per "image" dimension).
// This may evolve in the future.
+ // Conditionally check nPar is large enough for cases of ill-formed op:
+ // this avoids overflows before hitting the verifier.
+ assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() &&
+ "expected at least one window dimension (i.e. memref ranks greater "
+ "than 2). See 'func @conv_rank_limit' in "
+ "mlir/test/Dialect/Linalg/invalid.mlir");
unsigned nWin =
nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
@@ -352,7 +332,8 @@ def ConvOp : PoolingBase_Op<"conv", [
ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto nWin = getNumWindowLoops();
- assert(nWin > 0 && "expected at least one window dimension");
+ assert(nWin > 0 && "expected at least one window dimension (i.e. memref "
+ "ranks greater than 2)");
unsigned idx = 0;
// In the following, AffineDimExprs are indexed in loop order:
// [ b, xs, k, q, zs]
@@ -394,13 +375,9 @@ def ConvOp : PoolingBase_Op<"conv", [
let hasCanonicalizer = 1;
}
+// Only support buffer semantics.
class SingleInputPoolingBase_Op<string mnemonic>
- : PoolingBase_Op<mnemonic, [
- NInputs<2>,
- // Despite having reductions, this manually defined ConvOp may only take
- // memref operands and can never have init tensors.
- ZeroInitTensors,
- NOutputs<1>]> {
+ : PoolingBase_Op<mnemonic, []> {
let description = [{
A base class for single input pooling function.
@@ -420,6 +397,9 @@ class SingleInputPoolingBase_Op<string mnemonic>
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = commonUtils# [{
+ ValueRange inputs() { return getOperands().slice(0, 2); }
+ ValueRange outputs() { return getOperands().take_back(); }
+
ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions.
unsigned nPar = getOutputShapedType(0).getRank();
@@ -493,11 +473,9 @@ class LinalgOperandOfRank<int rank>: Type<
class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- NamedStructuredOpTrait,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins Variadic<AnyShaped>:$inputs,
- Variadic<AnyMemRef>:$output_buffers,
- Variadic<AnyRankedTensor>:$init_tensors,
+ Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
OptionalAttr<StrAttr>:$doc,
@@ -622,34 +600,26 @@ def GenericOp : GenericOpBase<"generic"> {
```mlir
%C = linalg.generic #trait_attribute
ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
- init(%C : tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>)
{other-optional-attributes}
{region}
-> (tensor<?x?xf32>)
```
-
- The `init` operand and the conventions around mixing tensors and buffers are
- described in more detail in the "Tensors and Buffers: Conventions and
- Limitations" section in the [Linalg Document](../docs/Linalg.md)
-
- Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. Such legalizations move tensor return values
- into output buffer operands and updates the region arguments accordingly.
}];
let builders = [
OpBuilderDAG<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputBuffers, "ValueRange":$initTensors,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
- "StringRef":$doc, "StringRef":$libraryCall,
+ "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
+ "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
+ "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
OpBuilderDAG<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
"StringRef":$doc, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
OpBuilderDAG<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputBuffers, "ValueRange":$initTensors,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+ "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
+ "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
OpBuilderDAG<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
@@ -714,8 +684,8 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
```mlir
linalg.indexed_generic #matmul_trait
- ins(%A, %B : memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>)
+ ins(%A, %B : memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
outs(%C : memref<?x?xf32, stride_specification>) {
(%offset_m: index, %offset_n: index, %offset_k: index,
%a: f32, %b: f32, %c: f32) :
@@ -761,27 +731,19 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
```mlir
%C = linalg.indexed_generic #trait_attribute
- ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
- init(%C : tensor<?x?xf32>)
+ ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
+ outs(%C : tensor<?x?xf32>)
{other-optional-attributes}
{region_with_index_arguments}
-> (tensor<?x?xf32>)
```
-
- The `init` operand and the conventions around mixing tensors and buffers are
- described in more detail in the "Tensors and Buffers: Conventions and
- Limitations" section in the [Linalg Document](../docs/Linalg.md)
-
- Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. Such legalizations move tensor return values
- into output buffer operands and update the region arguments accordingly.
}];
let builders = [
OpBuilderDAG<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputBuffers, "ValueRange":$initTensors,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
- "StringRef":$doc, "StringRef":$libraryCall,
+ "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
+ "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
+ "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilderDAG<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
@@ -790,8 +752,8 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilderDAG<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputBuffers, "ValueRange":$initTensors,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+ "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
+ "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilderDAG<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 74ca666d63a5..3fc3fa4a5556 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -20,6 +20,24 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td"
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let cppNamespace = "::mlir::linalg";
let methods = [
+ //===------------------------------------------------------------------===//
+ // Loop types handling.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of induction variables in the basic block. This should
+ always be 0 for index-free linalg ops. For IndexedGeneric, this must be
+ equal to numLoops
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumPayloadInductionVariables",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return isa<IndexedGenericOp>(this->getOperation()) ?
+ $_op.getNumLoops() : 0;
+ }]
+ >,
//===------------------------------------------------------------------===//
// Loop types handling.
//===------------------------------------------------------------------===//
@@ -125,42 +143,60 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
}]>,
//===------------------------------------------------------------------===//
- // Num input/output/initTensors arguments handling.
+ // Num input/output arguments handling.
//===------------------------------------------------------------------===//
- // These special methods must be defined by each op that wants to implement
- // the LinalgStructuredInterface. For now, this is either:
- // - Explicitly specified in the op definition.
- // - Derived from variadic attributes (for "named" ops, linalg.generic and
- // linalg.indexed_generic ops).
+ // `inputs` must be defined by each op that wants to implement the
+ // LinalgStructuredInterface.
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input shape operands.
+ }],
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"inputs",
+ /*args=*/(ins)
+ >,
+ // These special methods rely on `inputs` and `outputs` being defined by
+ // each op that wants to implement the LinalgStructuredInterface.
InterfaceMethod<
/*desc=*/[{
Return the number of inputs.
}],
/*retTy=*/"unsigned",
- /*methodName=*/"getNumInputs"
+ /*methodName=*/"getNumInputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.inputs().size();
+ }]
>,
+ // `outputs` must be defined by each op that wants to implement the
+ // LinalgStructuredInterface.
InterfaceMethod<
/*desc=*/[{
- Return the number of init tensors.
+ Return the output shape operands.
}],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumInitTensors"
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"outputs",
+ /*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Return the number of outputs.
}],
/*retTy=*/"unsigned",
- /*methodName=*/"getNumOutputs"
+ /*methodName=*/"getNumOutputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.outputs().size();
+ }]
>,
//===------------------------------------------------------------------===//
- // Input arguments handling.
+ // Input operands handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return the `i`-th input value.
- The `i^th` input argument is always the `i^th` operand regardless of
- whether we have tensors or buffers.
+ Return the `i`-th input operand.
}],
/*retTy=*/"Value",
/*methodName=*/"getInput",
@@ -173,24 +209,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the index of the given input value `v`, or `None` if the value is
- not an input.
- }],
- /*retTy=*/"llvm::Optional<unsigned>",
- /*methodName=*/"getIndexOfInput",
- /*args=*/(ins "Value":$value),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto it = llvm::find(getInputs(), value);
- if (it != getInputs().end())
- return it - getInputs().begin();
- return llvm::None;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the `i`-th input shaped type, irrespective of buffer or tensor
- type.
+ Return the `i`-th input shaped type
}],
/*retTy=*/"ShapedType",
/*methodName=*/"getInputShapedType",
@@ -202,7 +221,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the input operands.
+ Return the range of input operands.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getInputs",
@@ -215,7 +234,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the range over the input operands that are of buffer type.
+ Return the OpOperands for the input operands.
+ }],
+ /*retTy=*/" MutableArrayRef<OpOperand>",
+ /*methodName=*/"getInputOpOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return this->getOperation()->getOpOperands().take_front(getNumInputs());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of input operands that are of buffer type.
}],
/*retTy=*/"SmallVector<Value, 4>",
/*methodName=*/"getInputBuffers",
@@ -223,417 +254,504 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::to_vector<4>(llvm::make_filter_range(
- getInputs(), [](Value in){ return in.getType().isa<MemRefType>(); }));
+ getInputs(), [](Value in){ return in.getType().template isa<MemRefType>(); }));
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the subset of input operands that are of ranked tensor type.
+ Return the number of input buffer operands.
}],
- /*retTy=*/"SmallVector<RankedTensorType, 4>",
- /*methodName=*/"getInputTensorTypes" ,
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumInputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<RankedTensorType, 4> res;
- for (Type type : getInputs().getTypes())
- if (auto t = type.template dyn_cast<RankedTensorType>())
- res.push_back(t);
- return res;
+ return $_op.getInputBuffers().size();
}]
>,
- //===------------------------------------------------------------------===//
- // Output arguments handling.
- //===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return the output buffer at the given index, asserts that this is a
- buffer operand and not a tensor result.
- The `i^th` output argument is an operand (resp. a return value) iff it
- is a value of buffer type (resp. a return value of tensor type).
+ Return the `index`^th input buffer.
}],
/*retTy=*/"Value",
- /*methodName=*/"getOutputBuffer",
- /*args=*/(ins "unsigned":$i),
+ /*methodName=*/"getInputBuffer",
+ /*args=*/(ins "unsigned":$index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- // Output buffers are passed as output buffer operands (side-effecting).
- // Output tensors are results.
- // The union of the 2 are all the outputs and we want to ensure i does
- // not overflow the buffer operands.
- assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs()
- && "overflowing output buffer index");
- return this->getOperation()->getOperand($_op.getNumInputs() + i);
+ assert(index < getNumInputBuffers());
+ return getInputBuffers()[index];
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the index of the given buffer value, or `None` if the value is
- not part of the output buffers.
+ Return the subset of input operands that are of buffer type.
}],
- /*retTy=*/"llvm::Optional<unsigned>",
- /*methodName=*/"getIndexOfOutputBuffer",
- /*args=*/(ins "Value":$value),
+ /*retTy=*/"SmallVector<OpOperand*, 4>",
+ /*methodName=*/"getInputBuffersOpOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto it = llvm::find(getOutputBuffers(), value);
- if (it != getOutputBuffers().end())
- return it - getOutputBuffers().begin();
- return llvm::None;
+ SmallVector<OpOperand*, 4> res;
+ res.reserve(getNumInputs());
+ for (OpOperand &o : getInputOpOperands())
+ if (o.get().getType().isa<MemRefType>())
+ res.push_back(&o);
+ return res;
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the type of the output buffer at the given index.
+ Return the subset of input operands that are of tensor type.
}],
- /*retTy=*/"MemRefType",
- /*methodName=*/"getOutputBufferType",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"SmallVector<Value, 4>",
+ /*methodName=*/"getInputTensors",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getOutputBuffer(i).getType().template cast<MemRefType>();
- }]>,
+ return llvm::to_vector<4>(llvm::make_filter_range(
+ getInputs(),
+ [](Value in){ return in.getType().template isa<RankedTensorType>(); }));
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
- Return the `i`-th output shaped type, irrespective of buffer or tensor
- type.
+ Return the subset of op operands that are of tensor type.
}],
- /*retTy=*/"ShapedType",
- /*methodName=*/"getOutputShapedType",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"SmallVector<OpOperand*, 4>",
+ /*methodName=*/"getInputTensorsOpOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getShapedType(i + $_op.getNumInputs());
- }]>,
+ SmallVector<OpOperand*, 4> res;
+ res.reserve(getNumInputs());
+ for (OpOperand &o : getInputOpOperands())
+ if (o.get().getType().isa<RankedTensorType>())
+ res.push_back(&o);
+ return res;
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
- Return the results that are of ranked tensor type.
+ Return the types of the subset of input operands that are of buffer type.
}],
- /*retTy=*/"SmallVector<RankedTensorType, 4>",
- /*methodName=*/"getOutputTensorTypes",
+ /*retTy=*/"SmallVector<MemRefType, 4>",
+ /*methodName=*/"getInputBufferTypes" ,
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<RankedTensorType, 4> res;
- for (Type type : this->getOperation()->getResults().getTypes())
- res.push_back(type.template cast<RankedTensorType>());
- return res;
- }]>,
+ return llvm::to_vector<4>(
+ llvm::map_range(
+ llvm::make_filter_range(
+ ValueRange(getInputs()).getTypes(),
+ [](Type in){ return in.isa<MemRefType>(); }),
+ [](Type in){ return in.cast<MemRefType>(); }));
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
- Return the output buffers (operands).
+ Return the types of the subset of input operands that are of ranked
+ tensor type.
}],
- /*retTy=*/"Operation::operand_range",
- /*methodName=*/"getOutputBuffers",
+ /*retTy=*/"SmallVector<RankedTensorType, 4>",
+ /*methodName=*/"getInputTensorTypes" ,
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = this->getOperation()->getOperands();
- return {range.begin() + $_op.getNumInputs(),
- range.begin() + getNumInputsAndOutputBuffers()};
+ return llvm::to_vector<4>(
+ llvm::map_range(
+ llvm::make_filter_range(
+ ValueRange(getInputs()).getTypes(),
+ [](Type in){ return in.isa<RankedTensorType>(); }),
+ [](Type in){ return in.cast<RankedTensorType>(); }));
}]
>,
//===------------------------------------------------------------------===//
- // Input and Output arguments handling.
+ // Output operands handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return one single buffer at position `$i`.
+ Return the `i`-th output operand.
}],
/*retTy=*/"Value",
- /*methodName=*/"getBuffer",
+ /*methodName=*/"getOutput",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
- return this->getOperation()->getOperand(i);
+ assert(i < $_op.getNumOutputs());
+ return this->getOperation()->getOperand(i + $_op.getNumInputs());
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of output buffers
+ Return the `i`-th output shaped type
}],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumOutputBuffers",
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getOutputShapedType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getOutput(i).getType().template cast<ShapedType>();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the range of output operands.
+ }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getNumOutputs() - this->getOperation()->getNumResults();
+ auto start =
+ this->getOperation()->getOperands().begin() + $_op.getNumInputs();
+ return {start, start + $_op.getNumOutputs()};
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of inputs and outputs, irrespective of their buffer or
- tensor type.
+ Return the OpOperands for the output operands.
}],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumInputsAndOutputs",
+ /*retTy=*/" MutableArrayRef<OpOperand>",
+ /*methodName=*/"getOutputOpOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getNumInputs() + $_op.getNumOutputs();
+ return this->getOperation()->getOpOperands().slice(
+ getNumInputs(), getNumOutputs());
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of inputs, irrespective of their buffer or tensor type
- and output buffers
+ Return the subset of output operands that are of buffer type.
}],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumInputsAndOutputBuffers",
+ /*retTy=*/"SmallVector<Value, 4>",
+ /*methodName=*/"getOutputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getNumInputs() + $_op.getNumOutputs() -
- this->getOperation()->getNumResults();
+ return llvm::to_vector<4>(llvm::make_filter_range(
+ getOutputs(), [](Value in){ return in.getType().template isa<MemRefType>(); }));
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the range over inputs (irrespective of type) and output buffers.
+ Return the `index`^th output buffer.
}],
- /*retTy=*/"Operation::operand_range",
- /*methodName=*/"getInputsAndOutputBuffers",
+ /*retTy=*/"Value",
+ /*methodName=*/"getOutputBuffer",
+ /*args=*/(ins "unsigned":$index),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(index < getNumOutputBuffers());
+ return getOutputBuffers()[index];
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of output operands that are of buffer type.
+ }],
+ /*retTy=*/"SmallVector<OpOperand*, 4>",
+ /*methodName=*/"getOutputBuffersOpOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
+ SmallVector<OpOperand*, 4> res;
+ res.reserve(getNumOutputs());
+ for (OpOperand &o : getOutputOpOperands())
+ if (o.get().getType().isa<MemRefType>())
+ res.push_back(&o);
+ return res;
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the range over init tensors.
+ Return the number of output buffer operands.
}],
- /*retTy=*/"Operation::operand_range",
- /*methodName=*/"getInitTensors",
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumOutputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = this->getOperation()->getOperands();
- auto base = range.begin() + getNumInputsAndOutputBuffers();
- return {base, base + $_op.getNumInitTensors()};
+ return $_op.getOutputBuffers().size();
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return one single init tensor at position `$i`.
+ Return the subset of output operands that are of tensor type.
}],
- /*retTy=*/"Value",
- /*methodName=*/"getInitTensor",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"SmallVector<Value, 4>",
+ /*methodName=*/"getOutputTensors",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i < $_op.getNumInitTensors() && "overflowing init tensor index");
- return getInitTensors()[i];
+ return llvm::to_vector<4>(llvm::make_filter_range(
+ getOutputs(),
+ [](Value in){ return in.getType().template isa<RankedTensorType>(); }));
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return true if the shaped operand index `i` is the index of an init
- tensor.
+ Return the subset of output operands that are of tensor type.
}],
- /*retTy=*/"bool",
- /*methodName=*/"isIndexOfAnInitTensor",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"SmallVector<OpOperand*, 4>",
+ /*methodName=*/"getOutputTensorsOpOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i < $_op.getNumShapedOperands() && "overflowing shaped operand index");
- return i >= $_op.getNumInputs() + getNumOutputBuffers();
+ SmallVector<OpOperand*, 4> res;
+ res.reserve(getNumOutputs());
+ for (OpOperand &o : getOutputOpOperands())
+ if (o.get().getType().isa<RankedTensorType>())
+ res.push_back(&o);
+ return res;
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the relative init tensor index of the shaped operand index.
+ Return the number of output tensor operands.
}],
/*retTy=*/"unsigned",
- /*methodName=*/"getInitTensorIndexFromShapedIndex",
- /*args=*/(ins "unsigned":$i),
+ /*methodName=*/"getNumOutputTensors",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(isIndexOfAnInitTensor(i) && "expected an init tensor index");
- return i - $_op.getNumInputs() - getNumOutputBuffers();
+ return $_op.getOutputTensors().size();
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the index of the given init tensor value, or `None` if the value
- is not part of the init tensors.
+ Return the types of the subset of output operands that are of buffer type.
}],
- /*retTy=*/"llvm::Optional<unsigned>",
- /*methodName=*/"getIndexOfInitTensor",
- /*args=*/(ins "Value":$value),
+ /*retTy=*/"SmallVector<MemRefType, 4>",
+ /*methodName=*/"getOutputBufferTypes" ,
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto it = llvm::find(getInitTensors(), value);
- if (it != getInitTensors().end())
- return it - getInitTensors().begin();
- return llvm::None;
+ return llvm::to_vector<4>(
+ llvm::map_range(
+ llvm::make_filter_range(
+ ValueRange(getOutputs()).getTypes(),
+ [](Type in){ return in.isa<MemRefType>(); }),
+ [](Type in){ return in.cast<MemRefType>(); }));
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of inputs, output buffers and init tensors operands.
+ Return the types of the subset of output operands that are of ranked
+ tensor type.
}],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumShapedOperands",
+ /*retTy=*/"SmallVector<RankedTensorType, 4>",
+ /*methodName=*/"getOutputTensorTypes" ,
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
+ return llvm::to_vector<4>(
+ llvm::map_range(
+ llvm::make_filter_range(
+ ValueRange(getOutputs()).getTypes(),
+ [](Type in){ return in.isa<RankedTensorType>(); }),
+ [](Type in){ return in.cast<RankedTensorType>(); }));
}]
>,
+
+ //===------------------------------------------------------------------===//
+ // Input and Output arguments handling.
+ //===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return the `i`-th shaped operand value, which can be an arbitrary input
- tensor/buffer, init tensor or output buffer.
+ Return true if the payload uses the value loaded from `opOperand`. This
+ is useful to avoid loading from "write-only" memory that may be
+ uninitialized, as well as properly cloning "read-write" operands.
}],
- /*retTy=*/"Value",
- /*methodName=*/"getShapedOperand",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"bool",
+ /*methodName=*/"payloadUsesValueFromOpOperand",
+ /*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i < $_op.getNumShapedOperands());
- return this->getOperation()->getOperand(i);
+ unsigned bbArgNumber =
+ getNumPayloadInductionVariables() + opOperand->getOperandNumber();
+ // Safeguard against the named linalg ops that are manually defined and
+ // that only support buffer semantics: we should not be there.
+ // Such ops have an empty regionBuilder and are not constructed with a
+ // region for now. In the future they are slated to disappear.
+ assert(this->getOperation()->getNumRegions() == 1 && "unexpected "
+ "missing region (calling `payloadUsesValueFromOpOperand` on "
+ "manually defined named Linalg op?)");
+ Block &block = this->getOperation()->getRegion(0).front();
+ // Init tensors have uses.
+ return !block.getArgument(bbArgNumber).use_empty();
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the range over inputs, output buffers and init tensors.
+ Return true if the payload uses the value loaded from input operand
+ `index`.
}],
- /*retTy=*/"Operation::operand_range",
- /*methodName=*/"getShapedOperands",
- /*args=*/(ins),
+ /*retTy=*/"bool",
+ /*methodName=*/"payloadUsesValueFromInputOperandIndex",
+ /*args=*/(ins "unsigned":$index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + getNumShapedOperands()};
+ return payloadUsesValueFromOpOperand(&getInputOpOperands()[index]);
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the `i`-th shaped type, there are 3 cases:
- 1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`;
- otherwise
- 2. if `i < getNumInputsAndOutputBuffers()` then return the
- `getOutputBufferType(i - $_op.getNumInputs())`; otherwise
- 3. return the `i - getNumInputsAndOutputBuffers()` result type.
+ Return true if the payload uses the value loaded from output operand
+ `index`.
}],
- /*retTy=*/"ShapedType",
- /*methodName=*/"getShapedType",
- /*args=*/(ins "unsigned":$i),
+ /*retTy=*/"bool",
+ /*methodName=*/"payloadUsesValueFromOutputOperandIndex",
+ /*args=*/(ins "unsigned":$index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (i < $_op.getNumInputs())
- return getInputShapedType(i);
- if (i < getNumInputsAndOutputBuffers())
- return getOutputBufferType(i - $_op.getNumInputs());
- return this->getOperation()->getResult(
- i - getNumInputsAndOutputBuffers()).
- getType().template cast<ShapedType>();
- }]>,
+ return payloadUsesValueFromOpOperand(&getOutputOpOperands()[index]);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
- Return the shaped types for all the inputs and outputs
+ Return true if `opOperand` is an init tensor. This is true when it is
+ an output tensor operand whose value is used in the payload region.
}],
- /*retTy=*/"SmallVector<ShapedType, 4>",
- /*methodName=*/"getInputOutputShapedTypes",
+ /*retTy=*/"bool",
+ /*methodName=*/"isInitTensor",
+ /*args=*/(ins "OpOperand *":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (!opOperand->get().getType().template isa<RankedTensorType>())
+ return false;
+ if (opOperand->getOperandNumber() < $_op.getNumInputs())
+ return false;
+ return payloadUsesValueFromOpOperand(opOperand);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the operand at output index `index` is an init tensor.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isIndexOfInitTensor",
+ /*args=*/(ins "unsigned":$index),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(index < getNumOutputs());
+ return isInitTensor(
+ &this->getOperation()->getOpOperands()[$_op.getNumInputs() + index]);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the output operands that are init tensors.
+ }],
+ /*retTy=*/"SmallVector<Value, 4>",
+ /*methodName=*/"getInitTensors",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<Type, 4> inputOutputTypes(
- this->getOperation()->operand_type_begin(),
- this->getOperation()->operand_type_end());
- inputOutputTypes.append(this->getOperation()->result_type_begin(),
- this->getOperation()->result_type_end());
+ auto start =
+ this->getOperation()->getOpOperands().begin() + $_op.getNumInputs();
return llvm::to_vector<4>(
- llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
- return type.cast<ShapedType>();
- }));
+ llvm::map_range(
+ llvm::make_filter_range(
+ llvm::make_range(start, start + $_op.getNumOutputs()),
+ [&](OpOperand &opOperand) {
+ return $_op.isInitTensor(&opOperand);
+ }),
+ [&](OpOperand &opOperand) {
+ return opOperand.get();
+ }));
}]
>,
InterfaceMethod<
/*desc=*/[{
- Return the first position of the shaped operand in the operand list.
+ Return the number of init tensor operands.
}],
- /*retTy=*/"Optional<unsigned>",
- /*methodName=*/"getIndexOfShapedOperand",
- /*args=*/(ins "Value":$value),
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumInitTensors",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getInitTensors().size();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of input and output operands.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumShapedOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- Optional<unsigned> inputIndex = getIndexOfInput(value);
- if (inputIndex.hasValue()) return inputIndex.getValue();
- Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
- if (outputIndex.hasValue())
- return $_op.getNumInputs() + outputIndex.getValue();
- Optional<unsigned> initTensorIndex = getIndexOfInitTensor(value);
- if (initTensorIndex.hasValue())
- return $_op.getNumInputs() + $_op.getNumOutputBuffers() + initTensorIndex.getValue();
- return llvm::None;
+ return $_op.getNumInputs() + $_op.getNumOutputs();
}]
>,
InterfaceMethod<
/*desc=*/[{
- Returns the operand index given the input index. Returns None
- of the input index is invalid.
+ Return the `i`-th shaped operand value.
}],
- /*retTy=*/"Optional<unsigned>",
- /*methodName=*/"getOperandIndexForInputIndex",
- /*args=*/(ins "unsigned":$input_index),
+ /*retTy=*/"Value",
+ /*methodName=*/"getShapedOperand",
+ /*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (input_index >= $_op.getNumInputs())
- return llvm::None;
- return input_index;
+ assert(i < $_op.getNumShapedOperands());
+ return this->getOperation()->getOperand(i);
}]
>,
InterfaceMethod<
/*desc=*/[{
- Returns the operand index given the output index. Returns None
- of the output index is invalid.
+ Return the range over input and output operands.
}],
- /*retTy=*/"Optional<unsigned>",
- /*methodName=*/"getOperandIndexForOutputIndex",
- /*args=*/(ins "unsigned":$output_index),
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getShapedOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (output_index >= $_op.getNumOutputs())
- return llvm::None;
- return output_index + $_op.getNumInputs();
+ auto range = this->getOperation()->getOperands();
+ return {range.begin(), range.begin() + getNumShapedOperands()};
}]
>,
InterfaceMethod<
/*desc=*/[{
- Returns the input index given the operand index. Return None
- if the operand index doesnt corresponding to an input.
+ Return the OpOperands for all the shaped operands.
}],
- /*retTy=*/"Optional<unsigned>",
- /*methodName=*/"getInputIndex",
- /*args=*/(ins "unsigned":$operand_index),
+ /*retTy=*/" MutableArrayRef<OpOperand>",
+ /*methodName=*/"getShapedOpOperands",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (operand_index >= $_op.getNumInputs())
- return llvm::None;
- return operand_index;
+ return this->getOperation()->getOpOperands().take_front(
+ getNumShapedOperands());
}]
>,
InterfaceMethod<
/*desc=*/[{
- Returns the output index given the operand index. Return None
- if the operand index doesnt corresponding to an output.
+ Return the range over input and output operands.
}],
- /*retTy=*/"Optional<unsigned>",
- /*methodName=*/"getOutputIndex",
- /*args=*/(ins "unsigned":$operand_index),
+ /*retTy=*/"SmallVector<ShapedType, 4>",
+ /*methodName=*/"getShapedOperandTypes",
+ /*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (operand_index < $_op.getNumInputs() ||
- operand_index >= $_op.getNumInputs() + $_op.getNumOutputs())
- return llvm::None;
- return operand_index - $_op.getNumInputs();
+ return llvm::to_vector<4>(
+ llvm::map_range(
+ getShapedOperands(),
+ [](Value v) { return v.getType().cast<ShapedType>(); }));
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `i`-th shaped type
+ }],
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getShapedType",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getShapedOperand(i).getType().template cast<ShapedType>();
+ }]>,
//===------------------------------------------------------------------===//
// Other interface methods.
@@ -679,7 +797,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i < getNumInputsAndOutputs());
+ assert(i < $_op.getNumShapedOperands());
return getIndexingMaps()[i];
}]
>,
@@ -719,8 +837,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return this->getOperation()->getNumResults() == 0 &&
- llvm::all_of(getInputs(),
- [](Value v) { return v.getType().isa<MemRefType>(); });
+ llvm::all_of(getShapedOperands(), [](Value v) {
+ return v.getType().template isa<MemRefType>(); });
}]
>,
InterfaceMethod<
@@ -732,11 +850,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto isTensorType = [](Value v) {
- return v.getType().isa<RankedTensorType>();
- };
- return llvm::all_of(getInputs(), isTensorType) &&
- llvm::all_of(this->getOperation()->getResults(), isTensorType);
+ return llvm::all_of(getShapedOperands(), [](Value v) {
+ return v.getType().template isa<RankedTensorType>();
+ });
}]
>,
InterfaceMethod<
@@ -748,7 +864,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op->getAttr(getSparseAttrName()).template dyn_cast_or_null<ArrayAttr>() != nullptr;
+ return $_op->getAttr(getSparseAttrName()).
+ template dyn_cast_or_null<ArrayAttr>() != nullptr;
}]
>,
InterfaceMethod<
@@ -871,7 +988,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
];
let extraClassDeclaration = [{
- /// Return the flat list of all operand dimension sizes in the order they
+ /// Return the flat list of all operand dimension sizes in the order they
/// appear in the operands.
SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
@@ -893,7 +1010,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
for (unsigned i = 0; i < nExtraOperands; ++i) {
res.push_back(getOperation()->getOperand(numShapedOperands + i));
assert((res.back().getType().isSignlessIntOrIndexOrFloat()
- || res.back().getType().isa<VectorType>()) &&
+ || res.back().getType().template isa<VectorType>()) &&
"expected scalar or vector type");
}
return res;
@@ -904,7 +1021,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//========================================================================//
void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
- void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); }
private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
@@ -916,6 +1032,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
getOperation()->setAttr("operand_segment_sizes", newAttr);
}
}];
+
+ let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
}
#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
deleted file mode 100644
index adfa6a6f1af9..000000000000
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ /dev/null
@@ -1,166 +0,0 @@
-//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
-#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
-
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace OpTrait {
-namespace linalg {
-
-/// This class provides the API for ops that are known to have a specified
-/// number of inputs, all passed as operands. Use as a trait as follows:
-///
-/// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
-///
-template <unsigned N> class NInputs {
-public:
- template <typename ConcreteType>
- class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
- public:
- static unsigned getNumInputs() { return N; }
- };
-};
-
-/// This class provides the API for ops that are known to not have init tensor
-/// operands. Use as a trait as follows:
-///
-/// class CopyOp : public Op<CopyOp, OpTrait::ZeroInitTensors> {
-///
-template <typename ConcreteType>
-class ZeroInitTensors : public TraitBase<ConcreteType, ZeroInitTensors> {
-public:
- static unsigned getNumInitTensors() { return 0; }
-};
-
-/// This class provides the API for ops that are known to have a specified
-/// number of outputs, all passed as operands. Use as a trait as follows:
-///
-/// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
-///
-template <unsigned N> class NOutputs {
-public:
- template <typename ConcreteType>
- class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
- public:
- static unsigned getNumOutputs() { return N; }
- };
-};
-
-/// This class provides a verifier for structured ops that are known to operate
-/// on buffers or tensors. This trait must be used in conjunction with an op
-/// definition or a trait that provides the methods `getNumInputs` and
-/// `getNumOutputs`. Use as a trait as follows:
-///
-/// class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> {
-///
-template <typename ConcreteType>
-class StructuredOpTraits
- : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
-public:
- static LogicalResult verifyTrait(Operation *op) {
- ConcreteType concreteOp = cast<ConcreteType>(op);
- auto nOperands = concreteOp.getNumInputsAndOutputBuffers();
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
- return failure();
- if (op->getNumResults() > concreteOp.getNumOutputs())
- return op->emitError("unexpected #results > #outputs");
- return success();
- }
-};
-
-/// This class provides a verifier for structured ops that are known to operate
-/// on buffers or tensors and that support `ins`, `outs` and `init` arguments.
-/// This trait must be used in conjunction with an op definition or a trait that
-/// provides the methods `getNumInputs` and `getNumOutputs`.
-///
-/// Use as a trait as follows:
-///
-/// class MatmulOp : public Op<MatmulOp, OpTrait::NamedStructuredOpTrait> {
-///
-template <typename ConcreteType>
-class NamedStructuredOpTrait
- : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTrait> {
-public:
- unsigned getNumInputs() {
- return cast<ConcreteType>(this->getOperation()).inputs().size();
- }
- unsigned getNumInitTensors() {
- return cast<ConcreteType>(this->getOperation()).init_tensors().size();
- }
- unsigned getNumOutputs() {
- ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
- return concreteOp.output_buffers().size() +
- concreteOp.result_tensors().size();
- }
- static LogicalResult verifyTrait(Operation *op) {
- ConcreteType concreteOp = cast<ConcreteType>(op);
- unsigned nInputAndBufferOperands =
- concreteOp.getNumInputsAndOutputBuffers();
- if (failed(
- OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
- return failure();
-
- SmallVector<AffineExpr, 4> redDims;
- concreteOp.getReductionDims(redDims);
- // If no result and no reduction, only check there is no init tensor and we
- // are done.
- if (redDims.empty() || op->getNumResults() == 0) {
- if (!concreteOp.init_tensors().empty())
- return op->emitError("expected empty `init` when op has no "
- "results or no reduction dims");
- return success();
- }
-
- // Only a single tensor result supported atm.
- if (op->getNumResults() != 1)
- return op->emitError(
- "expected single tensor result when reduction present");
-
- if (concreteOp.init_tensors().size() != op->getNumResults())
- return op->emitError(
- "expected #init tensors to match #results when reduction present");
-
- for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx)
- if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx])
- return op->emitError("expected init tensor #")
- << idx << " of the same type as result #" << idx;
-
- // Output tensor indexing map may not depend on reduction index.
- // TODO: this is not yet tested. Add a test when linalg.generic switches to
- // this representation.
- for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) {
- AffineMap outputMap = concreteOp.getOutputIndexingMap(idx);
- for (auto expr : outputMap.getResults()) {
- for (auto dim : redDims) {
- unsigned pos = dim.cast<AffineDimExpr>().getPosition();
- if (expr.isFunctionOfDim(pos))
- return op->emitError(
- "unexpected single tensor output indexing map ")
- << "is function of reduction dim @" << pos;
- }
- }
- }
-
- return success();
- }
-};
-
-} // namespace linalg
-} // namespace OpTrait
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 552ac75bfee5..0f060b2b1a0a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -673,6 +673,11 @@ class AnyStridedMemRefOfRank<int rank> :
MemRefRankOf<[AnyType], [rank]>.predicate]>,
AnyStridedMemRef.description # " of rank " # rank>;
+class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
+ Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+ StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
+ MemRefOf<allowedTypes>.description>;
+
// This represents a generic tuple without any constraints on element type.
def AnyTuple : Type<IsTupleTypePred, "tuple">;
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
index 38d97332f0d7..9e4b9f39f7fb 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -22,7 +22,7 @@ func @main() {
%C = constant dense<1000.0> : tensor<2x4xf32>
%D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
- init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+ outs(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
%unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index ca2d16e8de86..1042930b1ef7 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -113,15 +114,16 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
}
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
- LinalgOpView indexingOpView,
- LinalgOpView dependentOpView) {
+ OpOperand *indexingOpView,
+ OpOperand *dependentOpView) {
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
- << *indexingOpView.op << ", " << indexingOpView.operandIndex
- << ") -> \n\t\t(" << *dependentOpView.op << ", "
- << dependentOpView.operandIndex << ")");
- dependencesFromGraphs[dt][indexingOpView.op].push_back(
+ << indexingOpView->get() << " @"
+ << indexingOpView->getOperandNumber() << ") -> \n\t\t("
+ << dependentOpView->get() << " @"
+ << dependentOpView->getOperandNumber() << ")");
+ dependencesFromGraphs[dt][indexingOpView->getOwner()].push_back(
LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
- dependencesIntoGraphs[dt][dependentOpView.op].push_back(
+ dependencesIntoGraphs[dt][dependentOpView->getOwner()].push_back(
LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt});
}
@@ -156,57 +158,25 @@ LinalgDependenceGraph::getDependencesInto(
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
- for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
- unsigned srcIndex =
- src.getOperandIndexForOutputIndex(srcView.index()).getValue();
+ for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W
// RAW graph
- for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
- if (aliases.alias(srcView.value(),
- dstView.value())) { // if alias, fill RAW
- unsigned dstIndex =
- dst.getOperandIndexForInputIndex(dstView.index()).getValue();
- addDependenceElem(DependenceType::RAW,
- LinalgOpView{src.getOperation(), srcIndex},
- LinalgOpView{dst.getOperation(), dstIndex});
- }
- }
+ for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R
+ if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
+ addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
// WAW graph
- for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
- if (aliases.alias(srcView.value(),
- dstView.value())) { // if alias, fill WAW
- unsigned dstIndex =
- dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
- addDependenceElem(DependenceType::WAW,
- LinalgOpView{src.getOperation(), srcIndex},
- LinalgOpView{dst.getOperation(), dstIndex});
- }
- }
+ for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W
+ if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
+ addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
}
- for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
- unsigned srcIndex =
- src.getOperandIndexForInputIndex(srcView.index()).getValue();
+ for (OpOperand *srcOpOperand : src.getInputBuffersOpOperands()) { // R
// RAR graph
- for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
- if (aliases.alias(srcView.value(),
- dstView.value())) { // if alias, fill RAR
- unsigned dstIndex =
- dst.getOperandIndexForInputIndex(dstView.index()).getValue();
- addDependenceElem(DependenceType::RAR,
- LinalgOpView{src.getOperation(), srcIndex},
- LinalgOpView{dst.getOperation(), dstIndex});
- }
- }
+ for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R
+ if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
+ addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
// WAR graph
- for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
- if (aliases.alias(srcView.value(),
- dstView.value())) { // if alias, fill WAR
- unsigned dstIndex =
- dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
- addDependenceElem(DependenceType::WAR,
- LinalgOpView{src.getOperation(), srcIndex},
- LinalgOpView{dst.getOperation(), dstIndex});
- }
- }
+ for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W
+ if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
+ addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
}
}
@@ -248,17 +218,15 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
// TODO: we are not considering paths yet, just interleaved positions.
for (auto dt : types) {
for (auto dependence : getDependencesFrom(src, dt)) {
- auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
+ auto interimPos =
+ linalgOpPositions.lookup(dependence.dependentOpView->getOwner());
// Skip if not interleaved.
if (interimPos >= dstPos || interimPos <= srcPos)
continue;
- linalg::LinalgOp consumer =
- cast<linalg::LinalgOp>(dependence.indexingOpView.op);
- Value consumerView =
- consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
+ Value consumerView = dependence.indexingOpView->get();
if (view && !aliases.alias(view, consumerView))
continue;
- auto *op = dependence.dependentOpView.op;
+ auto *op = dependence.dependentOpView->getOwner();
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
<< *op << " on " << consumerView);
@@ -271,12 +239,10 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
bool LinalgDependenceGraph::hasDependenceFrom(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
- for (auto dep : depTypes) {
- for (auto dependence : getDependencesInto(dstLinalgOp, dep)) {
- if (dependence.dependentOpView.op == srcLinalgOp)
+ for (auto dep : depTypes)
+ for (auto dependence : getDependencesInto(dstLinalgOp, dep))
+ if (dependence.dependentOpView->getOwner() == srcLinalgOp)
return true;
- }
- }
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 0ae1efe10b7f..3c3b2777d6c1 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -23,36 +23,25 @@ using namespace mlir::scf;
Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
- ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
- ArrayRef<StructuredIndexed> resultTensorTypes,
+ ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) {
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
// Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
- exprsList.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
- for (auto container : {inputs, outputBuffers, resultTensorTypes})
+ exprsList.reserve(inputs.size() + outputs.size());
+
+ for (auto container : {inputs, outputs})
for (const StructuredIndexed &s : container)
exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
auto maps = AffineMap::inferFromExprList(exprsList);
- SmallVector<Type, 4> types;
- assert(llvm::all_of(resultTensorTypes, [](const StructuredIndexed &s) {
- return !s.hasValue();
- }));
- std::copy(resultTensorTypes.begin(), resultTensorTypes.end(),
- std::back_inserter(types));
-
- SmallVector<Value, 4> inputValues, outputBufferValues, initTensorValues;
+ SmallVector<Value, 4> inputValues, outputValues;
inputValues.reserve(inputs.size());
- outputBufferValues.reserve(outputBuffers.size());
- initTensorValues.reserve(initTensors.size());
+ outputValues.reserve(outputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
- std::copy(outputBuffers.begin(), outputBuffers.end(),
- std::back_inserter(outputBufferValues));
- std::copy(initTensors.begin(), initTensors.end(),
- std::back_inserter(initTensorValues));
+ std::copy(outputs.begin(), outputs.end(), std::back_inserter(outputValues));
auto iteratorStrTypes =
llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
@@ -61,10 +50,9 @@ Operation *mlir::edsc::makeGenericLinalgOp(
edsc::ScopedContext::getBuilderRef()
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
- types,
+ resultTensorTypes,
inputValues,
- outputBufferValues,
- initTensorValues,
+ outputValues,
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
@@ -77,12 +65,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
using namespace edsc;
SmallVector<Type, 4> blockTypes;
- blockTypes.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
- for (auto container : {inputs, outputBuffers})
+ blockTypes.reserve(inputs.size() + outputs.size());
+ for (auto container : {inputs, outputs})
for (const StructuredIndexed &s : container)
blockTypes.push_back(getElementTypeOrSelf(s.getType()));
- for (Value v : initTensors)
- blockTypes.push_back(getElementTypeOrSelf(v.getType()));
assert(op->getNumRegions() == 1);
assert(op->getRegion(0).empty());
@@ -119,11 +105,10 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
linalg_yield(unaryOp(a));
};
if (O.getType().isa<RankedTensorType>())
- return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{},
- /*initTensors=*/{}, /*resultTensorTypes=*/{O},
- fun);
- return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{O},
- /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
+ /*resultTensorTypes=*/{O}, fun);
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
+ /*resultTensorTypes=*/{}, fun);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
@@ -144,12 +129,10 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
linalg_yield(binaryOp(a, b));
};
if (O.getType().isa<RankedTensorType>())
- return makeGenericLinalgOp(
- iterTypes, /*inputs=*/{I1, I2}, /*outputBuffers=*/{},
- /*initTensors=*/{}, /*resultTensorTypes=*/{O}, fun);
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, /*outputs=*/{O},
+ /*resultTensorTypes=*/{O}, fun);
return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
- /*outputBuffers=*/{O},
- /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
+ /*outputs=*/{O}, /*resultTensorTypes=*/{}, fun);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
@@ -181,8 +164,7 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
/*inputs=*/{A({m, k}), B({k, n})},
- /*outputBuffers=*/{C({m, n})},
- /*initTensors=*/{},
+ /*outputs=*/{C({m, n})},
/*resultTensorTypes=*/{},
regionBuilder);
// clang-format on
@@ -199,8 +181,7 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
/*inputs=*/{A({m, k}), B({k, n})},
- /*outputBuffers=*/{},
- /*initTensors=*/{C({m, n})},
+ /*outputs=*/{C({m, n})},
/*resultTensorTypes=*/{D({m, n})},
regionBuilder);
// clang-format on
@@ -236,8 +217,7 @@ Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, f}) },
- /*outputBuffers=*/{ O({b, h, w, f}) },
- /*initTensors=*/{},
+ /*outputs=*/{ O({b, h, w, f}) },
/*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
@@ -272,9 +252,8 @@ Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, dm})},
- /*outputBuffers=*/{
+ /*outputs=*/{
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
- /*initTensors=*/{},
/*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3a7249df8e79..bcbd6d903612 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -88,22 +88,20 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
/// Forward declarations.
template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegionAndAttributes(
- OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
- TypeRange outputBufferTypes, TypeRange initTensorTypes,
- TypeRange resultTypes);
+static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputTypes);
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
- SmallVectorImpl<Type> &outputBufferTypes,
- SmallVectorImpl<Type> &initTensorTypes);
+ SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputBufferTypes,
- TypeRange initTensorTypes, TypeRange resultTypes);
+ TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
@@ -122,9 +120,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
-template <typename NamedStructuredOpType>
-static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
-
/// This is a common class used for patterns of the form
/// ```
/// someop(memrefcast) -> someop
@@ -152,11 +147,10 @@ static LogicalResult foldMemRefCast(Operation *op) {
//===----------------------------------------------------------------------===//
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
- ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
- StringRef doc, StringRef libraryCall,
+ ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
+ build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
@@ -166,7 +160,7 @@ void GenericOp::build(
return;
SmallVector<Type, 4> blockArgTypes;
- for (ValueRange container : {inputs, outputBuffers, initTensors})
+ for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
@@ -178,41 +172,40 @@ void GenericOp::build(
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
- indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
+ build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
+ iteratorTypes, doc, libraryCall, bodyBuild);
}
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
+ build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
- ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
- indexingMaps, iteratorTypes,
+ build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
+ iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
- ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
- StringRef doc, StringRef libraryCall,
+ ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
+ build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
@@ -223,7 +216,7 @@ void IndexedGenericOp::build(
unsigned nLoops = iteratorTypes.size();
SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
- for (ValueRange container : {inputs, outputBuffers, initTensors})
+ for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
@@ -237,32 +230,32 @@ void IndexedGenericOp::build(
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
- build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
- indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
+ build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
+ iteratorTypes, doc, libraryCall, bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
- build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
+ build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
/*doc=*/"", /*libraryCall=*/"", bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
- ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
- indexingMaps, iteratorTypes,
+ build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
+ iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
@@ -327,9 +320,8 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
dictAttr.getValue().end());
// Parsing is shared with named ops, except for the region.
- SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
- if (parseCommonStructuredOpParts(parser, result, inputTypes,
- outputBufferTypes, initTensorTypes))
+ SmallVector<Type, 1> inputTypes, outputTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
// Optional attributes may be added.
@@ -360,7 +352,7 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
- ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
+ ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
for (Value value : results) {
effects.emplace_back(MemoryEffects::Allocate::get(), value,
SideEffects::DefaultResource::get());
@@ -369,7 +361,7 @@ static void getGenericEffectsImpl(
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
}
- for (Value value : outputBuffers) {
+ for (Value value : outputs) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
@@ -391,65 +383,150 @@ void IndexedGenericOp::getEffects(
getInputBuffers(), getOutputBuffers());
}
-namespace {
+LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ // Expect at least one shaped operand.
+ // This means an op that constructs a tensor out of indices cannot be a
+ // LinalgOp at the moment. For now this will have to be a special op until we
+ // have output shape operands that are not tensors.
+ auto nShapedOperands = linalgOp.getNumShapedOperands();
+ if (nShapedOperands == 0)
+ return linalgOp.emitOpError("expected at least 1 Shaped operand");
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands)))
+ return failure();
+ // Should have at least one output tensor per result tensor.
+ // Can also have outbut buffers that do not correspond to results.
+ if (op->getNumResults() > linalgOp.getNumOutputTensors())
+ return op->emitError("unexpected #results > #outputs");
+
+ // All shaped operands must be indexed.
+ if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
+ return linalgOp.emitOpError("expected the number of indexing_map (")
+ << linalgOp.indexing_maps().size()
+ << ") to be equal to the number of shaped operands ("
+ << linalgOp.getNumShapedOperands() << ")";
-template <typename GenericOpType>
-struct BlockArgsVerifier {
- static LogicalResult verify(GenericOpType op, Block &block);
-};
+ SmallVector<AffineMap, 4> indexingMaps;
+ indexingMaps.reserve(linalgOp.indexing_maps().size());
+ for (auto en : llvm::enumerate(linalgOp.indexing_maps())) {
+ auto idx = en.index();
+ auto m = en.value().template cast<AffineMapAttr>().getValue();
+ indexingMaps.push_back(m); // Save reference to map for further checks.
+ auto shapedValue = linalgOp.getShapedType(idx);
-template <typename GenericOpType>
-LogicalResult BlockArgsVerifier<GenericOpType>::verify(GenericOpType op,
- Block &block) {
- auto nOperands = op.getNumOperands();
- if (block.getNumArguments() != nOperands)
- return op.emitOpError("expected number of block arguments to match number "
- "of operands");
+ // Symbols disallowed.
+ if (m.getNumSymbols() != 0)
+ return linalgOp.emitOpError("unexpected symbols in indexing_map #")
+ << idx;
- // Note: the number and type of yield values are checked in the YieldOp.
- auto nInputViews = op.getNumInputs();
- for (unsigned i = 0; i < nOperands; ++i) {
- auto viewType = op.getShapedType(i);
- if (viewType.getElementType() != block.getArgument(i).getType())
- return op.emitOpError("expected block argument ")
- << (i + 1) << " of the same type as elemental type of "
- << ((i < nInputViews) ? "input " : "output ")
- << "operand: " << viewType;
+ // Domain must be consistent.
+ auto nLoops = linalgOp.getNumLoops();
+ if (m.getNumDims() != nLoops)
+ return linalgOp.emitOpError("expected indexing_map #")
+ << idx << " to have " << nLoops
+ << " dim(s) to match the number of loops";
+
+ if (m.getNumResults() != shapedValue.getRank())
+ return linalgOp.emitOpError("expected shaped value rank (")
+ << shapedValue.getRank()
+ << ") to match the result rank of indexing_map #" << idx << " ("
+ << m.getNumResults() << ")";
}
- return success();
-}
-template <>
-LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op,
- Block &block) {
- auto nInputViews = op.getNumInputs();
- auto nLoops = op.getNumLoops();
- auto nOperands = op.getNumOperands();
- if (block.getNumArguments() != nOperands + nLoops)
- return op.emitOpError(
- "expected number of block arguments to match number of operands + "
- "number of loops");
+ SmallVector<AffineExpr, 4> redDims;
+ linalgOp.getReductionDims(redDims);
+
+ // Simplifying assumption: either full tensor or full buffer mode.
+ // This allows simpler verification of output operands vs result types
+ // without premature tracking of which operand is what in mixed-mode.
+ // TODO: relax when mixed-mode needs to pass verification.
+ if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0)
+ return op->emitError("expected output operands to all have tensor type or "
+ "all have buffer type");
+
+ for (auto it :
+ llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) {
+ if (!std::get<0>(it).get().getType().isa<RankedTensorType>())
+ continue;
+ if (std::get<0>(it).get().getType() != std::get<1>(it))
+ return op->emitError("expected type of operand #")
+ << std::get<0>(it).getOperandNumber() << " ("
+ << std::get<0>(it).get().getType() << ")"
+ << " to match type of corresponding result (" << std::get<1>(it)
+ << ")";
+ }
+
+ // Output tensor indexing map may not depend on reduction indices.
+ for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) {
+ AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber());
+ for (auto expr : outputMap.getResults()) {
+ for (auto dim : redDims) {
+ unsigned pos = dim.cast<AffineDimExpr>().getPosition();
+ if (expr.isFunctionOfDim(pos)) {
+ std::string exprStr;
+ {
+ llvm::raw_string_ostream os(exprStr);
+ os << expr;
+ }
+ return op->emitError(
+ "unexpected output tensor expression in indexing map #")
+ << (opOperand.getOperandNumber() - linalgOp.getNumInputs())
+ << " a.k.a '" << exprStr
+ << "' is function of reduction iterator 'd" << pos << "'";
+ }
+ }
+ }
+ }
+
+ // Named ops that are defined manually have a region builder but no region at
+ // this time. Assume the region is well-formed by specification.
+ // TODO: use linalg-ods-gen for all ops when we have enough expressive power.
+ if (linalgOp->getNumRegions() == 0) {
+ assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
+ return success();
+ }
+
+ auto ®ion = linalgOp->getRegion(0);
+ if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
+ return op->emitOpError("expected 1 region with 1 block");
+
+ if (!linalgOp.getShapesToLoopsMap())
+ return op->emitOpError("expected the shape-to-loops map to be non-null");
+
+ // Simplifying assumption: bbargs match 1-1 with shape operands elemental
+ // types.
+ // TODO: once ranked shape types are plugged in, we may want to drop the
+ // corresponding bbargs, that can never be read from. This will be subject to
+ // consistency discussions (i.e. what to do with output tensors whose bbarg is
+ // not used).
+ Block &block = linalgOp->getRegion(0).front();
+ unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables();
+
+ if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments())
+ return op->emitError("expected as many non-induction variable region "
+ "arguments as the number of shaped operands");
// Note: the number and type of yield values are checked in the YieldOp.
- for (unsigned i = 0; i < nLoops; ++i)
+ for (unsigned i = 0; i < numBBIvs; ++i)
if (!block.getArgument(i).getType().isIndex())
- return op.emitOpError("expected block argument ")
- << (i + 1) << " to be an index";
-
- for (unsigned i = 0; i < nOperands; ++i) {
- unsigned memrefArgIndex = i + nLoops;
- auto viewType = op.getShapedType(i);
- if (viewType.getElementType() !=
- block.getArgument(memrefArgIndex).getType())
- return op.emitOpError("expected block argument ")
- << (memrefArgIndex + 1)
- << " of the same type as elemental type of "
- << ((i < nInputViews) ? "input " : "output ")
- << "operand: " << viewType;
+ return op->emitOpError("expected index block argument #") << i;
+
+ unsigned idx = 0;
+ for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(),
+ block.getArguments().drop_front(numBBIvs))) {
+ if (std::get<0>(it).getElementType() != std::get<1>(it).getType())
+ return op->emitError("expected type of bb argument #")
+ << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")"
+ << " to match element type of corresponding shaped operand ("
+ << std::get<0>(it).getElementType() << ")";
+ ++idx;
}
+
return success();
}
+namespace {
+
template <typename GenericOpType>
struct AnnotationsVerifier {
static LogicalResult verify(GenericOpType op) { return success(); }
@@ -465,7 +542,7 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
return op.emitOpError("expected sparse annotations on tensors only");
if (op.getNumOutputs() != 1)
return op.emitOpError("expected single output tensor");
- unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numTensors = op.getNumShapedOperands();
if (sparseAttr.size() != numTensors)
return op.emitOpError("expected one sparse annotation for each tensor");
for (unsigned t = 0; t < numTensors; t++) {
@@ -497,49 +574,6 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
- auto nLoops = op.getNumLoops();
-
- if (op.inputs().size() + op.output_buffers().size() +
- op.init_tensors().size() + op.getNumResults() ==
- 0)
- return op.emitOpError("expected at least 1 Shaped operand or return");
-
- auto ®ion = op.region();
- if (!llvm::hasSingleElement(region))
- return op.emitOpError("expected region with 1 block");
- if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
- return failure();
-
- if (op.indexing_maps().size() != op.getNumInputsAndOutputs())
- return op.emitOpError("expected the number of indexing_map (")
- << op.indexing_maps().size()
- << ") to be equal to the number of inputs and outputs ("
- << op.getNumInputsAndOutputs() << ")";
-
- SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.reserve(op.indexing_maps().size());
- for (auto en : llvm::enumerate(op.indexing_maps())) {
- auto idx = en.index();
- auto m = en.value().template cast<AffineMapAttr>().getValue();
- indexingMaps.push_back(m); // Save reference to map for further checks.
- auto view = op.getShapedType(idx);
-
- if (m.getNumSymbols() != 0)
- return op.emitOpError("unexpected symbols in indexing_map #") << idx;
-
- if (m.getNumDims() != nLoops)
- return op.emitOpError("expected indexing_map #")
- << idx << " to have " << nLoops
- << " dim(s) to match the number of loops";
-
- if (m.getNumResults() != view.getRank())
- return op.emitOpError("expected indexing_map #")
- << idx << " results to match view rank: " << view;
- }
-
- if (!op.getShapesToLoopsMap())
- return op.emitOpError("expected the shape-to-loops map to be non-null");
-
if (failed(AnnotationsVerifier<GenericOpType>::verify(op)))
return failure();
@@ -1380,8 +1414,6 @@ static LogicalResult verify(ConvOp op) {
return op.emitOpError("expects memref elemental types to match");
if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
return op.emitOpError("expects memref ranks to match");
- if (oType.getRank() <= 2)
- return op.emitOpError("expects memref ranks to be greater than 2");
if (auto strides = op.strides()) {
if (failed(
verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
@@ -1591,13 +1623,12 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributesImpl(
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
- TypeRange outputBufferTypes, TypeRange initTensorTypes,
- TypeRange resultTypes,
+ TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler) {
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
- for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
+ for (auto containers : {inputTypes, outputTypes})
for (auto t : containers)
argTypes.push_back(getElementTypeOrSelf(t));
@@ -1622,13 +1653,11 @@ template <typename NamedStructuredOpType>
void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
- TypeRange outputBufferTypes,
- TypeRange initTensorTypes,
- TypeRange resultTypes) {
+ TypeRange outputTypes) {
Region ®ion = *result.addRegion();
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
- resultTypes, [&](unsigned expected, unsigned actual) {
+ opBuilder, region, inputTypes, outputTypes,
+ [&](unsigned expected, unsigned actual) {
llvm::errs() << "region expects " << expected << " args, got "
<< actual;
assert(expected != actual && "incorrect number of arguments");
@@ -1638,13 +1667,12 @@ void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputBufferTypes,
- TypeRange initTensorTypes, TypeRange resultTypes) {
+ TypeRange inputTypes, TypeRange outputTypes) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
- resultTypes, [&](unsigned expected, unsigned actual) {
+ opBuilder, region, inputTypes, outputTypes,
+ [&](unsigned expected, unsigned actual) {
res = parser.emitError(parser.getCurrentLocation(),
llvm::formatv("region expects {0} args, got {1}",
expected, actual));
@@ -1664,12 +1692,9 @@ parseNamedStructuredOpResults(OpAsmParser &parser,
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
- SmallVectorImpl<Type> &outputBufferTypes,
- SmallVectorImpl<Type> &initTensorTypes) {
- llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
- initTensorsOperandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
- outputBuffersOperands, initTensorsOperands;
+ SmallVectorImpl<Type> &outputTypes) {
+ llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
parser.parseOptionalAttrDict(result.attributes);
@@ -1684,41 +1709,30 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
}
if (succeeded(parser.parseOptionalKeyword("outs"))) {
- outputBuffersOperandsLoc = parser.getCurrentLocation();
- if (parser.parseLParen() ||
- parser.parseOperandList(outputBuffersOperands) ||
- parser.parseColonTypeList(outputBufferTypes) || parser.parseRParen())
- return failure();
- }
- if (succeeded(parser.parseOptionalKeyword("init"))) {
- initTensorsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
- parser.parseColonTypeList(initTensorTypes) || parser.parseRParen())
+ outputsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
+ parser.parseColonTypeList(outputTypes) || parser.parseRParen())
return failure();
}
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
result.operands) ||
- parser.resolveOperands(outputBuffersOperands, outputBufferTypes,
- outputBuffersOperandsLoc, result.operands) ||
- parser.resolveOperands(initTensorsOperands, initTensorTypes,
- initTensorsOperandsLoc, result.operands))
+ parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
+ result.operands))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(
{static_cast<int32_t>(inputsOperands.size()),
- static_cast<int32_t>(outputBuffersOperands.size()),
- static_cast<int32_t>(initTensorsOperands.size())}));
+ static_cast<int32_t>(outputsOperands.size())}));
return success();
}
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
- SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
- if (parseCommonStructuredOpParts(parser, result, inputTypes,
- outputBufferTypes, initTensorTypes))
+ SmallVector<Type, 1> inputTypes, outputTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
// TODO: consider merging results parsing into region parsing.
@@ -1730,8 +1744,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
- parser, *region, inputTypes, outputBufferTypes, initTensorTypes,
- outputTensorsTypes))
+ parser, *region, inputTypes, outputTypes))
return failure();
result.addRegion(std::move(region));
@@ -1750,12 +1763,8 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
- if (!op.output_buffers().empty())
- p << " outs(" << op.output_buffers() << " : "
- << op.output_buffers().getTypes() << ")";
- if (!op.init_tensors().empty())
- p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
- << ") ";
+ if (!op.outputs().empty())
+ p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
template <typename NamedStructuredOpType>
@@ -1789,7 +1798,7 @@ struct EraseDeadLinalgOp : public RewritePattern {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
- for (Value v : linalgOp.getInputsAndOutputBuffers()) {
+ for (Value v : linalgOp.getShapedOperands()) {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
@@ -1836,11 +1845,8 @@ struct FoldTensorCastOp : public RewritePattern {
newOperands.push_back(
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
}
- // Output buffers are memrefs, they don't fold.
- newOperands.append(linalgOp.getOutputBuffers().begin(),
- linalgOp.getOutputBuffers().end());
// Init tensors may fold, in which case the resultType must also change.
- for (Value v : linalgOp.getInitTensors()) {
+ for (Value v : linalgOp.getOutputs()) {
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
@@ -1904,8 +1910,7 @@ struct DeduplicateInputs : public RewritePattern {
for (auto v : llvm::enumerate(linalgOp.getInputs()))
if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
newOperands.push_back(v.value());
- llvm::append_range(newOperands, linalgOp.getOutputBuffers());
- llvm::append_range(newOperands, linalgOp.getInitTensors());
+ llvm::append_range(newOperands, linalgOp.getOutputs());
llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
// Clone the old op with new operands.
@@ -1929,11 +1934,8 @@ struct DeduplicateInputs : public RewritePattern {
newLinalgOp.setNumInputs(canonicalInput.size());
// linalg.indexed_generic payloads have additional arguments prepended to
- // the block arg list. The number of such args is one per dimension of the
- // iteration space.
- int bbArgBaseOffset = 0;
- if (isa<IndexedGenericOp>(op))
- bbArgBaseOffset = newIndexingMaps[0].getNumInputs();
+ // the block arg list.
+ int bbArgBaseOffset = newLinalgOp.getNumPayloadInductionVariables();
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b36d74bad3fb..a3ab6f45b26e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -21,21 +21,22 @@
using namespace ::mlir;
using namespace ::mlir::linalg;
-static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
- if (val.getType().isIndex())
- return val;
- return b.create<IndexCastOp>(loc, val, b.getIndexType());
-}
-
-static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
- auto memrefType = memref.getType().cast<MemRefType>();
+static SmallVector<Value, 4> getDynOperands(Location loc, Value val,
+ OpBuilder &b) {
SmallVector<Value, 4> dynOperands;
- for (auto dim : llvm::enumerate(memrefType.getShape())) {
+ auto shapedType = val.getType().cast<ShapedType>();
+ for (auto dim : llvm::enumerate(shapedType.getShape())) {
if (dim.value() == TensorType::kDynamicSize) {
- dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index()));
+ dynOperands.push_back(b.create<DimOp>(loc, val, dim.index()));
}
}
- auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
+ return dynOperands;
+}
+
+static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
+ auto memrefType = memref.getType().cast<MemRefType>();
+ auto alloc =
+ b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b));
b.create<linalg::CopyOp>(loc, memref, alloc);
return alloc;
}
@@ -48,6 +49,7 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
SmallVector<Range, 4> loopRanges;
// Allocate a buffer for every tensor result.
+ assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
size_t resultIndex = en.index();
Type resultType = en.value();
@@ -60,46 +62,26 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
}
auto tensorShape = tensorType.getShape();
auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
+ Value resultTensor = adaptor.outputs()[resultIndex];
- // Allocate buffers for init tensors that are assumed to fold onto the first
- // results.
- // TODO: update this assumption because the reality is more complex
- // under linalg on tensor based transformations.
- bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors();
- if (hasInitTensor) {
- resultBuffers.push_back(
- cloneMemref(loc, adaptor.init_tensors()[resultIndex], b));
+ // Clone output buffers whose value is actually used.
+ if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
+ resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
continue;
}
+ if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) {
+ resultBuffers.push_back(resultTensor);
+ continue;
+ }
// Allocate buffers for statically-shaped results.
if (memrefType.hasStaticShape()) {
resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
continue;
}
- // Perform a naive shape inference for the dynamically-shaped results.
- // Extract the required element out of the vector.
- SmallVector<Value, 4> dynOperands;
- auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
- for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
- if (loopRanges.empty())
- loopRanges = linalgOp.createLoopRanges(b, loc);
- if (shapeElement.value() != ShapedType::kDynamicSize)
- continue;
- AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
- switch (expr.getKind()) {
- case AffineExprKind::DimId: {
- int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
- Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
- dynOperands.push_back(size);
- break;
- }
- default:
- return failure();
- }
- }
- resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
+ resultBuffers.push_back(b.create<AllocOp>(
+ loc, memrefType, getDynOperands(loc, resultTensor, b)));
}
return success();
}
@@ -119,8 +101,7 @@ finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
genericOp.getLoc(),
/*resultTensorTypes=*/llvm::None,
/*inputs=*/inputs,
- /*outputBuffers=*/outputs,
- /*initTensors=*/llvm::None, genericOp.indexing_maps(),
+ /*outputs=*/outputs, genericOp.indexing_maps(),
genericOp.iterator_types(), genericOp.docAttr(),
genericOp.library_callAttr(), genericOp.sparseAttr());
@@ -130,10 +111,6 @@ finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
oldBlock->getArgumentTypes());
- // Add the result arguments to the new block.
- for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
- newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
-
// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
mapping.map(oldBlock->getArguments(), newBlock->getArguments());
@@ -159,12 +136,8 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
newOperands.append(outputs.begin(), outputs.end());
auto otherOperands = linalgOp.getAssumedNonShapedOperands();
newOperands.append(otherOperands.begin(), otherOperands.end());
- LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
- /*resultTypes=*/ArrayRef<Type>{},
- newOperands));
- // Need to mutate the operands_segment_sizes in the resulting op.
- res.setNumOutputBuffers(outputs.size());
- res.setNumInitTensors(0);
+ linalgOp.clone(rewriter, linalgOp.getLoc(),
+ /*resultTypes=*/ArrayRef<Type>{}, newOperands);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(linalgOp, outputs);
}
@@ -174,6 +147,24 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
//===----------------------------------------------------------------------===//
namespace {
+
+/// Generic conversion pattern that matches any LinalgOp. This avoids template
+/// instantiating one pattern for each LinalgOp.
+class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
+public:
+ using OpConversionPattern<InitTensorOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
+ rewriter.replaceOpWithNewOp<AllocOp>(
+ op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
+ adaptor.sizes());
+ return success();
+ }
+};
+
/// Generic conversion pattern that matches any LinalgOp. This avoids template
/// instantiating one pattern for each LinalgOp.
class BufferizeAnyLinalgOp : public ConversionPattern {
@@ -190,13 +181,12 @@ class BufferizeAnyLinalgOp : public ConversionPattern {
return failure();
// We abuse the GenericOpAdaptor here.
- // TODO: Manually create an Adaptor that captures inputs, output_buffers and
- // init_tensors for all linalg::LinalgOp interface ops.
+ // TODO: Manually create an Adaptor that captures inputs and outputs for all
+ // linalg::LinalgOp interface ops.
linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
Location loc = linalgOp.getLoc();
- SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
- adaptor.output_buffers().end());
+ SmallVector<Value, 2> newOutputBuffers;
if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
newOutputBuffers, rewriter))) {
@@ -327,7 +317,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
// Mark all Standard operations legal.
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
- target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
+ target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
// Mark all Linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {
@@ -354,10 +344,11 @@ void mlir::linalg::populateLinalgBufferizePatterns(
OwningRewritePatternList &patterns) {
patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
// TODO: Drop this once tensor constants work in standard.
+ // clang-format off
patterns.insert<
- // clang-format off
+ BufferizeInitTensorOp,
SubTensorOpConverter,
SubTensorInsertOpConverter
- // clang-format on
- >(typeConverter, context);
+ >(typeConverter, context);
+ // clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index bf488f827f89..8d09d58b9d7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -189,7 +189,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
- for (ShapedType shapedType : op.getInputOutputShapedTypes())
+ for (ShapedType shapedType : op.getShapedOperandTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
DenseSet<unsigned> unitDims;
ArrayAttr iteratorTypes = op.iterator_types();
@@ -295,7 +295,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
// TODO: support init_tensors and reductions.
- if (!op.hasTensorSemantics() || !op.init_tensors().empty())
+ if (!op.hasTensorSemantics() || op.getNumInitTensors() != 0)
return failure();
MLIRContext *context = rewriter.getContext();
@@ -306,7 +306,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
for (auto it :
- llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
+ llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
auto replacementInfo = replaceUnitExtents(
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
context);
@@ -342,19 +342,16 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
};
SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
- SmallVector<Value, 4> newOutputBuffers =
- insertReshapes(op.output_buffers());
- SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
+ SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
- // If any result type change, insert a reshape to convert from the original
+ // If any result type changes, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(op.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
- loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
- newIndexingMaps,
+ loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
llvm::to_vector<4>(
op.iterator_types().template getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
@@ -364,7 +361,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
// the original shape.
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
- unsigned index = result.index() + replacementOp.getNumOperands();
+ unsigned index = result.index() + replacementOp.getNumInputs();
RankedTensorType origResultType = op.getResult(result.index())
.getType()
.template cast<RankedTensorType>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 8ee1b389dee8..ada9f8c02b89 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -25,6 +25,61 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
[](Type type) { return type.isa<RankedTensorType>(); });
}
+/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
+/// the result types and return a list of values such that, for each result type
+/// `t` and value `v` at the same index `idx`:
+/// 1. `v.getType() == t`
+/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
+/// such operand. Then`v == operand_first`.
+/// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
+/// a. Static and dynamic dims extracted from the first operand of `op`.
+/// b. Elemental type equal to the elemental type of `t`.
+///
+/// This is sufficient because ElementwiseMappable guarantees that "The static
+/// types of all vector (resp. tensor) operands and results must have the same
+/// shape".
+static SmallVector<Value, 4>
+getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
+ assert(isElementwiseMappableOpOnRankedTensors(op));
+ Location loc = op->getLoc();
+ ValueRange operands = op->getOperands();
+ TypeRange rankedTensorTypes = op->getResultTypes();
+ SmallVector<Value, 4> res;
+ res.reserve(rankedTensorTypes.size());
+ for (Type t : rankedTensorTypes) {
+ // Try to find an operand with type matching the result tensor.
+ bool found = false;
+ for (Value v : operands) {
+ if (v.getType() == t) {
+ found = true;
+ res.push_back(v);
+ break;
+ }
+ }
+ if (found)
+ continue;
+
+ // Extract static / dynamic shape mix from the first operand.
+ Value firstOperand = operands.front();
+ auto rankedTensorType = t.cast<RankedTensorType>();
+ SmallVector<Value, 8> dynamicShape;
+ SmallVector<int64_t, 8> staticShape;
+ dynamicShape.reserve(rankedTensorType.getRank());
+ staticShape.reserve(rankedTensorType.getRank());
+ unsigned idx = 0;
+ for (auto shape : rankedTensorType.getShape()) {
+ staticShape.push_back(shape);
+ if (rankedTensorType.isDynamicDim(idx))
+ dynamicShape.push_back(b.create<DimOp>(loc, firstOperand, idx));
+ ++idx;
+ }
+ // Create init tensor.
+ res.push_back(b.create<linalg::InitTensorOp>(
+ loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
+ }
+ return res;
+}
+
namespace {
struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors()
@@ -41,18 +96,19 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
rewriter.getMultiDimIdentityMap(rank));
SmallVector<StringRef, 6> iteratorTypes(rank,
getParallelIteratorTypeName());
+ auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
- /*outputBuffers=*/ValueRange(),
- /*initTensors=*/ValueRange(),
+ /*outputs=*/outputs,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
OperationState state(loc, op->getName());
state.addAttributes(op->getAttrs());
- state.addOperands(regionArgs);
+ // Only take the input operands in the cloned elementwise op.
+ state.addOperands(regionArgs.take_front(op->getNumOperands()));
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return type.cast<TensorType>().getElementType();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d9ea7d8ccb29..b525108d22ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -169,8 +169,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
auto maps = op.indexing_maps();
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
- for (auto en : llvm::enumerate(ios)) {
+ for (auto en : llvm::enumerate(op.getShapedOperands())) {
// The method `getRangeFromOperandShape` requires using SubViewOp or
// SubTensorOps. If the value isnt defined from there continue.
// todo: The method should be adapted to get the values from
@@ -381,6 +380,8 @@ static bool isSameSubView(Value a, Value b) {
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
const LinalgDependenceGraph &dependenceGraph) {
+ assert(consumer.hasBufferSemantics() && "revisit usage of shaped operand");
+
// Only consider RAW and WAW atm.
for (auto depType : {
LinalgDependenceGraph::DependenceType::RAW,
@@ -390,26 +391,25 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
dependenceGraph.getDependencesInto(consumer, depType),
[consumerIdx](
LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
- return elem.indexingOpView.operandIndex == consumerIdx;
+ return elem.indexingOpView->getOperandNumber() == consumerIdx;
})) {
- auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
// Check that the dependence is indeed on the input `consumerIdx` view.
- auto consumedView =
- consumer.getBuffer(dependence.indexingOpView.operandIndex);
- if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
+ Value consumedView = dependence.indexingOpView->get();
+ if (!isSameSubView(consumer.getShapedOperand(consumerIdx), consumedView))
continue;
// Consumer consumes this view, `isStructurallyFusableProducer` also
// checks whether it is a strict subview of the producer view.
- auto producedView =
- producer.getBuffer(dependence.dependentOpView.operandIndex);
+ auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
+ Value producedView = dependence.dependentOpView->get();
LLVM_DEBUG(llvm::dbgs()
<< "\n"
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
- << "producer: " << *producer.getOperation()
- << " view: " << producedView << " output index: "
- << dependence.dependentOpView.operandIndex -
+ << "producer: " << *dependence.dependentOpView->getOwner()
+ << " view: " << dependence.dependentOpView->get()
+ << " output index: "
+ << dependence.dependentOpView->getOperandNumber() -
producer.getNumInputs()
<< "\n");
(void)producedView;
@@ -433,13 +433,15 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
if (!fusableDependence)
return {};
- LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ LinalgOp producerOp =
+ cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
// If producer is already in the same block as consumer, we are done.
if (consumer->getBlock() == producerOp->getBlock())
return {};
- unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
- producerOp.getNumInputs();
+ unsigned producerIdx =
+ fusableDependence->dependentOpView->getOperandNumber() -
+ producerOp.getNumInputs();
Value consumerView = consumer.getShapedOperand(consumerIdx);
// Must be a subview or a slice to guarantee there are loops we can fuse
@@ -548,12 +550,12 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
/// inverse(producerIndexMap).compose(consumerIndexMap)
static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
- auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+ auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
AffineMap producerIndexingMap =
- producer.getIndexingMap(dependence.dependentOpView.operandIndex);
- auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
+ producer.getIndexingMap(dependence.dependentOpView->getOperandNumber());
+ auto consumer = cast<LinalgOp>(dependence.indexingOpView->getOwner());
AffineMap consumerIndexingMap =
- consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
+ consumer.getIndexingMap(dependence.indexingOpView->getOperandNumber());
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
producer.iterator_types().getValue(), producerIndexingMap);
@@ -733,14 +735,14 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
for (LinalgOp op : reverse(ops)) {
for (auto operandIndex :
- llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
+ llvm::seq<unsigned>(0, op.getNumShapedOperands())) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependence =
findFusableProducer(op, operandIndex, dependenceGraph);
if (!fusableDependence)
continue;
LinalgOp producerOp =
- cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
// Do not fuse dependences that are to operations not in the same basic
// block. This avoid moving fused operations across loops that might
// themselves carry dependency making the fusion illegal.
@@ -750,7 +752,8 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
}
// Make sure that the indexing map of the view used for fusion in the
// producer is a projected permutation.
- unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+ unsigned producerIdx =
+ fusableDependence->dependentOpView->getOperandNumber();
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
if (!producerMap.isProjectedPermutation()) {
op.emitRemark(
@@ -760,7 +763,8 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
return FusableOpDependencesTy{};
}
- unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+ unsigned consumerIdx =
+ fusableDependence->indexingOpView->getOperandNumber();
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
if (!consumerMap.isProjectedPermutation()) {
op.emitRemark(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 22e03c1e2f92..b1ea07309b4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -128,7 +128,9 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
if (consumerArg.index() == consumerIdx + numConsumerIndices) {
// Map the arguments for the args from the producer.
- for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
+ for (auto producerArg :
+ llvm::enumerate(producerBlock.getArguments().take_front(
+ producer.getNumInputs() + numProducerIndices))) {
// If producer is an indexed_generic op, map the indices from consumer
// loop to producer loop (because the fusedOp is built based on
// consumer's perspective).
@@ -213,7 +215,6 @@ fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
consumerIndexMaps.end());
// Generate the fused op.
- // Tensor-level fusion is only on ops without initTensors and outputBuffers.
LinalgOp fusedOp;
if (isa<GenericOp>(producer.getOperation()) &&
isa<GenericOp>(consumer.getOperation())) {
@@ -221,8 +222,8 @@ fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
rewriter
.create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
/*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{},
+ // TODO: handle outputs.
+ consumer.getOutputs(),
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
@@ -230,18 +231,18 @@ fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
/*sparse=*/nullptr)
.getOperation();
} else {
- fusedOp = rewriter
- .create<IndexedGenericOp>(
- consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{},
- rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*sparse=*/nullptr)
- .getOperation();
+ fusedOp =
+ rewriter
+ .create<IndexedGenericOp>(
+ consumer.getLoc(), consumer->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ // TODO: handle outputs.
+ consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*sparse=*/nullptr)
+ .getOperation();
}
// Construct an AffineMap from consumer loops to producer loops.
@@ -430,6 +431,42 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
});
}
+// Get the output tensor to use for the expanded operation. Creates an
+// `linalg.init_tensor` operation to materialize the tensor that carries the
+// shape information.
+static Value getOutputValueForExpansion(
+ OpBuilder &builder, Location loc, AffineMap outputIndexingMap, Value result,
+ ArrayRef<SmallVector<int64_t, 4>> origDimToExpandedShapeMap) {
+ SmallVector<Value, 4> dynamicDims;
+ SmallVector<int64_t, 4> staticDims;
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ ArrayRef<int64_t> origShape = resultType.getShape();
+ for (AffineExpr expr : outputIndexingMap.getResults()) {
+ unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
+ ArrayRef<int64_t> expandedShape(origDimToExpandedShapeMap[origDimPos]);
+ bool foundDynamic = false;
+ int64_t linearizedShape = 1;
+ for (int64_t extent : expandedShape) {
+ if (ShapedType::isDynamic(extent)) {
+ assert(!foundDynamic &&
+ "Expanded dimensions of reshape can have only one dynamic dim");
+ staticDims.push_back(ShapedType::kDynamicSize);
+ foundDynamic = true;
+ continue;
+ }
+ staticDims.push_back(extent);
+ linearizedShape *= extent;
+ }
+ if (ShapedType::isDynamic(origShape[origDimPos])) {
+ Value origDim = builder.create<DimOp>(loc, result, origDimPos);
+ dynamicDims.push_back(builder.create<UnsignedDivIOp>(
+ loc, origDim, builder.create<ConstantIndexOp>(loc, linearizedShape)));
+ }
+ }
+ return builder.create<linalg::InitTensorOp>(loc, dynamicDims, staticDims,
+ resultType.getElementType());
+}
+
/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
/// conditions have been satisfied.
@@ -548,7 +585,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
expandedOpOperands.push_back(reshapeOp.src());
continue;
}
- AffineMap indexingMap = linalgOp.getIndexingMap(operand.index());
+ AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
SmallVector<ReassociationIndices, 4> reassociation;
SmallVector<int64_t, 4> expandedOperandShape;
getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
@@ -563,17 +600,17 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
expandedOpOperands.push_back(operand.value());
}
}
- SmallVector<Type, 1> resultTypes;
+
+ Location loc = linalgOp.getLoc();
+ SmallVector<Value, 1> outputs;
SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
- for (auto result : llvm::enumerate(linalgOp->getResults())) {
- AffineMap indexingMap =
- linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index());
+ for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
+ AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
SmallVector<ReassociationIndices, 4> reassociation;
SmallVector<int64_t, 4> expandedResultShape;
getReshapeInfo(indexingMap, reassociation, expandedResultShape);
- resultTypes.push_back(RankedTensorType::get(
- expandedResultShape,
- result.value().getType().cast<ShapedType>().getElementType()));
+ outputs.push_back(getOutputValueForExpansion(
+ rewriter, loc, indexingMap, result.value(), expandedDimsShape));
resultReassociation.emplace_back(std::move(reassociation));
}
@@ -581,11 +618,11 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
getParallelIteratorTypeName());
+ TypeRange resultTypes = ValueRange(outputs).getTypes();
LinalgOp fusedOp = createLinalgOpOfSameType(
linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
- /*inputs=*/expandedOpOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes);
+ /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps,
+ iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
@@ -656,6 +693,47 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
return resultVals;
}
+static Value
+getOutputValueForLinearization(OpBuilder &builder, Location loc,
+ Value origOutput,
+ ArrayRef<AffineMap> reassociationMaps) {
+ SmallVector<Value, 4> dynamicDims;
+ SmallVector<int64_t, 4> staticDims;
+ auto shapedType = origOutput.getType().cast<ShapedType>();
+ ArrayRef<int64_t> origShape = shapedType.getShape();
+ for (auto map : reassociationMaps) {
+ Optional<Value> dynamicDim;
+ int64_t staticLinearizedShape = 1;
+ for (AffineDimExpr expr :
+ llvm::map_range(map.getResults(), [](AffineExpr e) {
+ return e.cast<AffineDimExpr>();
+ })) {
+ unsigned pos = expr.getPosition();
+ if (ShapedType::isDynamic(origShape[pos])) {
+ Value dim = builder.create<DimOp>(loc, origOutput, pos);
+ if (dynamicDim) {
+ dynamicDim = builder.create<MulIOp>(loc, dynamicDim.getValue(), dim);
+ } else {
+ dynamicDim = dim;
+ }
+ } else {
+ staticLinearizedShape *= origShape[pos];
+ }
+ }
+ if (dynamicDim) {
+ dynamicDim = builder.create<MulIOp>(
+ loc, dynamicDim.getValue(),
+ builder.create<ConstantIndexOp>(loc, staticLinearizedShape));
+ dynamicDims.push_back(dynamicDim.getValue());
+ staticDims.push_back(ShapedType::kDynamicSize);
+ } else {
+ staticDims.push_back(staticLinearizedShape);
+ }
+ }
+ return builder.create<InitTensorOp>(loc, dynamicDims, staticDims,
+ shapedType.getElementType());
+}
+
namespace {
/// Pattern to fold tensor_reshape op with its consumer by using the source of
@@ -704,6 +782,8 @@ struct FoldProducerReshapeOpByLinearization
// Compute the fused operands list,
SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
fusedOperands[operand.index()] = reshapeOp.src();
+ fusedOperands.append(linalgOp.getOutputs().begin(),
+ linalgOp.getOutputs().end());
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
@@ -736,7 +816,7 @@ struct FoldProducerReshapeOpByLinearization
rewriter.eraseOp(reshapeOp);
return success();
}
- return op.emitRemark("no fusion candidates found");
+ return failure();
}
};
@@ -816,12 +896,15 @@ struct FoldConsumerReshapeOpByLinearization
if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
return reshapeOp.emitRemark("fused op loop bound computation failed");
+ Location loc = producer.getLoc();
+ Value output =
+ getOutputValueForLinearization(rewriter, loc, producer.getOutputs()[0],
+ reshapeOp.getReassociationMaps());
LinalgOp fusedOp = createLinalgOpOfSameType(
- producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(),
+ producer, rewriter, loc, reshapeOp.getResultType(),
/*inputs=*/producer.getInputs(),
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, // no init tensors for now.
- rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ // TODO: handle outputs.
+ /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
@@ -902,8 +985,7 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
linalgOp, rewriter, rewriter.getUnknownLoc(),
linalgOp->getResultTypes(),
/*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, // no init tensors for now.
+ /*outputs=*/linalgOp.getOutputs(),
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
linalgOp.iterator_types(),
/*doc=*/nullptr,
@@ -915,7 +997,7 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
Region &linalgOpRegion = linalgOp->getRegion(0);
Block &entryBlock = *linalgOpRegion.begin();
unsigned argIndex = entryBlock.getNumArguments() -
- linalgOp.getNumInputs() + operand.index();
+ linalgOp.getNumShapedOperands() + operand.index();
BlockAndValueMapping mapping;
mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3496a7796988..454bbbe3578a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -45,8 +45,8 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
return builder.create<linalg::GenericOp>(
- namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
- namedOp.getInitTensors(), indexingMaps, iterators,
+ namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
+ indexingMaps, iterators,
[®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
edsc::ScopedContext scope(bodyBuilder, loc);
regionBuilder(*bodyBuilder.getBlock());
@@ -153,8 +153,8 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
return builder.create<linalg::GenericOp>(
convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
- convOp.getInputBuffers(), convOp.getOutputBuffers(),
- /*initTensors=*/ValueRange(), indexingMaps, iterators,
+ convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
+ iterators,
[](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
Value mul =
bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index a7f0660281b5..cac0ae0d081c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -64,7 +64,7 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
assert(permutationMap && "expected permutation to be invertible");
SmallVector<Attribute, 4> newIndexingMaps;
auto indexingMaps = op.indexing_maps().getValue();
- for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) {
+ for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 073673bc33f8..329cc88bd2ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -172,7 +172,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
LinalgOp linalgOp, const LinalgPromotionOptions &options)
: subViews(), dynamicBuffers(options.dynamicBuffers),
alignment(options.alignment) {
- unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
+ assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
+ unsigned nBuffers = linalgOp.getNumShapedOperands();
auto vUseFullTileBuffers =
options.useFullTileBuffers.getValueOr(llvm::SmallBitVector());
vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault);
@@ -180,7 +181,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
for (unsigned idx = 0; idx != nBuffers; ++idx) {
if (options.operandsToPromote && !options.operandsToPromote->count(idx))
continue;
- auto *op = linalgOp.getBuffer(idx).getDefiningOp();
+ auto *op = linalgOp.getShapedOperand(idx).getDefiningOp();
if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
subViews[idx] = sv;
useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
@@ -326,10 +327,10 @@ promoteSubViews(OpBuilder &b, LinalgOp op,
// operands are not views. This is to support cases such as FillOp taking
// extra scalars etc. Keep a reference to output buffers;
SmallVector<Value, 8> opViews;
- opViews.reserve(op.getNumInputsAndOutputs());
+ opViews.reserve(op.getNumShapedOperands());
SmallVector<std::pair<Value, Value>, 8> writebackViews;
writebackViews.reserve(promotedBuffersAndViews->size());
- for (auto view : llvm::enumerate(op.getInputsAndOutputBuffers())) {
+ for (auto view : llvm::enumerate(op.getShapedOperands())) {
if (options.subViews.count(view.index()) != 0) {
if (options.useFullTileBuffers[view.value()])
opViews.push_back(
@@ -371,7 +372,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
if (!linOp || !linOp.hasBufferSemantics())
return failure();
// Check that at least one of the requested operands is indeed a subview.
- for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) {
+ for (auto en : llvm::enumerate(linOp.getShapedOperands())) {
auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
if (sv) {
if (!options.operandsToPromote.hasValue() ||
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index fed2eedd41a4..eb940d0f769b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -334,7 +334,7 @@ struct CodeGen {
/// Helper method to inspect sparse annotations in the linalg operation.
/// Fills the per-dimension sparsity information for all tensors.
static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
- unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numTensors = op.getNumShapedOperands();
ArrayAttr sparseAttr = op.sparseAttr();
for (unsigned t = 0; t < numTensors; t++) {
auto map = op.getIndexingMap(t);
@@ -467,7 +467,7 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
// is set to a synthetic tensor with undefined indices only.
unsigned s = merger.addSet();
unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
- : op.getNumInputsAndOutputs();
+ : op.getNumShapedOperands() - 1;
merger.set(s).push_back(merger.addLat(t, idx, exp));
return s;
}
@@ -504,7 +504,7 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
static void genBuffers(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op) {
Location loc = op.getLoc();
- unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numTensors = op.getNumShapedOperands();
unsigned numInputs = op.getNumInputs();
assert(numTensors == numInputs + 1);
@@ -544,7 +544,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
up = codegen.sizes[i];
assert(up); // TODO: what else?
} else {
- Value arg = t < numInputs ? op.getInput(t) : op.getInitTensor(0);
+ Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0];
up = rewriter.create<DimOp>(loc, arg, d);
}
args.push_back(up);
@@ -597,7 +597,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned tensor, Value rhs) {
// Test if this is a scalarized reduction.
- unsigned lhs = op.getNumInputsAndOutputs() - 1;
+ unsigned lhs = op.getNumShapedOperands() - 1;
if (lhs == tensor && codegen.redVal) {
codegen.redVal = rhs;
return;
@@ -670,7 +670,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
atLevel = true;
}
// All exhausted at this level (atLevel denotes exactly at this level).
- unsigned lhs = op.getNumInputsAndOutputs() - 1;
+ unsigned lhs = op.getNumShapedOperands() - 1;
if (lhs == tensor) {
codegen.redExp = hoist ? exp : -1u;
} else if (atLevel) {
@@ -995,7 +995,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
unsigned exp, unsigned at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == topSort.size()) {
- unsigned lhs = op.getNumInputsAndOutputs() - 1;
+ unsigned lhs = op.getNumShapedOperands() - 1;
Value rhs = genExp(merger, codegen, rewriter, op, exp);
genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
return;
@@ -1073,7 +1073,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
Value red = codegen.redVal;
if (red) {
codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
- unsigned lhs = op.getNumInputsAndOutputs() - 1;
+ unsigned lhs = op.getNumShapedOperands() - 1;
genTensorStore(merger, codegen, rewriter, op, lhs, red);
}
codegen.loops[idx] = Value();
@@ -1095,7 +1095,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (!op.hasSparseSemantics())
return failure();
assert(op.getNumOutputs() == 1);
- unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numTensors = op.getNumShapedOperands();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
findSparseAnnotations(merger, op);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 423d687c1eb8..f323d2e50435 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -375,9 +375,9 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
// 2. Create the tiled loops.
LinalgOp res = op;
SmallVector<Value, 4> ivs, tensorResults;
- auto initTensors = op.getInitTensors();
+ auto outputTensors = op.getOutputTensors();
GenerateLoopNest<LoopTy>::doit(
- loopRanges, /*iterArgInitValues*/ initTensors, iteratorTypes,
+ loopRanges, /*iterArgInitValues*/ outputTensors, iteratorTypes,
[&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
@@ -392,14 +392,16 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
else
interchangedIvs.assign(ivs.begin(), ivs.end());
- assert(op.getNumInitTensors() == iterArgs.size() &&
- "num init tensors must match number of loop iter arguments");
- // This uses knowledge about position of the init tensor in the list
- // of operands.
- auto operands = llvm::to_vector<4>(op.getShapedOperands());
- std::copy(iterArgs.begin(), iterArgs.end(),
- operands.begin() + op.getNumInputsAndOutputBuffers());
+ assert(op.getNumOutputTensors() == iterArgs.size() &&
+ "num output tensors must match number of loop iter arguments");
+ auto operands = llvm::to_vector<4>(op.getInputs());
+ SmallVector<Value, 4> outputBuffers = op.getOutputBuffers();
+ // TODO: thanks to simplifying assumption we do not need to worry about
+ // order of output buffers and tensors: there is only ever one kind.
+ assert(outputBuffers.empty() || iterArgs.empty());
+ operands.append(outputBuffers.begin(), outputBuffers.end());
+ operands.append(iterArgs.begin(), iterArgs.end());
SmallVector<Value, 4> tiledOperands =
makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap,
interchangedIvs, tileSizes, allShapeSizes);
@@ -407,41 +409,31 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
tiledOperands.append(nonShapedOperands.begin(),
nonShapedOperands.end());
- // If LinalgOp has results, they must all be tied to init tensors.
- // We enforce this to ensure all tiled ops have been rewritten in
- // "init tensor" form. This ensures tiling has anchor values into which
- // to subtensor / subtensor_insert. Otherwise tiling would need to
- // allocate which is not acceptable.
- // This would not be the case with a special terminator op that
- // generates the whole tensor (instead of inserting a subtensor). But
- // the generator-based abstraction has other issues.
- assert(op.getNumInitTensors() == op->getNumResults() &&
- "expected same number of init tensors as number of results");
-
- // Handle init tensor operands.
- // This uses knowledge about position of the init tensor in the list
- // of operands.
- // TODO: InterfaceAdaptor ?
+ // TODO: use an interface/adaptor to avoid leaking position in
+ // `tiledOperands`.
SmallVector<Type, 4> resultTensorTypes;
- for (auto idx : llvm::seq<unsigned>(0, op.getNumInitTensors()))
+ for (OpOperand *opOperand : op.getOutputTensorsOpOperands())
resultTensorTypes.push_back(
- tiledOperands[op.getNumInputsAndOutputBuffers() + idx].getType());
+ tiledOperands[opOperand->getOperandNumber()].getType());
res = op.clone(b, loc, resultTensorTypes, tiledOperands);
- // Insert a subtensor_insert for each init subtensor.
- for (unsigned idx = 0, e = op.getNumInitTensors(); idx != e; ++idx) {
- Value initTensor =
- tiledOperands[op.getNumInputsAndOutputBuffers() + idx];
- if (auto subtensor = initTensor.getDefiningOp<SubTensorOp>()) {
+ // Insert a subtensor_insert for each output tensor.
+ unsigned resultIdx = 0;
+ for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) {
+ // TODO: use an interface/adaptor to avoid leaking position in
+ // `tiledOperands`.
+ Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
+ if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) {
tensorResults.push_back(b.create<SubTensorInsertOp>(
- loc, subtensor.source().getType(), res->getResult(idx),
+ loc, subtensor.source().getType(), res->getResult(resultIdx),
subtensor.source(), subtensor.offsets(), subtensor.sizes(),
subtensor.strides(), subtensor.static_offsets(),
subtensor.static_sizes(), subtensor.static_strides()));
} else {
- tensorResults.push_back(res->getResult(idx));
+ tensorResults.push_back(res->getResult(resultIdx));
}
+ ++resultIdx;
}
return scf::ValueVector(tensorResults.begin(), tensorResults.end());
},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 804ae6681f8c..c5d811c41edb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -125,17 +125,6 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
- // If LinalgOp has results, they must all be tied to init tensors.
- // We enforce this to ensure all tiled ops have been rewritten in
- // "init tensor" form. This ensures tiling has anchor values into which to
- // subtensor / subtensor_insert. Otherwise tiling would need to allocate which
- // is not acceptable.
- // This would not be the case with a special terminator op that generates the
- // whole tensor (instead of inserting a subtensor). But the generator-based
- // abstraction has other issues.
- if (linalgOp.getNumInitTensors() != linalgOp->getNumResults())
- return failure();
-
Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
if (!res)
@@ -174,10 +163,10 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
producers.insert(linalgOp);
for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
if (!fusionOptions.indicesToFuse.count(
- dependence.indexingOpView.operandIndex))
+ dependence.indexingOpView->getOperandNumber()))
continue;
- if (isa<LinalgOp>(dependence.dependentOpView.op))
- producers.insert(dependence.dependentOpView.op);
+ if (isa<LinalgOp>(dependence.dependentOpView->getOwner()))
+ producers.insert(dependence.dependentOpView->getOwner());
}
SmallVector<LinalgOp, 1> fusionOps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7165ee775e9c..23e452df9184 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -199,9 +199,8 @@ class GenericVectorizer {
// block argument.
auto scalarArg = scalarValue.cast<BlockArgument>();
assert(scalarArg.getOwner() == &generic.region().front());
- Value vector_arg =
- generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()];
- Value vectorResult = transferReadVector(builder, vector_arg);
+ Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber());
+ Value vectorResult = transferReadVector(builder, vectorArg);
valueCache[scalarArg] = vectorResult;
return vectorResult;
}
@@ -277,7 +276,7 @@ static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
- for (Value operand : linalgOp.getInputsAndOutputBuffers())
+ for (Value operand : linalgOp.getShapedOperands())
if (!operand.getType().cast<ShapedType>().hasStaticShape())
return failure();
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f44bb6769e61..81bfbc6ecf52 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -104,12 +104,6 @@ SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
auto shape = v.getType().cast<ShapedType>().getShape();
res.append(shape.begin(), shape.end());
}
- if (linalgOp.getNumInitTensors())
- return res;
- for (Value v : linalgOp.getOperation()->getResults()) {
- auto shape = v.getType().cast<ShapedType>().getShape();
- res.append(shape.begin(), shape.end());
- }
return res;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c0af06314086..30bf546807c4 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1477,12 +1477,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
return success();
}
};
-
} // end anonymous namespace.
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<DimOfMemRefReshape, DimOfCastOp<tensor::CastOp>>(context);
+ results.insert<DimOfMemRefReshape, DimOfCastOp<TensorToMemrefOp>,
+ DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 368568bdcc4a..08d715f90b5e 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -linalg-bufferize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s
#map0 = affine_map<(d0) -> (d0)>
@@ -26,8 +26,9 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]
- } ins(%arg0 : tensor<4xf32>) {
- ^bb0(%gen_arg1: f32):
+ } ins(%arg0 : tensor<4xf32>)
+ outs(%arg0 : tensor<4xf32>) {
+ ^bb0(%gen_arg1: f32, %out: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1 : f32
} -> tensor<4xf32>
@@ -35,6 +36,35 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> {
}
+// -----
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// Same as above but with linalg.init_tensor op.
+
+// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @init_tensor(
+// CHECK-SAME: %[[IN:.*]]: tensor<?xf32>, %[[SIZE:.*]]: index)
+// CHECK: %[[OUT_BUF:.*]] = alloc(%[[SIZE]]) : memref<?xf32>
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[IN]] : memref<?xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[MEMREF]] : memref<?xf32>)
+// CHECK-SAME: outs(%[[OUT_BUF]] : memref<?xf32>) {
+func @init_tensor(%in : tensor<?xf32>, %size: index) -> tensor<?xf32> {
+ %init = linalg.init_tensor [%size] : tensor<?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>)
+ outs(%init : tensor<?xf32>) {
+ ^bb0(%gen_arg1: f32, %out: f32):
+ %tmp1 = exp %gen_arg1 : f32
+ linalg.yield %tmp1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+
// -----
#map0 = affine_map<(d0) -> (d0)>
@@ -50,8 +80,9 @@ func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0, %1 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel"]
- } ins(%arg0 : tensor<4xf32>) {
- ^bb0(%gen_arg1: f32):
+ } ins(%arg0 : tensor<4xf32>)
+ outs (%arg0, %arg0 : tensor<4xf32>, tensor<4xf32>) {
+ ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1, %tmp1 : f32, f32
} -> tensor<4xf32>, tensor<4xf32>
@@ -74,8 +105,9 @@ func @multiple_results_indexed(%arg0: tensor<4xi32>)
%0, %1 = linalg.indexed_generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel"]
- } ins(%arg0 : tensor<4xi32>) {
- ^bb0(%i: index, %gen_arg1: i32):
+ } ins(%arg0 : tensor<4xi32>)
+ outs (%arg0, %arg0 : tensor<4xi32>, tensor<4xi32>) {
+ ^bb0(%i: index, %gen_arg1: i32, %out1: i32, %out2: i32):
%i_i32 = index_cast %i : index to i32
%tmp1 = addi %gen_arg1, %i_i32 : i32
linalg.yield %tmp1, %tmp1 : i32, i32
@@ -86,32 +118,30 @@ func @multiple_results_indexed(%arg0: tensor<4xi32>)
// -----
#map_2d = affine_map<(d0, d1) -> (d0, d1)>
-#map_2d_inv = affine_map<(d0, d1) -> (d1, d0)>
// Check that the allocs properly consider the
diff erent shapes of the output
// operands. The permuted indexing maps translate to
diff erent output shapes.
-// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func @dynamic_results(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
-// CHECK: %[[MEMREF_ARG:.*]] = tensor_to_memref %[[ARG]] : memref<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[DIM0:.*]] = dim %[[ARG]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[MEMREF_ARG:.*]] = tensor_to_memref %[[ARG]] : memref<?x?xf32>
+// CHECK: %[[DIM0:.*]] = dim %[[ARG]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM1:.*]] = dim %[[ARG]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[RESULT0:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
-// CHECK: %[[RESULT1:.*]] = alloc(%[[DIM1]], %[[DIM0]]) : memref<?x?xf32>
-// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1]
+// CHECK: %[[RESULT1:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+// CHECK: linalg.generic
// CHECK-SAME: ins(%[[MEMREF_ARG]] : memref<?x?xf32>)
// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<?x?xf32>, memref<?x?xf32>)
func @dynamic_results(%arg0: tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0, %1 = linalg.generic {
- indexing_maps = [#map_2d, #map_2d, #map_2d_inv],
+ indexing_maps = [#map_2d, #map_2d, #map_2d],
iterator_types = ["parallel", "parallel"]
- } ins(%arg0 : tensor<?x?xf32>) {
- ^bb0(%gen_arg1: f32):
+ } ins(%arg0 : tensor<?x?xf32>)
+ outs (%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) {
+ ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1, %tmp1 : f32, f32
} -> tensor<?x?xf32>, tensor<?x?xf32>
@@ -147,10 +177,9 @@ func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>,
%0 = linalg.generic #trait
ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
- init(%arg1 : tensor<3x2xf32>) {
+ outs(%arg1 : tensor<3x2xf32>) {
^bb(%v0: vector<3x4xi4>, %v1: f32) :
- %f0 = constant 0.0 : f32
- linalg.yield %f0 : f32
+ linalg.yield %v1 : f32
} -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
@@ -204,16 +233,16 @@ func @bufferize_subtensor_insert(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %
(tensor<?x?xf32>, tensor<?x?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
- // CHECK: %[[IDX:.*]] = call @make_index() : () -> index
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+ // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
%i0 = call @make_index() : () -> index
+ // CHECK: %[[IDX:.*]] = call @make_index() : () -> index
- // CHECK-DAG: %[[M0:.*]] = tensor_to_memref %[[T]] : memref<?x?xf32>
- // CHECK-DAG: %[[SM0:.*]] = tensor_to_memref %[[ST0]] : memref<2x3xf32>
- // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
- // CHECK-NEXT: %[[DIM0:.*]] = dim %[[M0]], %[[C0]] : memref<?x?xf32>
- // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
- // CHECK-NEXT: %[[DIM1:.*]] = dim %[[M0]], %[[C1]] : memref<?x?xf32>
+ // CHECK-DAG: %[[M0:.*]] = tensor_to_memref %[[T]] : memref<?x?xf32>
+ // CHECK-DAG: %[[SM0:.*]] = tensor_to_memref %[[ST0]] : memref<2x3xf32>
+ // CHECK-NEXT: %[[DIM0:.*]] = dim %[[T]], %[[C0]] : tensor<?x?xf32>
+ // CHECK-NEXT: %[[DIM1:.*]] = dim %[[T]], %[[C1]] : tensor<?x?xf32>
// CHECK-NEXT: %[[M0_COPY:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
// CHECK-NEXT: linalg.copy(%[[M0]], %[[M0_COPY]]) : memref<?x?xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[SUBVIEW0:.*]] = subview %[[M0_COPY]][0, 0] [2, 3] [1, 1]
@@ -224,10 +253,6 @@ func @bufferize_subtensor_insert(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %
// CHECK-DAG: %[[M1:.*]] = tensor_to_memref %[[T]] : memref<?x?xf32>
// CHECK-DAG: %[[SM1:.*]] = tensor_to_memref %[[ST1]] : memref<2x?xf32>
- // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
- // CHECK-NEXT: %[[DIM0:.*]] = dim %[[M1]], %[[C0]] : memref<?x?xf32>
- // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
- // CHECK-NEXT: %[[DIM1:.*]] = dim %[[M1]], %[[C1]] : memref<?x?xf32>
// CHECK-NEXT: %[[M1_COPY:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
// CHECK-NEXT: linalg.copy(%[[M1]], %[[M1_COPY]]) : memref<?x?xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[SUBVIEW1:.*]] = subview %[[M1_COPY]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2]
@@ -239,3 +264,4 @@ func @bufferize_subtensor_insert(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %
// CHECK: return %[[RT0]], %[[RT1]]
return %t0, %t1: tensor<?x?xf32>, tensor<?x?xf32>
}
+
diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
index 8c08fb390b9e..de894b9192fb 100644
--- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
@@ -8,10 +8,12 @@
// CHECK-LABEL: @basic
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]]]
- // CHECK: ^bb0(%[[BBARG:.*]]: f32):
+ // CHECK: ^bb0(%[[BBARG:.*]]: f32, %{{.*}}: f32):
// CHECK: addf %[[BBARG]], %[[BBARG]]
- %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
+ ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%1 = addf %arg1, %arg2 : f32
linalg.yield %1 : f32
} -> tensor<?xf32>
@@ -31,8 +33,10 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-LABEL: @distinct_affine_maps
func @distinct_affine_maps(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%1 = addf %arg1, %arg2 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
@@ -52,10 +56,12 @@ func @distinct_affine_maps(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-LABEL: @mixed_redundant_non_redundant
func @mixed_redundant_non_redundant(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
- // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32):
+ // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
// CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]])
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%1 = "test.elementwise_mappable"(%arg1, %arg2, %arg3) : (f32, f32, f32) -> f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
@@ -72,10 +78,12 @@ func @mixed_redundant_non_redundant(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-LABEL: @multiple_
diff erent_redundant_args
func @multiple_
diff erent_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]], #[[$MAP]]]
- // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32):
+ // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
// CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]], %[[BBARG1]])
- %0 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg0, %arg1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32):
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel"]}
+ ins(%arg0, %arg1, %arg0, %arg1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%1 = "test.elementwise_mappable"(%arg2, %arg3, %arg4, %arg5) : (f32, f32, f32, f32) -> f32
linalg.yield %1 : f32
} -> tensor<?xf32>
@@ -93,10 +101,12 @@ func @multiple_
diff erent_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf3
// CHECK-LABEL: @indexed_generic
func @indexed_generic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.indexed_generic
- // CHECK: ^bb0(%{{.*}}: index, %[[BBARG:.*]]: f32):
+ // CHECK: ^bb0(%{{.*}}: index, %[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
// CHECK: addf %[[BBARG]], %[[BBARG]]
- %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%index: index, %arg1: f32, %arg2: f32):
+ %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
+ ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?xf32>) {
+ ^bb0(%index: index, %arg1: f32, %arg2: f32, %arg3: f32):
%1 = addf %arg1, %arg2 : f32
linalg.yield %1 : f32
} -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 6c12070e07f1..f015d5fd64fd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -232,7 +232,6 @@ func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
// -----
#accesses = [
- affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>
]
@@ -246,7 +245,7 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>
// tensor<0xf32> cannot be dce'ed
- %1 = linalg.generic #trait ins(%arg1 : tensor<0xf32>) {
+ %1 = linalg.generic #trait outs(%arg1 : tensor<0xf32>) {
^bb(%0: f32) :
linalg.yield %0 : f32
} -> tensor<0xf32>
@@ -326,9 +325,9 @@ func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf3
%tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
// CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
- // CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
+ // CHECK-SAME: outs({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
%0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
- init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+ outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
@@ -344,7 +343,7 @@ func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf3
func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
// CHECK-NOT: %{{.*}} = linalg.matmul
%t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
- init(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
+ outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: %{{.*}} = linalg.matmul
linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index 7ea78fef7add..8dca137843bb 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -1,14 +1,20 @@
// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
// In-depth checking of the linalg.generic op for a very trivial case.
-// CHECK: #map = affine_map<() -> ()>
-// CHECK-LABEL: func @addf_rank0
+// CHECK: #[[$MAP:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: func @addf_rank0
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32>
func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
- // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
- // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
- // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
- // CHECK: linalg.yield %[[YIELD]] : f32
- // CHECK: } -> tensor<f32>
+ // CHECK: %{{.*}} = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
+ // CHECK-SAME: iterator_types = []
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+ // CHECK-SAME: outs(%[[ARG0]]
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
+ // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
+ // CHECK: linalg.yield %[[YIELD]] : f32
+ // CHECK: } -> tensor<f32>
%0 = addf %arg0, %arg1 : tensor<f32>
return %0 : tensor<f32>
}
@@ -16,10 +22,14 @@ func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
// -----
// Check indexing maps and iterator types for the rank > 0 case.
-// CHECK: #map = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @addf_rank1
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<?xf32>
func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
- // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
+ // CHECK: linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel"]
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+ // CHECK-SAME: outs(%[[ARG0]]
%0 = addf %arg0, %arg1 : tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -28,9 +38,12 @@ func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// Check a unary op.
// CHECK-LABEL: func @exp
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
func @exp(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[SCALAR:.*]]: f32):
+ // CHECK-SAME: ins(%[[ARG0]]
+ // CHECK-SAME: outs(%[[ARG0]]
+ // CHECK: ^bb0(%[[SCALAR:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32
// CHECK: linalg.yield %[[YIELD]] : f32
%0 = exp %arg0 : tensor<f32>
@@ -41,9 +54,14 @@ func @exp(%arg0: tensor<f32>) -> tensor<f32> {
// Check a case with varying operand types.
// CHECK-LABEL: func @select
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<i1>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<i32>
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<i32>
func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]]
+ // CHECK-SAME: outs(%[[ARG1]]
+ // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32, %{{.*}}: i32):
// CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
%0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
return %0 : tensor<i32>
@@ -52,9 +70,41 @@ func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tenso
// -----
// Spot-check an op that requires copying attributes properly to the created scalar op.
+// Also checks proper init_tensor usage.
// CHECK-LABEL: func @cmpf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32>
func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
+ // CHECK: %[[INIT:.*]] = linalg.init_tensor [] : tensor<i1>
+ // CHECK: linalg.generic
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+ // CHECK-SAME: outs(%[[INIT]]
+ // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1):
// CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
%0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
return %0 : tensor<i1>
}
+
+// -----
+
+// Check proper init_tensor usage in a mixed case.
+// CHECK-LABEL: func @cmpf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32>
+func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) -> tensor<4x?x?x8x2x?xi1> {
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] : tensor<4x?x?x8x2x?xf32>
+ // CHECK: %[[C2:.*]] = constant 2 : index
+ // CHECK: %[[D2:.*]] = dim %[[ARG0]], %[[C2]] : tensor<4x?x?x8x2x?xf32>
+ // CHECK: %[[C5:.*]] = constant 5 : index
+ // CHECK: %[[D5:.*]] = dim %[[ARG0]], %[[C5]] : tensor<4x?x?x8x2x?xf32>
+ // CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[D1]], %[[D2]], 8, 2, %[[D5]]] : tensor<4x?x?x8x2x?xi1>
+ // CHECK: linalg.generic
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+ // CHECK-SAME: outs(%[[INIT]]
+ // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1):
+ // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
+ %0 = cmpf "olt", %arg0, %arg1 : tensor<4x?x?x8x2x?xf32>
+ return %0 : tensor<4x?x?x8x2x?xi1>
+}
+
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index e04d03b4e493..17b8bda967b1 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-fold-unit-extent-dims | FileCheck %s
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
@@ -11,12 +11,12 @@
library_call = "some_external_func"
}
-func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
-{
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
%0 = linalg.generic #trait
- ins(%arg0 : tensor<?x1x?xf32>) {
- ^bb0(%arg1 : f32) :
- linalg.yield %arg1 : f32
+ ins(%arg0 : tensor<?x1x?xf32>)
+ outs(%shape : tensor<?x1x?x1x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32) :
+ linalg.yield %arg2 : f32
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
@@ -48,12 +48,13 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
}
func @drop_one_trip_loops_indexed_generic
- (%arg0 : tensor<?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
+ (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
{
%0 = linalg.indexed_generic #trait
- ins(%arg0 : tensor<?x1x?xi32>) {
+ ins(%arg0 : tensor<?x1x?xi32>)
+ outs(%shape: tensor<?x1x?x1x?xi32>) {
^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
- %arg5 : index, %arg6 : i32) :
+ %arg5 : index, %arg6 : i32, %arg7 : i32) :
%1 = addi %arg1, %arg2 : index
%2 = addi %1, %arg3 : index
%3 = addi %2, %arg4 : index
@@ -68,7 +69,7 @@ func @drop_one_trip_loops_indexed_generic
// CHECK: linalg.indexed_generic
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32)
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
// CHECK: %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]]
// CHECK: %[[T4:.+]] = addi %[[T3]], %[[ARG3]]
// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32
@@ -88,8 +89,9 @@ func @drop_one_trip_loops_indexed_generic
func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
%0 = linalg.generic #trait
- ins(%arg0 : tensor<1x1xf32>) {
- ^bb0(%arg1: f32) :
+ ins(%arg0 : tensor<1x1xf32>)
+ outs(%arg0 : tensor<1x1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32) :
linalg.yield %arg1 : f32
} -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
@@ -112,11 +114,11 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
}
func @drop_all_loops_indexed_generic
- (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>
-{
+ (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
%0 = linalg.indexed_generic #trait
- ins(%arg0 : tensor<1x1xi32>) {
- ^bb0(%arg1 : index, %arg2 : index, %arg3: i32) :
+ ins(%arg0 : tensor<1x1xi32>)
+ outs(%arg0 : tensor<1x1xi32>) {
+ ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) :
%1 = addi %arg1, %arg2 : index
%2 = index_cast %1 : index to i32
%3 = addi %2, %arg3 : i32
@@ -127,7 +129,7 @@ func @drop_all_loops_indexed_generic
// CHECK-LABEL: func @drop_all_loops_indexed_generic
// CHECK: linalg.indexed_generic
-// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32)
+// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
// CHECK: linalg.yield %[[ARG1]] : i32
// -----
@@ -143,10 +145,11 @@ func @drop_all_loops_indexed_generic
library_call = "some_external_fn"
}
-func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> {
%0 = linalg.generic #trait
- ins(%arg0 : tensor<1x5xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ ins(%arg0 : tensor<1x5xf32>)
+ outs(%shape : tensor<5xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5xf32>
return %0 : tensor<5xf32>
@@ -172,16 +175,17 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
library_call = "some_external_fn"
}
-func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5xf32>
+func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32>
{
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
tensor<5xf32> into tensor<1x5xf32>
%1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
tensor<5xf32> into tensor<5x1xf32>
%2 = linalg.generic #trait
- ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- %3 = addf %arg2, %arg3 : f32
+ ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>)
+ outs(%shape : tensor<5x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %3 = addf %arg3, %arg4 : f32
linalg.yield %3 : f32
} -> tensor<5x5xf32>
return %2 : tensor<5x5xf32>
@@ -209,12 +213,13 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
library_call = "some_external_fn"
}
-func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
+func @broadcast_scalar(%arg0 : tensor<1x1xf32>, %shape : tensor<?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.generic #trait
- ins(%arg0 : tensor<1x1xf32>) {
- ^bb0(%arg1 : f32):
- linalg.yield %arg1 : f32
+ ins(%arg0 : tensor<1x1xf32>)
+ outs(%shape : tensor<?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
index 6d75c480b5c6..d0c526e441b6 100644
--- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
+++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" | FileCheck %s
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
@@ -11,11 +11,12 @@
library_call = "some_external_func"
}
-func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
{
%0 = linalg.generic #trait
- ins(%arg0 : tensor<?x1x?xf32>) {
- ^bb0(%arg1 : f32) :
+ ins(%arg0 : tensor<?x1x?xf32>)
+ outs(%shape : tensor<?x1x?x1x?xf32>) {
+ ^bb0(%arg1 : f32, %arg2 : f32) :
linalg.yield %arg1 : f32
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
@@ -40,8 +41,9 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
%0 = linalg.generic #trait
- ins(%arg0 : tensor<1x1xf32>) {
- ^bb0(%arg1: f32) :
+ ins(%arg0 : tensor<1x1xf32>)
+ outs(%arg0 : tensor<1x1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32) :
linalg.yield %arg1 : f32
} -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
@@ -91,10 +93,11 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
library_call = "some_external_fn"
}
-func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> {
%0 = linalg.generic #trait
- ins(%arg0 : tensor<1x5xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ ins(%arg0 : tensor<1x5xf32>)
+ outs(%shape : tensor<5xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5xf32>
return %0 : tensor<5xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index ff0394f18249..df7e59d59dde 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -6,29 +6,36 @@
// CHECK-LABEL: @add_mul_fusion
func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = addf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %4 = addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
} -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
- ^bb0(%arg5: f32, %arg6: f32): // no predecessors
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
// CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG0]], [[ARG1]]
// CHECK-NOT: linalg.yield
// CHECK: mulf [[T1]], [[ARG2]]
// CHECK: linalg.yield
- %3 = mulf %arg5, %arg6 : f32
- linalg.yield %3 : f32
+ %5 = mulf %arg5, %arg6 : f32
+ linalg.yield %5 : f32
} -> tensor<?x?xf32>
- return %2 : tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
}
// -----
@@ -41,21 +48,28 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : te
// CHECK-LABEL: @transpose_add_mul_fusion
func @transpose_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = addf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %4 = addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
} -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg5: f32, %arg6: f32): // no predecessors
- %3 = mulf %arg5, %arg6 : f32
- linalg.yield %3 : f32
+ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
+ %5 = mulf %arg5, %arg6 : f32
+ linalg.yield %5 : f32
} -> tensor<?x?xf32>
- return %2 : tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
}
// -----
@@ -68,21 +82,28 @@ func @transpose_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-LABEL: @add_transpose_mul_fusion
func @add_transpose_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = addf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %4 = addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
} -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg5: f32, %arg6: f32): // no predecessors
- %3 = mulf %arg5, %arg6 : f32
- linalg.yield %3 : f32
+ %4 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>){
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
+ %5= mulf %arg5, %arg6 : f32
+ linalg.yield %5 : f32
} -> tensor<?x?xf32>
- return %2 : tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
}
// -----
@@ -96,21 +117,29 @@ func @add_transpose_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-LABEL: @add_broadcast_mul_fusion
func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]}
- ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = addf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?xf32>
+ %1 = linalg.init_tensor [%0] : tensor<?xf32>
+ %2 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%1 : tensor<?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %3 = addf %arg3, %arg4 : f32
+ linalg.yield %3 : f32
} -> tensor<?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP1]], [[$MAP0]], [[$MAP0]]
- %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : tensor<?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg5: f32, %arg6: f32): // no predecessors
- %3 = mulf %arg5, %arg6 : f32
- linalg.yield %3 : f32
+ %3 = dim %arg2, %c1 : tensor<?x?xf32>
+ %4 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
+ %5 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%2, %arg2 : tensor<?xf32>, tensor<?x?xf32>)
+ outs(%4 : tensor<?x?xf32>){
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
+ %6 = mulf %arg5, %arg6 : f32
+ linalg.yield %6 : f32
} -> tensor<?x?xf32>
- return %2 : tensor<?x?xf32>
+ return %5 : tensor<?x?xf32>
}
// -----
@@ -121,23 +150,26 @@ func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg
// CHECK-LABEL: @add_mul_scalar_fusion
func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
{
- %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
- ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = addf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
} -> tensor<f32>
// CHECK: linalg.generic {
// CHECK: addf
// CHECK: mulf
- %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
- ins(%0, %arg2 : tensor<f32>, tensor<f32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
- %1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%1, %arg2 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %3 = mulf %arg3, %arg4 : f32
+ linalg.yield %3 : f32
} -> tensor<f32>
- return %1 : tensor<f32>
+ return %2 : tensor<f32>
}
// -----
@@ -146,22 +178,29 @@ func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tenso
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
{
- %0 = constant dense<42.0> : tensor<5xf32>
- %1 = linalg.generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- %2 = mulf %arg1, %arg2 : f32
- linalg.yield %2 : f32
- } -> tensor<5x?x?xf32>
- return %1 : tensor<5x?x?xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %cst = constant dense<42.0> : tensor<5xf32>
+ %0 = dim %arg0, %c1 : tensor<5x?x?xf32>
+ %1 = dim %arg0, %c2 : tensor<5x?x?xf32>
+ %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>)
+ outs(%2 : tensor<5x?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %4 = mulf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<5x?x?xf32>
+ return %3 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @generic_op_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.generic
-// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
+// CHECK: ^{{.+}}(%[[ARG1:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32):
// CHECK: mulf %[[CST]], %[[ARG1]]
// -----
@@ -171,16 +210,23 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
-> tensor<5x?x?xf32>
{
- %0 = constant dense<42.0> : tensor<5xf32>
- %1 = linalg.indexed_generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32):
- %2 = mulf %arg4, %arg5 : f32
- linalg.yield %2 : f32
- } -> tensor<5x?x?xf32>
- return %1 : tensor<5x?x?xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %cst = constant dense<42.0> : tensor<5xf32>
+ %0 = dim %arg0, %c1 : tensor<5x?x?xf32>
+ %1 = dim %arg0, %c2 : tensor<5x?x?xf32>
+ %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
+ %3 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>)
+ outs(%2 : tensor<5x?x?xf32>) {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32, %arg6 : f32):
+ %4 = mulf %arg4, %arg5 : f32
+ linalg.yield %4 : f32
+ } -> tensor<5x?x?xf32>
+ return %3 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @indexed_generic_op_constant_fusion
@@ -190,7 +236,7 @@ func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index
-// CHECK-SAME: %[[ARG4:.*]]: f32)
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
// CHECK: mulf %[[CST]], %[[ARG4]]
// -----
@@ -200,22 +246,29 @@ func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
-> tensor<5x?x?xf32>
{
- %0 = constant dense<42.0> : tensor<f32>
- %1 = linalg.generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg0 : tensor<f32>, tensor<5x?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- %2 = mulf %arg1, %arg2 : f32
- linalg.yield %2 : f32
- } -> tensor<5x?x?xf32>
- return %1 : tensor<5x?x?xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %cst = constant dense<42.0> : tensor<f32>
+ %0 = dim %arg0, %c1 : tensor<5x?x?xf32>
+ %1 = dim %arg0, %c2 : tensor<5x?x?xf32>
+ %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%cst, %arg0 : tensor<f32>, tensor<5x?x?xf32>)
+ outs(%2 : tensor<5x?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %4 = mulf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<5x?x?xf32>
+ return %3 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.generic
-// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
+// CHECK: ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
// CHECK: mulf %[[CST]], %[[ARG1]]
// -----
@@ -225,16 +278,23 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
func @indexed_generic_op_zero_dim_constant_fusion
(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
{
- %0 = constant dense<42.0> : tensor<f32>
- %1 = linalg.indexed_generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg0 : tensor<f32>, tensor<5x?x?xf32>) {
- ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32):
- %2 = mulf %arg4, %arg5 : f32
- linalg.yield %2 : f32
- } -> tensor<5x?x?xf32>
- return %1 : tensor<5x?x?xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %cst = constant dense<42.0> : tensor<f32>
+ %0 = dim %arg0, %c1 : tensor<5x?x?xf32>
+ %1 = dim %arg0, %c2 : tensor<5x?x?xf32>
+ %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
+ %3 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%cst, %arg0 : tensor<f32>, tensor<5x?x?xf32>)
+ outs(%2 : tensor<5x?x?xf32>) {
+ ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32, %arg6: f32):
+ %4 = mulf %arg4, %arg5 : f32
+ linalg.yield %4 : f32
+ } -> tensor<5x?x?xf32>
+ return %3 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
@@ -244,7 +304,7 @@ func @indexed_generic_op_zero_dim_constant_fusion
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index
-// CHECK-SAME: %[[ARG4:.*]]: f32)
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
// CHECK: mulf %[[CST]], %[[ARG4]]
// -----
@@ -252,26 +312,33 @@ func @indexed_generic_op_zero_dim_constant_fusion
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
- ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
%10 = addi %arg2, %arg3 : i32
linalg.yield %10 : i32
} -> tensor<?x?xi32>
- %1 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%0 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
- %3 = index_cast %arg3 : index to i32
- %4 = addi %arg4, %2 : i32
- %5 = subi %4, %3 : i32
- linalg.yield %5 : i32
+ %4 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%3 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
+ %5 = index_cast %arg2 : index to i32
+ %6 = index_cast %arg3 : index to i32
+ %7 = addi %arg4, %5 : i32
+ %8 = subi %7, %6 : i32
+ linalg.yield %8 : i32
} -> tensor<?x?xi32>
- return %1 : tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
@@ -295,26 +362,33 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %0 = linalg.indexed_generic {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.indexed_generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"] }
- ins(%arg0 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
- %3 = index_cast %arg3 : index to i32
- %4 = addi %arg4, %2 : i32
- %5 = subi %4, %3 : i32
- linalg.yield %5 : i32
- } -> tensor<?x?xi32>
- %1 = linalg.generic {
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
+ %4 = index_cast %arg2 : index to i32
+ %5 = index_cast %arg3 : index to i32
+ %6 = addi %arg4, %4 : i32
+ %7 = subi %6, %5 : i32
+ linalg.yield %7 : i32
+ } -> tensor<?x?xi32>
+ %4 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"] }
- ins(%0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
- ^bb0(%arg2: i32, %arg3: i32): // no predecessors
- %10 = addi %arg2, %arg3 : i32
- linalg.yield %10 : i32
- } -> tensor<?x?xi32>
- return %1 : tensor<?x?xi32>
+ ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } -> tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
@@ -339,29 +413,36 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %0 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%arg0 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
- %3 = index_cast %arg3 : index to i32
- %4 = addi %arg4, %2 : i32
- %5 = subi %4, %3 : i32
- linalg.yield %5 : i32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, %c0 : tensor<?x?xi32>
+ %1 = dim %arg0, %c1 : tensor<?x?xi32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+ %3 = linalg.indexed_generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
+ %4 = index_cast %arg2 : index to i32
+ %5 = index_cast %arg3 : index to i32
+ %6 = addi %arg4, %4 : i32
+ %7 = subi %5, %6 : i32
+ linalg.yield %7 : i32
} -> tensor<?x?xi32>
- %1 = linalg.indexed_generic {
- indexing_maps = [#map1, #map1],
- iterator_types = ["parallel", "parallel"] }
- ins(%0 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
- %3 = index_cast %arg3 : index to i32
- %4 = addi %arg4, %2 : i32
- %5 = subi %4, %3 : i32
- linalg.yield %5 : i32
+ %4= linalg.indexed_generic {
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%3 : tensor<?x?xi32>)
+ outs(%2 : tensor<?x?xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
+ %5 = index_cast %arg2 : index to i32
+ %6 = index_cast %arg3 : index to i32
+ %7 = addi %arg4, %5 : i32
+ %8 = subi %7, %6 : i32
+ linalg.yield %8 : i32
} -> tensor<?x?xi32>
- return %1 : tensor<?x?xi32>
+ return %4 : tensor<?x?xi32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @indexed_generic_op_fusion
@@ -374,7 +455,7 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32
// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32
// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32
-// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND1]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32
// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32
// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
@@ -389,25 +470,27 @@ func @scalar_indexed_generic_fusion
{
%c0 = constant 0 : index
%cst = constant dense<1.000000e+00> : tensor<10xf32>
- %0 = linalg.indexed_generic
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.indexed_generic
{indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
iterator_types = []}
- ins(%arg1 : tensor<i32>) {
- ^bb0(%arg2: i32): // no predecessors
+ ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
+ ^bb0(%arg2: i32, %arg3: f32): // no predecessors
%3 = index_cast %arg2 : i32 to index
%4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
linalg.yield %4 : f32
} -> tensor<f32>
- %1 = linalg.generic
+ %2 = linalg.init_tensor [10] : tensor<10xf32>
+ %3 = linalg.generic
{indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
- ins(%0, %cst : tensor<f32>, tensor<10xf32>) {
- ^bb0(%arg2: f32, %arg3: f32): // no predecessors
- %3 = mulf %arg2, %arg3 : f32
- linalg.yield %3 : f32
+ ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
+ %4 = mulf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
} -> tensor<10xf32>
- return %1 : tensor<10xf32>
+ return %3 : tensor<10xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
@@ -421,3 +504,35 @@ func @scalar_indexed_generic_fusion
// CHECK: tensor.extract %[[ARG0]]
// CHECK: linalg.yield
// CHECK return %[[T0]]
+
+// -----
+
+func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) {
+ %cst = constant dense<1.0> : tensor<4xf32>
+ %1 = linalg.init_tensor [4] : tensor<4xf32>
+ %2 = linalg.generic
+ {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins (%arg0, %cst : tensor<4xf32>, tensor<4xf32>)
+ outs (%1 : tensor<4xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %3 = addf %arg1, %arg2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<4xf32>
+ return %2 : tensor<4xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @constant_fusion(%[[ARG0:.+]]: tensor<4xf32>)
+// CHECK-DAG: %[[CST:.+]] = constant 1.000000e+00 : f32
+// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [4] : tensor<4xf32>
+// CHECK: %[[T1:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<4xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<4xf32>)
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: f32, %[[ARG2:[a-zA-Z0-9_]+]]: f32)
+// CHECK: %[[T2:.+]] = addf %[[ARG1]], %[[CST]]
+// CHECK: linalg.yield %[[T2]]
+// CHECK: return %[[T1]]
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 6db48af3b573..c9f24844662f 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -28,7 +28,8 @@ func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32
// -----
func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
- linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) outs(%C: memref<16x32xf32>)
+ linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>)
+ outs(%C: memref<16x32xf32>)
return
}
@@ -45,7 +46,7 @@ func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C:
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
@@ -56,15 +57,16 @@ func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C:
// -----
func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) init(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK: func @generalize_matmul_tensor
// CHECK: linalg.generic
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
-// CHECK-SAME: init(%{{.+}} : tensor<16x32xf32>)
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<16x32xf32>)
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 8e98a80e77b1..95a663d19f0d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -77,7 +77,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
// -----
func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>'}}
+ // expected-error @+1 {{expected shaped value rank (1) to match the result rank of indexing_map #0 (2)}}
linalg.generic {
indexing_maps = [ affine_map<() -> (0, 0)> ],
iterator_types = []}
@@ -143,9 +143,9 @@ func @generic_empty_region(%arg0: memref<f32>) {
func @generic_empty_region(%arg0: memref<f32>) {
%f0 = constant 0.0: f32
- // expected-error @+1 {{linalg.generic' op expected region with 1 block}}
+ // expected-error @+1 {{linalg.generic' op expected 1 region with 1 block}}
linalg.generic {
- indexing_maps = [ affine_map<() -> (0)> ],
+ indexing_maps = [ affine_map<() -> ()> , affine_map<() -> ()> ],
iterator_types = []}
ins(%arg0 : memref<f32>)
outs(%arg0 : memref<f32>) {
@@ -155,12 +155,12 @@ func @generic_empty_region(%arg0: memref<f32>) {
// -----
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected number of block arguments to match number of operands}}
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
linalg.generic {
- indexing_maps = [ affine_map<() -> (0)> ],
+ indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
iterator_types = []}
- outs(%arg0 : memref<f32>) {
- ^bb(%f: f32, %g: f32):
+ outs(%arg0, %arg0 : memref<f32>, memref<f32>) {
+ ^bb(%f: f32):
linalg.yield %f: f32
}
}
@@ -168,9 +168,9 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// -----
func @generic_block_arg_type(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref<f32>'}}
+ // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element type of corresponding shaped operand ('f32')}}
linalg.generic {
- indexing_maps = [ affine_map<() -> (0)> ],
+ indexing_maps = [ affine_map<() -> ()> ],
iterator_types = []}
outs(%arg0 : memref<f32>) {
^bb(%i: i1):
@@ -180,12 +180,12 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
// -----
-func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
+func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
linalg.indexed_generic {
- indexing_maps = [ affine_map<(d0) -> (d0)> ],
+ indexing_maps = [ affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]}
- outs(%arg0 : memref<f32>) {
+ outs(%arg0 : memref<?xf32>) {
^bb(%f: f32):
linalg.yield %f : f32
}
@@ -193,12 +193,12 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
// -----
-func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected block argument 1 to be an index}}
+func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{op expected index block argument #0}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
- outs(%arg0 : memref<f32>) {
+ outs(%arg0 : memref<?xf32>) {
^bb(%i: f64, %f: f32):
linalg.yield %f: f32
}
@@ -206,12 +206,12 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
// -----
-func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected block argument 2 of the same type as elemental type of output operand: 'memref<f32>'}}
+func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element type of corresponding shaped operand ('f32')}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
- outs(%arg0 : memref<f32>) {
+ outs(%arg0 : memref<?xf32>) {
^bb(%i: index, %f: i1):
linalg.yield %i: index
}
@@ -220,7 +220,7 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
// -----
func @indexed_generic_arg_count(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
linalg.indexed_generic {
indexing_maps = [ affine_map<()[] -> ()> ],
iterator_types = []}
@@ -233,19 +233,6 @@ func @indexed_generic_arg_count(%arg0: memref<f32>) {
// -----
-func @indexed_generic_induction_var_arg_type(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected block argument 1 to be an index}}
- linalg.indexed_generic {
- iterator_types = ["parallel"],
- indexing_maps = [ affine_map<(i) -> (i)> ]}
- outs(%arg0 : memref<f32>) {
- ^bb(%0: i32, %1: f32):
- linalg.yield %1: f32
- }
-}
-
-// -----
-
func @indexed_generic_result_count(%arg0: memref<?xf32>) {
// expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
linalg.indexed_generic {
@@ -273,19 +260,36 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
// -----
-func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}}
+func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
+ %arg1: tensor<?xf32>) {
+ // expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('f32')}}
%0 = linalg.generic {
- indexing_maps = [ affine_map<(i) -> (i)> ],
+ indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]}
- ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- ^bb(%i: f32):
+ ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb(%i: f32, %j: f32):
linalg.yield %i: f32
} -> f32
}
// -----
+func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
+ %arg1: tensor<?xf32>) {
+ // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd0' is function of reduction iterator 'd0'}}
+ %0 = linalg.generic {
+ indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
+ iterator_types = ["reduction"]}
+ ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb(%i: f32, %j: f32):
+ linalg.yield %i: f32
+ } -> tensor<?xf32>
+}
+
+// -----
+
func @generic(%arg0: memref<?x?xi4>) {
// expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}}
@@ -301,12 +305,17 @@ func @generic(%arg0: memref<?x?xi4>) {
// -----
-func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- // expected-error @+1 {{expects memref ranks to be greater than 2}}
- linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
-}
-
-// -----
+// This test is currently disabled: subject to verifier ordering issues.
+// Instead, when the ranks are not greater than 2, an assertion will be triggered
+// in LinalgStructuredOps.td::ConvOp::iterator_types() for now because the
+// verifier inspects the iterator_types. This is slated to become an
+// autogenerated op in the future, alleviating the issue.
+// func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+// // DISABLED_expected -error @+1 {{expects memref ranks to be greater than 2}}
+// linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+// }
+//
+// // -----
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown
@@ -367,7 +376,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
%arg1: memref<2x3xf32>,
%arg2: memref<?x?x?xf32>) {
- // expected-error @+1 {{expects memref ranks to match}}
+ // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
return
@@ -376,7 +385,7 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
// -----
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
- // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref<?x?xf32>'}}
+ // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
outs(%c3 : memref<?x?x?xf32>)
return
@@ -384,18 +393,8 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
// -----
-func @empty_init_expected(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
- // expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}}
- linalg.matmul ins(%m, %m: memref<?x?xf32>, memref<?x?xf32>)
- outs(%m : memref<?x?xf32>)
- init(%t : tensor<?x?xf32>)
- return
-}
-
-// -----
-
func @incorrect_region_arg_count(%m: memref<?x?xf32>) {
- // expected-error @+3 {{region expects 3 args, got 4}}
+ // expected-error @+3 {{region expects 3 args, got 2}}
%res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
-> tensor<?x?xf32>, tensor<?x?xf32>
return
@@ -403,30 +402,10 @@ func @incorrect_region_arg_count(%m: memref<?x?xf32>) {
// -----
-func @single_tensor_result(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
- // expected-error @+1 {{expected single tensor result when reduction present}}
- %res:2 = linalg.matmul ins(%m : memref<?x?xf32>)
- init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
- -> tensor<?x?xf32>, tensor<?x?xf32>
- return
-}
-
-// -----
-
-func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
- // expected-error @+1 {{expected #init tensors to match #results when reduction present}}
- %res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
- init(%t, %t : tensor<?x?xf32>, tensor<?x?xf32>)
- -> tensor<?x?xf32>
- return
-}
-
-// -----
-
func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
- // expected-error @+1 {{expected init tensor #0 of the same type as result #0}}
+ // expected-error @+1 {{expected type of operand #2 ('tensor<?x?xf32>') to match type of corresponding result ('tensor<?xf32>')}}
%res = linalg.matmul ins(%m, %m : memref<?x?xf32>, memref<?x?xf32>)
- init(%t : tensor<?x?xf32>)
+ outs(%t : tensor<?x?xf32>)
-> tensor<?xf32>
return
}
diff --git a/mlir/test/Dialect/Linalg/parallel-loops.mlir b/mlir/test/Dialect/Linalg/parallel-loops.mlir
index 95eb997f4dbd..8d365af6a5a3 100644
--- a/mlir/test/Dialect/Linalg/parallel-loops.mlir
+++ b/mlir/test/Dialect/Linalg/parallel-loops.mlir
@@ -64,7 +64,7 @@ func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
#accesses = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>,
- affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
]
#trait = {
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
@@ -94,4 +94,4 @@ func @lower_mixed_parallel(%A: memref<?x?x?x?x?x?xf32>, %B: memref<?x?x?x?xf32>)
// CHECK: scf.parallel (%[[IV3:.*]], %[[IV4:.*]]) = (%[[C0]], %[[C0]]) to (%[[D3]], %[[D4]]) step (%[[C1]], %[[C1]])
// CHECK: scf.for %[[IV5:.*]] = %[[C0]] to %[[D5]] step %[[C1]]
// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]]
-// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV2]], %[[IV4]], %[[IV5]]]
+// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV4]], %[[IV3]]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 66e07cc56d65..92805218dde7 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,20 +1,21 @@
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?x?xf32>,
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
%arg1 : tensor<?x?x?xf32>) ->
tensor<?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
affine_map<(i, j, k, l) -> (j, k)>,
affine_map<(i, j, k, l) -> (l)>] :
- tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
%1 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%0 : tensor<?x?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32>
@@ -22,44 +23,58 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?x?xf32>,
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
// CHECK: func @generic_op_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
-// CHECK: %[[T1:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
+// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
+// CHECK-DAG: %[[D0:.+]] = dim %[[T0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = dim %[[T0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = dim %[[T0]], %[[C2]]
+// CHECK: %[[D3:.+]] = divi_unsigned %[[D0]], %[[C4]]
+// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], %[[D2]], %[[D3]], 4]
+// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[ARG0]], %[[T0]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-// CHECK: %[[T2:.+]] = linalg.tensor_reshape
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK-SAME: tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
-// CHECK: return %[[T2]]
+// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x?x?x4xf32>)
+// CHECK: %[[T4:.+]] = linalg.tensor_reshape %[[T3]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
+// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
+// CHECK: return %[[T4]]
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
%arg1 : tensor<?x?xf32>) ->
- tensor<?x?x4x5xf32>
+ tensor<?x4x?x5xf32>
{
%0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
affine_map<(i, j, k, l) -> (j, k, l)>] :
- tensor<?x?xf32> into tensor<?x?x4x5xf32>
- return %1 : tensor<?x?x4x5xf32>
+ tensor<?x?xf32> into tensor<?x4x?x5xf32>
+ return %1 : tensor<?x4x?x5xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@@ -68,31 +83,40 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C20:.+]] = constant 20 : index
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[T2:.+]] = linalg.generic
+// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[D2:.+]] = divi_unsigned %[[D1]], %[[C20]]
+// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D0]], 4, %[[D2]], 5]
+// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
-// CHECK: return %[[T2]] : tensor<?x?x4x5xf32>
+// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x4x?x5xf32>)
+// CHECK: return %[[T3]] : tensor<?x4x?x5xf32>
// -----
func @reshape_as_consumer_permutation
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
- -> tensor<?x?x?x?x?x?xf32> {
+ -> tensor<?x2x?x3x4x?xf32> {
%c = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg0 : f32, %arg1: f32):
+ ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>)
+ outs(%a : tensor<?x?x?xf32>) {
+ ^bb0(%arg0 : f32, %arg1: f32, %s: f32):
%1 = addf %arg0, %arg1 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32>
@@ -100,8 +124,8 @@ func @reshape_as_consumer_permutation
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
- : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
- return %d : tensor<?x?x?x?x?x?xf32>
+ : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+ return %d : tensor<?x2x?x3x4x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
@@ -114,17 +138,28 @@ func @reshape_as_consumer_permutation
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C12:.+]] = constant 12 : index
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
-// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x?x?xf32>
-// CHECK: %[[T2:.+]] = linalg.generic
+// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[D1:.+]] = divi_unsigned %[[D0]], %[[C2]]
+// CHECK-DAG: %[[D2:.+]] = dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D4:.+]] = divi_unsigned %[[D3]], %[[C12]]
+// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], 2, %[[D2]], 3, 4, %[[D4]]]
+// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-// CHECK: return %[[T2]] : tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
+// CHECK: return %[[T3]] : tensor<?x2x?x3x4x?xf32>
// -----
@@ -138,8 +173,9 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
%0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) {
- ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>)
+ outs(%arg0 : tensor<264x4xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %s: f32): // no predecessors
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
} -> tensor<264x4xf32>
@@ -156,21 +192,27 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = linalg.generic
+// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] : tensor<8x33x4xf32>
+// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
-// CHECK: return %[[T1]] : tensor<8x33x4xf32>
+// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
+// CHECK: return %[[T2]] : tensor<8x33x4xf32>
// -----
-func @scalar_reshape(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>)
- -> tensor<1x10xf32> {
+func @scalar_reshape(
+ %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>, %shape : tensor<10xf32>)
+ -> tensor<1x10xf32>
+{
%0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]} ins(%0 : tensor<f32>) {
- ^bb0(%arg2: f32): // no predecessors
+ iterator_types = ["parallel"]}
+ ins(%0 : tensor<f32>)
+ outs(%shape : tensor<10xf32>) {
+ ^bb0(%arg2: f32, %s: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<10xf32>
%2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>]
@@ -185,11 +227,13 @@ func @scalar_reshape(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>)
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
// CHECK-SAME: tensor<1xf32> into tensor<f32>
-// CHECK: %[[T1:.+]] = linalg.generic
+// CHECK: %[[T1:.+]] = linalg.init_tensor [1, 10] : tensor<1x10xf32>
+// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]] : tensor<f32>)
-// CHECK: return %[[T1]] : tensor<1x10xf32>
+// CHECK-SAME: outs(%[[T1]] : tensor<1x10xf32>)
+// CHECK: return %[[T2]] : tensor<1x10xf32>
// -----
@@ -206,8 +250,9 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
%1 = linalg.indexed_generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) {
- ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32):
+ ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
+ outs(%0 : tensor<?x?x?xi32>) {
+ ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32, %s: i32):
%1 = muli %arg6, %arg7 : i32
%2 = index_cast %arg3 : index to i32
%3 = addi %1, %2 : i32
@@ -228,7 +273,8 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
// CHECK: ^{{.*}}(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32)
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32)
// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG2]], %[[ARG3]])
// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
// CHECK: %[[T5:.+]] = index_cast %[[T3]]
@@ -249,8 +295,9 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%0 = linalg.indexed_generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
- ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32): // no predecessors
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32, %s: i32): // no predecessors
%1 = muli %arg5, %arg6 : i32
%2 = index_cast %arg3 : index to i32
%3 = addi %1, %2 : i32
@@ -271,7 +318,8 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// CHECK: ^{{.*}}(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32)
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32)
// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG4]], %[[ARG5]])
// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
// CHECK: %[[T5:.+]] = index_cast %[[ARG2]]
@@ -283,15 +331,16 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// -----
func @reshape_as_consumer_permutation
- (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
+ (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>, %shape : tensor<6x4x210xi32>)
-> tensor<2x3x4x5x6x7xi32> {
%c = linalg.indexed_generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) {
- ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32):
+ ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
+ outs(%shape : tensor<6x4x210xi32>) {
+ ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32):
%1 = addi %arg3, %arg4 : i32
%2 = index_cast %arg0 : index to i32
%3 = addi %1, %2 : i32
@@ -327,36 +376,42 @@ func @reshape_as_consumer_permutation
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
-// CHECK: %[[T2:.+]] = linalg.indexed_generic
+// CHECK: %[[T2:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
+// CHECK: %[[T3:.+]] = linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]]
-// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<{{.+}}>, tensor<{{.+}}>)
+// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<2x3x4x5x6x7xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32)
-// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]])
-// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
-// CHECK-DAG: %[[T5:.+]] = addi %[[ARG8]], %[[ARG9]]
-// CHECK: %[[T6:.+]] = index_cast %[[T3]]
-// CHECK: %[[T7:.+]] = addi %[[T5]], %[[T6]]
-// CHECK: %[[T8:.+]] = index_cast %[[T4]]
-// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
-// CHECK: %[[T10:.+]] = index_cast %[[ARG7]]
-// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]])
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
+// CHECK-DAG: %[[T6:.+]] = addi %[[ARG8]], %[[ARG9]]
+// CHECK: %[[T7:.+]] = index_cast %[[T4]]
+// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
+// CHECK: %[[T9:.+]] = index_cast %[[T5]]
+// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]]
+// CHECK: %[[T11:.+]] = index_cast %[[ARG7]]
+// CHECK: %[[T12:.+]] = addi %[[T10]], %[[T11]]
// -----
-func @reshape_as_producer_projected_permutation
- (%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> {
+func @reshape_as_producer_projected_permutation(
+ %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
+{
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>]
: tensor<33x8x?xi32> into tensor<264x?xi32>
%1 = linalg.indexed_generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<264x?xi32>) {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<264x?xi32>)
+ outs(%shape : tensor<264x?x4xi32>) {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32, %s: i32): // no predecessors
%2 = index_cast %arg1 : index to i32
%3 = addi %arg4, %2 : i32
%4 = index_cast %arg2 : index to i32
@@ -384,7 +439,8 @@ func @reshape_as_producer_projected_permutation
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32)
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32,
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: i32)
// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]])
// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32
// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
@@ -409,8 +465,9 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
%0 = linalg.generic {
indexing_maps = [#map0, #map0, #map1],
iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index 468ae80a1288..aff1447a63c7 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -1,9 +1,5 @@
// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
%arg1 : tensor<?x?x4x?xf32>) ->
@@ -14,37 +10,39 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
affine_map<(i, j, k, l) -> (l)>] :
tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
%1 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>)
+ outs(%0 : tensor<?x?x4x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x4x?xf32>
return %1 : tensor<?x?x4x?xf32>
}
-// CHECK-LABEL: func @generic_op_reshape_producer_fusion
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func @generic_op_reshape_producer_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
-// CHECK-NOT: linalg.generic
-
+// CHECK-SAME: ins(%[[ARG0]], %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<?x?x4x?xf32>)
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
%arg1 : tensor<?x?x4x5xf32>) ->
tensor<?x?xf32>
{
%0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
+ outs(%arg0 : tensor<?x?x4x5xf32>){
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x4x5xf32>
@@ -54,10 +52,21 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.generic
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+// CHECK: func @generic_op_reshape_consumer_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C20:.+]] = constant 20 : index
+// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]]
+// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: outs(%[[T3]] : tensor<?x?xf32>)
// -----
@@ -69,8 +78,9 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
%0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
+ outs(%arg0 : tensor<?x?x?x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?x5xf32>
@@ -81,14 +91,11 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
}
// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-// CHECK: linalg.tensor_reshape
+// CHECK: %[[T0:.+]] = linalg.generic
+// CHECK: linalg.tensor_reshape %[[T0]]
// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-> tensor<?x?x4x?xi32> {
@@ -99,8 +106,9 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
%1 = linalg.indexed_generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
- ins(%0 : tensor<?x?x4x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ ins(%0 : tensor<?x?x4x?xi32>)
+ outs(%0 : tensor<?x?x4x?xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7 : i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
@@ -108,25 +116,24 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
return %1 : tensor<?x?x4x?xi32>
}
-// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func @indexed_generic_op_reshape_producer_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-> tensor<?x?xi32> {
%0 = linalg.indexed_generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
- ins(%arg0 : tensor<?x?x4x5xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
@@ -137,105 +144,124 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
return %1 : tensor<?x?xi32>
}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C20:.+]] = constant 20 : index
+// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]]
+// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: outs(%[[T3]] : tensor<?x?xi32>)
+// CHECK-NOT: linalg.tensor_reshape
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
#map0 = affine_map<(d0, d1, d2) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) {
+ ^bb0(%arg2: f32, %arg3 : f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<3x7x5xf32>
- return %1 : tensor<3x7x5xf32>
+ return %2 : tensor<3x7x5xf32>
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
#map0 = affine_map<(d0, d1, d2) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5x7x3xf32>
- return %1 : tensor<5x7x3xf32>
+ return %2 : tensor<5x7x3xf32>
}
-// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: func @generic_op_120_permultation_reshape_producer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
#map0 = affine_map<(d0, d1, d2) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5x3x7xf32>
- return %1 : tensor<5x3x7xf32>
+ return %2 : tensor<5x3x7xf32>
}
-// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
-
-// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
- %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
- ^bb0(%arg2: f32): // no predecessors
+ %0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map0, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) {
+ ^bb0(%arg2: f32, %arg3 : f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5x3x7xf32>
- %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
- return %1 : tensor<5x21xf32>
+ %2 = linalg.tensor_reshape %1 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
+ return %2 : tensor<5x21xf32>
}
-// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
-// CHECK-NOT: linalg.tensor_reshape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-NOT: linalg.tensor_reshape
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index be785ceb70d6..c4eb8f8eac67 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -300,7 +300,7 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait
- ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
@@ -314,14 +314,14 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait
- ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
+ ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
@@ -358,14 +358,14 @@ func @generic_without_inputs(%arg0 : memref<?x?x?xf32>) {
// -----
-#accesses = [
+#accesses2 = [
affine_map<(i, j, k) -> (j, i)>,
affine_map<(i, j, k) -> (i, k, i + j)>,
affine_map<(i, j, k) -> (i, k, i + j)>
]
#trait2 = {
- indexing_maps = #accesses,
+ indexing_maps = #accesses2,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
}
@@ -374,9 +374,10 @@ func @generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
%0 = linalg.generic #trait2
- ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ outs(%arg1 : tensor<?x?x?xf32>)
attrs = {foo = 1} {
- ^bb(%0: vector<3x4xi4>, %1: f32) :
+ ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} -> tensor<?x?x?xf32>
@@ -386,21 +387,22 @@ func @generic_with_tensor_input_and_output(
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>
// -----
-#accesses = [
+#accesses3 = [
affine_map<(i, j, k) -> (j, i)>,
affine_map<(i, j, k) -> (i, k, i + j)>,
affine_map<(i, j, k) -> (i, k, i + j)>
]
-#trait2 = {
- indexing_maps = #accesses,
+#trait3 = {
+ indexing_maps = #accesses3,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
}
@@ -408,10 +410,11 @@ func @generic_with_tensor_input_and_output(
func @indexed_generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
- %0 = linalg.indexed_generic #trait2
- ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ %0 = linalg.indexed_generic #trait3
+ ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ outs(%arg1 : tensor<?x?x?xf32>)
attrs = {foo = 1} {
- ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
+ ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32, %2: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} -> tensor<?x?x?xf32>
@@ -421,7 +424,8 @@ func @indexed_generic_with_tensor_input_and_output(
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>
@@ -439,21 +443,23 @@ func @indexed_generic_with_tensor_input_and_output(
library_call = "some_broadcast_external_fn"
}
-func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
+func @generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tensor<3x4xf32>)
{
%0 = linalg.generic #trait_broadcast
- ins(%arg0 : tensor<f32>) {
- ^bb(%a: f32) :
+ ins(%arg0 : tensor<f32>)
+ outs(%arg1 : tensor<3x4xf32>) {
+ ^bb(%a: f32, %b: f32) :
linalg.yield %a : f32
} -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
-func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
+func @indexed_generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tensor<3x4xf32>)
{
%0 = linalg.indexed_generic #trait_broadcast
- ins(%arg0 : tensor<f32>) {
- ^bb(%i: index, %j: index, %a: f32) :
+ ins(%arg0 : tensor<f32>)
+ outs(%arg1 : tensor<3x4xf32>) {
+ ^bb(%i: index, %j: index, %a: f32, %b: f32) :
linalg.yield %a : f32
} -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
@@ -478,7 +484,7 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait3
- ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%a: vector<3x4xi4>, %b: f32) :
@@ -491,7 +497,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
-// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: attrs = {foo = 1 : i64} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
@@ -500,7 +506,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic #trait3
- ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
@@ -564,8 +570,8 @@ func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2:
affine_map<(i, j, k, l, m) -> (l, m)>] :
tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
%rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>,
- affine_map<(i, j, k, l, m) -> (k)>,
- affine_map<(i, j, k, l, m) -> (l, m)>] :
+ affine_map<(i, j, k, l, m) -> (k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
%t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
@@ -660,11 +666,13 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
outs(%c3: memref<?x?x?xf32>)
linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%c3: memref<?x?x?xf32>)
- %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- init(%tc3: tensor<?x?x?xf32>)
+ %res1 = linalg.batch_matmul
+ ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%tc3: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32>
- %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
- init(%tc3: tensor<?x?x?xf32>)
+ %res2 = linalg.batch_matmul
+ ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%tc3: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32>
return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir
index 4c14b2e89279..4baf1d1c1403 100644
--- a/mlir/test/Dialect/Linalg/sparse_1d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir
@@ -32,8 +32,9 @@
// CHECK: }
func @add_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
%0 = linalg.generic #trait_d
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s : f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -58,8 +59,9 @@ func @add_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// CHECK: }
func @mul_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
%0 = linalg.generic #trait_d
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s : f32):
%0 = mulf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -124,8 +126,9 @@ func @mul_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// CHECK: }
func @add_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
%0 = linalg.generic #trait_s
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s : f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -159,8 +162,9 @@ func @add_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// CHECK: }
func @repeated_add_s(%arga: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_s
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s : f32):
%0 = addf %a, %a : f32 // same tensor
%1 = addf %a, %a : f32 // should yield
%2 = addf %0, %1 : f32 // one guard
@@ -192,8 +196,9 @@ func @repeated_add_s(%arga: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @mul_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
%0 = linalg.generic #trait_s
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s : f32):
%0 = mulf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -235,8 +240,9 @@ func @mul_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// CHECK: }
func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -263,8 +269,9 @@ func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -335,8 +342,9 @@ func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_ds
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -368,8 +376,9 @@ func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_ds
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -440,8 +449,9 @@ func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_sd
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -473,8 +483,9 @@ func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_sd
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -569,8 +580,9 @@ func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_ss
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -628,8 +640,9 @@ func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
// CHECK: }
func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
%0 = linalg.generic #trait_ss
- ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+ outs(%arga : tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -730,8 +743,9 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
func @two_way_inv(%arga: tensor<16xf32>,
%argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> {
%0 = linalg.generic #trait_two_way_inv
- ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) {
- ^bb(%a : f32, %b : f32):
+ ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>)
+ outs(%argb : tensor<16xf32>) {
+ ^bb(%a : f32, %b : f32, %c : f32):
%0 = mulf %a, %argc : f32
%1 = mulf %b, %argc : f32
%2 = addf %0, %1 : f32
@@ -819,8 +833,9 @@ func @two_way_inv_alt(%arga: tensor<16xf32>,
%argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> {
// Same kernel, but now expressed as "x(i) = (a(i) + b(i)) * c".
%0 = linalg.generic #trait_two_way_inv
- ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) {
- ^bb(%a : f32, %b : f32):
+ ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>)
+ outs(%argb : tensor<16xf32>) {
+ ^bb(%a : f32, %b : f32, %c : f32):
%0 = addf %a, %b : f32
%1 = mulf %0, %argc : f32
linalg.yield %1: f32
@@ -866,7 +881,7 @@ func @two_way_inv_alt(%arga: tensor<16xf32>,
func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_sum_reduction
ins(%arga : tensor<?xf32>)
- init(%argx : tensor<f32>) {
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %x : f32):
%0 = addf %x, %a : f32
linalg.yield %0: f32
@@ -975,7 +990,7 @@ func @sum_reduction_ss(%arga: tensor<16xf32>,
// as two separate reductions kernels.
%0 = linalg.generic #trait_sum_reduction_ss
ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>)
- init(%argx : tensor<f32>) {
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %b : f32, %x : f32):
%0 = addf %a, %b : f32
%1 = addf %x, %0 : f32
@@ -1094,7 +1109,7 @@ func @sum_reduction_inv(%arga: tensor<16xf32>,
// as two separate reductions kernels.
%0 = linalg.generic #trait_sum_reduction_inv_ss
ins(%arga, %argb, %argc : tensor<16xf32>, tensor<f32>, tensor<16xf32>)
- init(%argx : tensor<f32>) {
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %b : f32, %c : f32, %x : f32):
%0 = mulf %a, %b : f32
%1 = addf %0, %c : f32
diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir
index dea7444cadae..6612a723f23d 100644
--- a/mlir/test/Dialect/Linalg/sparse_2d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir
@@ -39,8 +39,9 @@
// CHECK: }
func @add_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga: tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -70,8 +71,9 @@ func @add_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -146,8 +148,9 @@ func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ds
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -183,8 +186,9 @@ func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ds
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -264,8 +268,9 @@ func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_sd
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -302,8 +307,9 @@ func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_sd
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -409,8 +415,9 @@ func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -450,8 +457,9 @@ func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -627,8 +635,9 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16
// CHECK: }
func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -721,8 +730,9 @@ func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
// CHECK: }
func @mul_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -898,8 +908,9 @@ func @mul_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
// CHECK: }
func @add_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -992,8 +1003,9 @@ func @add_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
// CHECK: }
func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> {
%0 = linalg.generic #trait_ss_ss
- ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+ outs(%arga : tensor<32x16xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -1048,8 +1060,8 @@ func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
// CHECK: }
func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
%0 = linalg.generic #trait_matvec
- ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
- init(%argx : tensor<16xf32>) {
+ ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
+ outs(%argx : tensor<16xf32>) {
^bb(%A: f32, %b: f32, %x: f32):
%0 = mulf %A, %b : f32
%1 = addf %0, %x : f32
@@ -1099,8 +1111,8 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
// CHECK: }
func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_sum_reduction
- ins(%arga : tensor<10x20xf32>)
- init(%argx : tensor<f32>) {
+ ins(%arga : tensor<10x20xf32>)
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %x : f32):
%0 = addf %x, %a : f32
linalg.yield %0: f32
@@ -1150,8 +1162,9 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
%0 = constant 2.0 : f64
%1 = linalg.generic #trait_scale
- ins(%arga: tensor<?x?xf64>) {
- ^bb(%a: f64):
+ ins(%arga: tensor<?x?xf64>)
+ outs(%arga: tensor<?x?xf64>) {
+ ^bb(%a: f64, %s: f64):
%2 = mulf %a, %0 : f64
linalg.yield %2 : f64
} -> tensor<?x?xf64>
@@ -1224,10 +1237,10 @@ func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
func @sampled_dense_dense(%args: tensor<?x?xf32>,
%arga: tensor<?x?xf32>,
%argb: tensor<?x?xf32>,
- %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_sampled_dense_dense
- ins(%args, %arga, %argb : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
- init(%argx : tensor<?x?xf32>) {
+ ins(%args, %arga, %argb : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%argx : tensor<?x?xf32>) {
^bb(%s : f32, %a : f32, %b : f32, %x : f32):
%0 = mulf %a, %b : f32
%1 = mulf %s, %0 : f32
@@ -1457,7 +1470,7 @@ func @sum_kernel_with_inv(%arga: tensor<?x?xf32>,
tensor<?x?xf32>,
tensor<?xf32>,
tensor<f32>)
- init(%argx : tensor<?xf32>) {
+ outs(%argx : tensor<?xf32>) {
^bb(%a : f32, %b : f32, %c : f32, %d : f32, %e : f32, %x : f32):
%0 = mulf %a, %b : f32
%1 = mulf %0, %d : f32
diff --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir
index 41818bb982b6..a32770e635e4 100644
--- a/mlir/test/Dialect/Linalg/sparse_3d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir
@@ -42,8 +42,9 @@
// CHECK: }
func @add_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_ddd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s: f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -76,8 +77,9 @@ func @add_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_ddd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -157,8 +159,9 @@ func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dds
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -199,8 +202,9 @@ func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dds
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -284,8 +288,9 @@ func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dsd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -326,8 +331,9 @@ func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dsd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -437,8 +443,9 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dss
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -482,8 +489,9 @@ func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_dss
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -572,8 +580,9 @@ func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sdd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -615,8 +624,9 @@ func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sdd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -731,8 +741,9 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sds
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -777,8 +788,9 @@ func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sds
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -897,8 +909,9 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_ssd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -943,8 +956,9 @@ func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_ssd
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -1089,8 +1103,9 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sss
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -1138,8 +1153,9 @@ func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: }
func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
%0 = linalg.generic #trait_sss
- ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) {
- ^bb(%a: f32, %b: f32):
+ ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+ outs(%arga : tensor<32x16x8xf32>) {
+ ^bb(%a: f32, %b: f32, %s : f32):
%0 = mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<32x16x8xf32>
@@ -1213,8 +1229,8 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
%argc: tensor<?x?xf32>,
%argd: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_kernel_3d
- ins(%argb, %argc, %argd : tensor<?x?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
- init(%arga : tensor<?x?xf32>) {
+ ins(%argb, %argc, %argd : tensor<?x?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arga : tensor<?x?xf32>) {
^bb(%b: f32, %c: f32, %d : f32, %a : f32):
%0 = mulf %b, %c : f32
%1 = mulf %0, %d : f32
@@ -1275,8 +1291,8 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
// CHECK: }
func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_sum_reduction
- ins(%arga : tensor<10x20x30xf32>)
- init(%argx : tensor<f32>) {
+ ins(%arga : tensor<10x20x30xf32>)
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %x : f32):
%0 = addf %x, %a : f32
linalg.yield %0: f32
@@ -1334,7 +1350,7 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
%argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_sum_reduction_inv
ins(%arga, %argb : tensor<?x?x?xf32>, tensor<?xf32>)
- init(%argx : tensor<f32>) {
+ outs(%argx : tensor<f32>) {
^bb(%a : f32, %b : f32, %x : f32):
%0 = mulf %a, %b : f32
%1 = addf %x, %0 : f32
@@ -1363,7 +1379,8 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
// CHECK-LABEL: func @invariants(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>) -> tensor<10x20x30xf32> {
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>,
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
// CHECK: %[[VAL_3:.*]] = constant 10 : index
// CHECK: %[[VAL_4:.*]] = constant 20 : index
// CHECK: %[[VAL_5:.*]] = constant 30 : index
@@ -1390,10 +1407,12 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
// CHECK: }
func @invariants(%arga: tensor<10xf32>,
%argb: tensor<20xf32>,
- %argc: tensor<30xf32>) -> tensor<10x20x30xf32> {
+ %argc: tensor<30xf32>,
+ %shape : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
%0 = linalg.generic #trait_invariants
- ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) {
- ^bb(%a : f32, %b : f32, %c : f32):
+ ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>)
+ outs(%shape : tensor<10x20x30xf32>) {
+ ^bb(%a : f32, %b : f32, %c : f32, %s : f32):
%0 = mulf %a, %b : f32
%1 = mulf %0, %c : f32
linalg.yield %1: f32
diff --git a/mlir/test/Dialect/Linalg/sparse_invalid.mlir b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
index a75ec361a7a1..bb64e80785fa 100644
--- a/mlir/test/Dialect/Linalg/sparse_invalid.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
@@ -12,11 +12,14 @@
iterator_types = ["parallel"]
}
-func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
+func @invalid_memref(%arga: memref<32xf32>, %argb: f32, %shape: tensor<32xf32>)
+ -> tensor<32xf32>
+{
// expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
%0 = linalg.generic #trait_memref
- ins(%arga: memref<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: memref<32xf32>)
+ outs(%shape: tensor<32xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -25,79 +28,6 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
// -----
-#trait_two_out = {
- indexing_maps = [
- affine_map<(i) -> (i)>, // a
- affine_map<(i) -> (i)>, // x (out)
- affine_map<(i) -> (i)> // y (out)
- ],
- sparse = [
- [ "S" ], // a
- [ "D" ], // x
- [ "D" ] // y
- ],
- iterator_types = ["parallel"]
-}
-
-func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected single output tensor}}
- %0, %1 = linalg.generic #trait_two_out
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
- %0 = addf %a, %a : f32
- linalg.yield %a, %0 : f32, f32
- } -> tensor<32xf32>, tensor<32xf32>
- return %1 : tensor<32xf32>
-}
-
-// -----
-
-#trait_two_blocks = {
- indexing_maps = [
- affine_map<(i) -> (i)>, // a
- affine_map<(i) -> (i)> // x (out)
- ],
- sparse = [
- [ "S" ], // a
- [ "D" ] // x
- ],
- iterator_types = ["parallel"]
-}
-
-func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}}
- %0 = linalg.generic #trait_two_blocks
- ins(%arga: tensor<32xf32>) {
- ^bb1(%a: f32):
- %0 = addf %a, %a : f32
- ^bb2:
- linalg.yield %0 : f32
- } -> tensor<32xf32>
- return %0 : tensor<32xf32>
-}
-
-// -----
-
-#trait_no_block = {
- indexing_maps = [
- affine_map<(i) -> (i)> // a
- ],
- sparse = [
- [ "S" ] // a
- ],
- iterator_types = ["parallel"]
-}
-
-func @invalid_no_block(%arga: tensor<32xf32>) {
- // expected-error at +1 {{'linalg.generic' op expected region with 1 block}}
- linalg.generic #trait_no_block
- ins(%arga: tensor<32xf32>) {
- }
- return
-}
-
-// -----
-
#trait_too_many = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
@@ -114,8 +44,9 @@ func @invalid_no_block(%arga: tensor<32xf32>) {
func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
%0 = linalg.generic #trait_too_many
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -136,8 +67,9 @@ func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
%0 = linalg.generic #trait_no_array
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -161,8 +93,9 @@ func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
// expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
%0 = linalg.generic #trait_wrong_rank
- ins(%arga: tensor<32xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32xf32>)
+ outs(%arga: tensor<32xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32xf32>
@@ -186,8 +119,9 @@ func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> {
// expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
%0 = linalg.generic #trait_no_string
- ins(%arga: tensor<32x16xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32x16xf32>)
+ outs(%arga: tensor<32x16xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -211,8 +145,9 @@ func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf3
func @invalid_wrong_symbol(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> {
// expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 1}}
%0 = linalg.generic #trait_wrong_symbol
- ins(%arga: tensor<32x16xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32x16xf32>)
+ outs(%arga: tensor<32x16xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
@@ -236,8 +171,9 @@ func @invalid_wrong_symbol(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16
func @invalid_no_sparse_output(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> {
// expected-error at +1 {{'linalg.generic' op sparse output tensors not supported (yet)}}
%0 = linalg.generic #trait_no_sparse_output
- ins(%arga: tensor<32x16xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<32x16xf32>)
+ outs(%arga: tensor<32x16xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = addf %a, %argb : f32
linalg.yield %0 : f32
} -> tensor<32x16xf32>
diff --git a/mlir/test/Dialect/Linalg/sparse_parallel.mlir b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
index a75406fbab69..3d3d51ae0327 100644
--- a/mlir/test/Dialect/Linalg/sparse_parallel.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
@@ -50,8 +50,9 @@
//
func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga: tensor<?x?xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<?x?xf32>)
+ outs(%arga: tensor<?x?xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = mulf %a, %scale : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
@@ -99,8 +100,9 @@ func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
//
func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_ss
- ins(%arga: tensor<?x?xf32>) {
- ^bb(%a: f32):
+ ins(%arga: tensor<?x?xf32>)
+ outs(%arga: tensor<?x?xf32>) {
+ ^bb(%a: f32, %s: f32):
%0 = mulf %a, %scale : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
@@ -151,7 +153,7 @@ func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
%0 = linalg.generic #trait_matvec
ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
- init(%argx : tensor<16xf32>) {
+ outs(%argx : tensor<16xf32>) {
^bb(%A: f32, %b: f32, %x: f32):
%0 = mulf %A, %b : f32
%1 = addf %0, %x : f32
diff --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir
index c63bdb1e413d..69b8e1903d69 100644
--- a/mlir/test/Dialect/Linalg/sparse_storage.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir
@@ -88,8 +88,9 @@
func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_mul_1d
- ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>) {
- ^bb(%a: f64, %b: f64):
+ ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>)
+ outs(%arga : tensor<32xf64>) {
+ ^bb(%a: f64, %b: f64, %s: f64):
%0 = mulf %a, %b : f64
linalg.yield %0 : f64
} -> tensor<32xf64>
diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index 2a6a7ba7b7e3..fcecf896ac5d 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -198,14 +198,14 @@ func @matmul_tensors(
// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: init(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
%0 = linalg.matmul {__internal_linalg_transform__ = "tensors_distribute1"}
ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
- init(%arg2: tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
-> tensor<?x?xf32>
// CHECK: return %[[TD0]] : tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 41adff7d46c3..9e9688088568 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -8,7 +8,7 @@
func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%t0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
- init(%arg2: tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
-> tensor<?x?xf32>
%c4 = constant 4 : index
@@ -25,7 +25,7 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
%6 = subtensor %t0[%arg3, %arg7][%c2, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
%7 = subtensor %arg1[%arg7, %arg5][4, %c3][1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
%8 = subtensor %arg8[%arg3, %arg5][%c2, %c3][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %9 = linalg.matmul ins(%6, %7 : tensor<?x4xf32>, tensor<4x?xf32>) init(%8 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %9 = linalg.matmul ins(%6, %7 : tensor<?x4xf32>, tensor<4x?xf32>) outs(%8 : tensor<?x?xf32>) -> tensor<?x?xf32>
%10 = subtensor_insert %9 into %arg8[%arg3, %arg5] [%c2, %c3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %10 : tensor<?x?xf32>
}
@@ -53,6 +53,6 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
// subtensors of the producing matmul.
// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
-// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
-// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
+// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index b899cb3e0049..787ea8d2b395 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
@@ -14,13 +14,13 @@ func @matmul_tensors(
// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: init(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
- init(%arg2: tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
-> tensor<?x?xf32>
// CHECK: return %[[TD0]] : tensor<?x?xf32>
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index b713ae98b107..db200ba5f90f 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -1101,7 +1101,7 @@ TEST_FUNC(linalg_metadata_ops) {
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
-// CHECK-SAME: init(%{{[a-z0-9]*}} : tensor<?x?xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : tensor<?x?xf32>)
// CHECK: mulf
// CHECK: addf
// CHECK: } -> tensor<?x?xf32>
@@ -1115,14 +1115,15 @@ TEST_FUNC(linalg_tensors_test) {
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto tensorType = RankedTensorType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type);
- auto f = makeFunction("linalg_tensors", {}, {tensorType, memrefType});
+ auto f =
+ makeFunction("linalg_tensors", {}, {tensorType, memrefType, tensorType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
- Value A(f.getArgument(0)), B(f.getArgument(1));
+ Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
AffineExpr i, j;
bindDims(&globalContext(), i, j);
- StructuredIndexed SA(A), SB(B), SC(tensorType);
+ StructuredIndexed SA(A), SB(B), SC(C);
Value added = linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}))
->getResult(0);
Value maxed = linalg_generic_pointwise_max(
@@ -1223,7 +1224,8 @@ TEST_FUNC(builder_loop_for_yield) {
[&](Value iv, ValueRange args) {
Value sum = args[0] + args[1];
return scf::ValueVector{args[1], sum};
- }).getResults();
+ })
+ .getResults();
results[0] + results[1];
// clang-format off
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 528fae883d19..f81380f02bb3 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -4,7 +4,6 @@
// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [
// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-// ODS-NEXT: NamedStructuredOpTrait
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
@@ -29,7 +28,6 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-// ODS-NEXT: NamedStructuredOpTrait
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
@@ -54,7 +52,6 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-// ODS-NEXT: NamedStructuredOpTrait
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 45dc115e6c1e..0342fab5ab9f 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1453,54 +1453,45 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- NamedStructuredOpTrait,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins Variadic<AnyShaped>:$inputs,
- Variadic<AnyMemRef>:$output_buffers,
- Variadic<AnyRankedTensor>:$init_tensors);
+ Variadic<AnyShaped>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;
let builders = [ OpBuilderDAG<
- (ins "ValueRange":$inputs, "ValueRange":$outputBuffers),
+ (ins "ValueRange":$inputs, "ValueRange":$outputs),
[{{
$_state.addOperands(inputs);
- $_state.addOperands(outputBuffers);
+ $_state.addOperands(outputs);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
- static_cast<int32_t>(outputBuffers.size()),
- static_cast<int32_t>(0)}));
+ static_cast<int32_t>(outputs.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder,
$_state,
TypeRange(inputs),
- TypeRange(outputBuffers),
- TypeRange(),
- TypeRange());
+ TypeRange(outputs));
}]>, OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputBuffers, "ValueRange":$initTensors),
+ "ValueRange":$outputs),
[{{
$_state.addOperands(inputs);
- $_state.addOperands(outputBuffers);
- $_state.addOperands(initTensors);
+ $_state.addOperands(outputs);
$_state.addTypes(resultTensorTypes);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
- static_cast<int32_t>(outputBuffers.size()),
- static_cast<int32_t>(initTensors.size())}));
+ static_cast<int32_t>(outputs.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder,
$_state,
TypeRange(inputs),
- TypeRange(outputBuffers),
- TypeRange(initTensors),
- resultTensorTypes);
+ TypeRange(outputs));
}]>, OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
@@ -1513,7 +1504,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
- let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
More information about the llvm-branch-commits
mailing list