[Mlir-commits] [mlir] 3139cc7 - [mlir][Linalg] Add a pattern to decompose `linalg.generic` ops.

Mahesh Ravishankar llvmlistbot at llvm.org
Fri Jul 15 16:01:28 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-07-15T23:01:18Z
New Revision: 3139cc766c86b09426893a7349763c347639cbdc

URL: https://github.com/llvm/llvm-project/commit/3139cc766c86b09426893a7349763c347639cbdc
DIFF: https://github.com/llvm/llvm-project/commit/3139cc766c86b09426893a7349763c347639cbdc.diff

LOG: [mlir][Linalg] Add a pattern to decompose `linalg.generic` ops.

This patch adds a pattern to decompose a `linalg.generic` operations
that
- has only parallel iterator types
- has more than 2 statements (including the yield)

into multiple `linalg.generic` operation such that each operation has
a single statement and a yield.
The pattern added here just splits the matching `linalg.generic` into
two `linalg.generic`s, one containing the first statement, and the
other containing the remaining. The same pattern can be applied
repeatedly on the second op to ultimately fully decompose the generic
op.

Differential Revision: https://reviews.llvm.org/D129704

Added: 
    mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
    mlir/test/Dialect/Linalg/decompose-ops.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 6e039f5c4a614..2ce8f66c51172 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -699,6 +699,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getBlock()->getArgument(opOperand->getOperandNumber());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the operand for a `blockArgument`.
+      }],
+      /*retTy=*/"OpOperand *",
+      /*methodName=*/"getTiedOpOperand",
+      /*args=*/(ins "BlockArgument":$blockArgument),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(blockArgument.getOwner() == getBlock());
+        return &this->getOperation()->getOpOperand(
+            blockArgument.getArgNumber());
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the input or output indexing map for `opOperand`.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d34719acbb523..79c40057e992e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -45,6 +45,10 @@ using LinalgLoops = SmallVector<Operation *, 4>;
 void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
                                      const LinalgTilingOptions &options);
 
+/// Populate patterns for splitting a `LinalgOp` with multiple statements within
+/// its payload into multiple `GenericOp` that have a single statement.
+void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns);
+
 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 8015edeb59a92..5bc2740afbe07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Bufferize.cpp
   CodegenStrategy.cpp
   ConstantFold.cpp
+  DecomposeLinalgOps.cpp
   Detensorize.cpp
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
new file mode 100644
index 0000000000000..9b6218474ed04
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -0,0 +1,391 @@
+//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// Pattern to decompose a GenericOp that has more than two statements
+/// into one GenericOp with the first statement (i.e. peeled operation), and
+/// a second GenericOp with the remaining statements (i.e. residual operations).
+
+/// - The result of the first GenericOp has the same shape as the iteration
+///   space of the GenericOp. The body of the op yields as many values as the
+///   original op plus all the results of the peeled operation.
+/// - The second GenericOp has as many operands as the original operation plus
+/// all the results of the first Generic Op. It has the same number of yields as
+/// the original op.
+/// - If the result of the peeled operation was yielded by the original
+///   GenericOp the uses of the corresponding results will be replaced with the
+///   result of the first GenericOp created.
+///
+///  Example
+///
+/// ```mlir
+///  %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
+///      outs(%init0, %init1 : ...) {
+///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
+///      %0 = <s0> %b0, %b1 : ...
+///      %1 = <s1> %0, %b2 : ...
+///      linalg.yield %0, %1 : ...
+///  } -> (..., ...)
+///  return %result#0, %result#1
+/// ```
+///
+/// gets split into
+///
+/// ```mlir
+/// %init = linalg.init_tensor ...
+/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
+///      outs(%init0, %init1, %init : ...)
+///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
+///      %0 = <s0> %b0, %b1 : ...
+///      linalg.yield %0, %..., %0 : ...
+///  } -> (..., ..., ...)
+/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
+///      outs(%init0, %init1 : ...) {
+///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
+///      %1 = <s1> %b3, %b2 : ...
+///      linalg.yield %..., %1 : ...
+///  } -> (..., ...)
+///  return %op0#0, %op1#1
+/// ```
+///
+/// After canonicalization this is expected to be
+///
+/// ```mlir
+/// %init = linalg.init_tensor ...
+/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
+///      outs(%init : ...)
+///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
+///      %0 = <s0> %b0, %b1 : ...
+///      linalg.yield %0 : ...
+///  } -> ...
+/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
+///      outs(%init1 : ...) {
+///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
+///      %1 = <s1> %b1, %b0 : ...
+///      linalg.yield %..., %1 : ...
+///  } -> ...
+///  return %op0, %op1
+/// ```
+struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Helper method to create a generic op for the peeled scalar operation. The
+  /// created op has an empty region.
+  GenericOp createPeeledGenericOp(GenericOp genericOp,
+                                  PatternRewriter &rewriter) const;
+
+  /// Helper method to create a generic op for the residual scalar operation.
+  /// The created op has the same region as the original op.
+  GenericOp createResidualGenericOp(GenericOp genericOp,
+                                    GenericOp peeledGenericOp,
+                                    PatternRewriter &rewriter) const;
+};
+} // namespace
+
+/// Helper method to compute the range of a generic op.
+static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
+                                                       GenericOp op) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  Location loc = op.getLoc();
+  auto allShapesSizes =
+      cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
+  AffineMap map = op.getShapesToLoopsMap();
+  return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes));
+}
+
+/// Helper method to permute the list of `values` based on the `map`.
+SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
+                                        AffineMap map) {
+  assert(map.isPermutation());
+  SmallVector<OpFoldResult> permutedValues(values.size());
+  for (auto position :
+       llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
+         return expr.cast<AffineDimExpr>().getPosition();
+       })))
+    permutedValues[position.value()] = values[position.index()];
+  return permutedValues;
+}
+
+/// Get zero value for an element type.
+static Value getZero(OpBuilder &b, Location loc, Type elementType) {
+  assert(elementType.isIntOrIndexOrFloat() &&
+         "expected scalar type while computing zero value");
+  if (elementType.isa<IntegerType>())
+    return b.create<arith::ConstantIntOp>(loc, 0, elementType);
+  if (elementType.isIndex())
+    return b.create<arith::ConstantIndexOp>(loc, 0);
+  // Assume float.
+  auto floatType = elementType.cast<FloatType>();
+  return b.create<arith::ConstantFloatOp>(
+      loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
+}
+
+GenericOp
+DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
+                                         PatternRewriter &rewriter) const {
+  Block *body = genericOp.getBody();
+  Operation *peeledScalarOperation = &(*body->begin());
+  SmallVector<AffineMap> peeledGenericOpIndexingMaps =
+      genericOp.getIndexingMaps();
+
+  /// Compute the loop ranges for operation. This is the shape of the result of
+  /// the generic op for the peeled operation.
+  Location loc = genericOp.getLoc();
+  SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
+  SmallVector<Value> newInitValues;
+  SmallVector<Type> newResultTypes;
+
+  /// The indexing map to use for the new results is obtained by
+  /// - Check if the result is yielded. If so use the same indexing map as the
+  /// corresponding output
+  /// - Identity indexing map if the result is not yielded.
+  Operation *yieldOp = body->getTerminator();
+  auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap {
+    OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr;
+    for (OpOperand &use : scalarOpResult.getUses()) {
+      if (use.getOwner() != yieldOp)
+        continue;
+      if (!firstUseInYield)
+        firstUseInYield = &use;
+      OpResult genericOpResult =
+          genericOp.getResult(use.getOperandNumber()).cast<OpResult>();
+      AffineMap indexingMap =
+          genericOp.getTiedIndexingMapForResult(genericOpResult);
+      if (indexingMap.isIdentity())
+        identityUseInYield = &use;
+    }
+    if (identityUseInYield || !firstUseInYield)
+      return rewriter.getMultiDimIdentityMap(domain.size());
+    OpResult genericOpResult =
+        genericOp.getResult(firstUseInYield->getOperandNumber())
+            .cast<OpResult>();
+    return genericOp.getTiedIndexingMapForResult(genericOpResult);
+  };
+
+  for (auto scalarResult : peeledScalarOperation->getResults()) {
+    AffineMap resultIndexingMap = getResultIndexingMap(scalarResult);
+    SmallVector<OpFoldResult> initSize =
+        permuteValues(domain, resultIndexingMap);
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, initSize, scalarResult.getType());
+    newInitValues.push_back(initTensor);
+    newResultTypes.push_back(initTensor.getType());
+    peeledGenericOpIndexingMaps.push_back(resultIndexingMap);
+  }
+
+  /// Create the peeled generic op with an empty body.
+  SmallVector<Value> outsOperands = genericOp.getOutputOperands();
+  outsOperands.append(newInitValues.begin(), newInitValues.end());
+  SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
+  resultTypes.append(newResultTypes.begin(), newResultTypes.end());
+  auto indexingMapAttr =
+      rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
+  return rewriter.create<GenericOp>(
+      loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr,
+      genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
+      [](OpBuilder, Location, ValueRange) {});
+}
+
+GenericOp
+DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
+                                           GenericOp peeledGenericOp,
+                                           PatternRewriter &rewriter) const {
+  /// Append all results from the peeledGenericOps as `ins` operand for the
+  /// residual generic op.
+  SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
+      llvm::map_range(genericOp.getInputOperands(),
+                      [](OpOperand *operand) { return operand->get(); }));
+  unsigned origNumResults = genericOp.getNumResults();
+  unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
+  SmallVector<Value> extraIns;
+  for (auto resultNum :
+       llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
+    extraIns.push_back(peeledGenericOp->getResult(resultNum));
+  residualGenericOpOperands.append(extraIns);
+
+  /// Add indexing maps for the newly added operands. Use the same map
+  /// as those used for the new results of the peeledGenericOp.
+  auto indexingMaps = llvm::to_vector(
+      llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
+        return genericOp.getTiedIndexingMap(operand);
+      }));
+  for (auto resultNum :
+       llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
+    OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
+    indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
+  }
+  for (OpOperand *outOperand : genericOp.getOutputOperands())
+    indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
+
+  auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
+  return rewriter.create<GenericOp>(
+      genericOp->getLoc(), genericOp->getResultTypes(),
+      residualGenericOpOperands, genericOp.outputs(), indexingMapAttr,
+      genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
+      [](OpBuilder, Location, ValueRange) {});
+}
+
+LogicalResult
+DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
+                                   PatternRewriter &rewriter) const {
+  /// For now only match on operations where the iterator types are all parallel
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "unhandled decomposition of operation "
+                                       "with non-parallel iterator types");
+  }
+  // TODO: this could be generalized to handle `linalg.generic` with buffer
+  // operands too but requires allocation for intermediates. Punt on this for
+  // now.
+  if (!genericOp.hasTensorSemantics()) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "only operations with tensor semantics are handled");
+  }
+
+  // TODO: For now only decompose operations where the `outs` operands values
+  // are not accessed within the payload. This might be relaxed in future, but
+  // needs a bit more reasoning to ensure that it is safe.
+  if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
+        return genericOp.payloadUsesValueFromOperand(outOperand);
+      })) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "unhandled decomposition of generic op with use of out "
+                   "operand value in payload");
+  }
+
+  if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
+        return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
+      })) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "unhandled decomposition of generic op with out operand not "
+                   "accessed using a permutation");
+  }
+
+  /// If the op has only a single statement (apart from the yield), do nothing.
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() <= 2) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "operation has less than 3 statements");
+  }
+
+  /// Check that the peeled statement has a scalar element type.
+  if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
+                   [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
+    return rewriter.notifyMatchFailure(
+        &(*body->getOperations().begin()),
+        "expected return type to be only int, index or float");
+  }
+
+  GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
+  GenericOp residualGenericOp =
+      createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
+
+  /// Move the first statement of the original operation into the body of the
+  /// generic op for the peeled operation.
+  Block *peeledGenericOpBody = peeledGenericOp.getBody();
+  Block *residualGenericOpBody = residualGenericOp.getBody();
+  assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
+         "expected split generic ops to have empty region");
+  peeledGenericOpBody->getOperations().splice(
+      peeledGenericOpBody->begin(), body->getOperations(), body->begin());
+  residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
+                                                body->getOperations());
+
+  Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
+  auto yieldOp = residualGenericOpBody->getTerminator();
+  {
+    // Yield all the result of the peeled scalar operation.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToEnd(peeledGenericOpBody);
+    SmallVector<Value> yieldedVals;
+    for (auto origYield : yieldOp->getOperands()) {
+      if (origYield.getDefiningOp() == peeledScalarOperation) {
+        yieldedVals.push_back(origYield);
+      } else {
+        yieldedVals.push_back(
+            getZero(rewriter, genericOp.getLoc(), origYield.getType()));
+      }
+    }
+    yieldedVals.append(llvm::to_vector(
+        llvm::map_range(peeledScalarOperation->getResults(),
+                        [](OpResult opr) -> Value { return opr; })));
+    rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
+  }
+
+  /// In the split operations, replace block arguments uses that refer to
+  /// original operation to the block arguments of the newly created operation.
+  unsigned origNumInputs = genericOp.getNumInputs();
+  for (auto inputBlockArg :
+       llvm::enumerate(genericOp.getBody()->getArguments())) {
+    Value residualOpReplacementArg =
+        residualGenericOpBody->getArgument(inputBlockArg.index());
+    inputBlockArg.value().replaceUsesWithIf(
+        residualOpReplacementArg, [&](OpOperand &use) {
+          return use.getOwner()->getBlock() == residualGenericOpBody;
+        });
+
+    Value peeledOpReplacementArg =
+        peeledGenericOpBody->getArgument(inputBlockArg.index());
+    inputBlockArg.value().replaceUsesWithIf(
+        peeledOpReplacementArg, [&](OpOperand &use) {
+          return use.getOwner()->getBlock() == peeledGenericOpBody;
+        });
+  }
+
+  /// Before fixing up the residual operation, track what values are yielded. If
+  /// any of those are from the peeled scalar operation, the uses of the
+  /// corresponding result have to be remapped to result of the generic op for
+  /// the peeled operation.
+  SmallVector<Value> replacements;
+  for (auto yieldValue : llvm::enumerate(yieldOp->getOperands())) {
+    OpResult opr = yieldValue.value().dyn_cast<OpResult>();
+    if (!opr || opr.getOwner() != peeledScalarOperation)
+      replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
+    else
+      replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
+  }
+
+  /// Update all uses of the peeled scalar operation results in the residual op
+  /// to the newly added arguments.
+  {
+    SmallVector<Value> scalarReplacements;
+    unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
+    scalarReplacements.reserve(peeledScalarOpNumResults);
+    for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
+      scalarReplacements.push_back(
+          residualGenericOpBody->getArgument(num + origNumInputs));
+    bool allUsesReplaced = false;
+    rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
+                                  residualGenericOpBody, &allUsesReplaced);
+    assert(!allUsesReplaced &&
+           "peeled scalar operation is erased when it wasnt expected to be");
+  }
+
+  // Replace the original operation
+  rewriter.replaceOp(genericOp, replacements);
+  return success();
+}
+
+void mlir::linalg::populateDecomposeLinalgOpsPattern(
+    RewritePatternSet &patterns) {
+  patterns.insert<DecomposeLinalgOp>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir
new file mode 100644
index 0000000000000..648a58eb87b30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir
@@ -0,0 +1,326 @@
+// RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK
+
+func.func @simple_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
+  %init2 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %result:2 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, 
+                     affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>)
+    outs(%init1, %init2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) :
+      %0 = arith.addf %b0, %b1 : f32
+      %1 = arith.mulf %0, %b2 : f32
+      linalg.yield %0, %1 : f32, f32
+    } -> (tensor<?x?xf32>, tensor<?x?xf32>)
+  return %result#0, %result#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func @simple_op(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
+//  CHECK-DAG:   %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CHECK-DAG:   %[[GENERIC1:.+]]:3 = linalg.generic
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP3]]]
+// CHECK-SAME:       ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT1]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B0:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B2:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B3:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B5:[a-zA-Z0-9]+]]: f32):
+// CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
+// CHECK-NEXT:     linalg.yield %[[S0]], %{{[a-zA-Z0-9]+}}, %[[S0]]
+//      CHECK:   %[[GENERIC2:.+]]:2 = linalg.generic
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP3]], #[[MAP0]]]
+// CHECK-SAME:       ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]]#2 :
+// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B6:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B7:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B8:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B9:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B10:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B11:[a-zA-Z0-9]+]]: f32):
+// CHECK-NEXT:     %[[S1:.+]] = arith.mulf %[[B9]], %[[B8]]
+// CHECK-NEXT:     linalg.yield %[[B9]], %[[S1]]
+//      CHECK:   return %[[GENERIC1]]#0, %[[GENERIC2]]#1
+
+// With cse + canonicalization
+
+//  CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1)>
+//      CANONICALIZECHECK: func @simple_op(
+// CANONICALIZECHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CANONICALIZECHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CANONICALIZECHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CANONICALIZECHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CANONICALIZECHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CANONICALIZECHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
+//  CANONICALIZECHECK-DAG:   %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CANONICALIZECHECK-DAG:   %[[GENERIC1:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CANONICALIZECHECK-SAME:       ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT1]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[B0:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B2:[a-zA-Z0-9]+]]: f32):
+// CANONICALIZECHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[S0]]
+//      CANONICALIZECHECK:   %[[GENERIC2:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       [#[[MAP3]], #[[MAP2]], #[[MAP0]]]
+// CANONICALIZECHECK-SAME:       ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG2]], %[[GENERIC1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT2]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[B3:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B5:[a-zA-Z0-9]+]]: f32):
+// CANONICALIZECHECK-NEXT:     %[[S1:.+]] = arith.mulf %[[B4]], %[[B3]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[S1]]
+//      CANONICALIZECHECK:   return %[[GENERIC1]], %[[GENERIC2]]
+
+
+// -----
+
+func.func @simple_op_permuted_outputs(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
+  %init2 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %result:3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, 
+                     affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+                     affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>)
+    outs(%init1, %init2, %init2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32, %b5 : f32) :
+      %0 = arith.addf %b0, %b1 : f32
+      %1 = arith.mulf %0, %b2 : f32
+      linalg.yield %0, %1, %0 : f32, f32, f32
+    } -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+  return %result#0, %result#1, %result#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func @simple_op_permuted_outputs(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
+//  CHECK-DAG:   %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CHECK-DAG:   %[[GENERIC1:.+]]:4 = linalg.generic
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:       ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT2]], %[[INIT2]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B0:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B2:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B3:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B5:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B6:[a-zA-Z0-9]+]]: f32):
+// CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
+// CHECK-NEXT:     linalg.yield %[[S0]], %{{[a-zA-Z0-9]+}}, %[[S0]]
+//      CHECK:   %[[GENERIC2:.+]]:3 = linalg.generic
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:       ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]]#3 :
+// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT2]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B7:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B8:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B9:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B10:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B11:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[B12:[a-zA-Z0-9]+]]: f32):
+// CHECK-NEXT:     %[[S1:.+]] = arith.mulf %[[B10]], %[[B9]]
+// CHECK-NEXT:     linalg.yield %[[B10]], %[[S1]], %[[B10]]
+//      CHECK:   return %[[GENERIC1]]#0, %[[GENERIC2]]#1, %[[GENERIC1]]#2
+
+//  CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1)>
+//      CANONICALIZECHECK: func @simple_op_permuted_outputs(
+// CANONICALIZECHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CANONICALIZECHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CANONICALIZECHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CANONICALIZECHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CANONICALIZECHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CANONICALIZECHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
+//  CANONICALIZECHECK-DAG:   %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+//  CANONICALIZECHECK-DAG:   %[[GENERIC1:.+]]:2 = linalg.generic
+// CANONICALIZECHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+// CANONICALIZECHECK-SAME:       ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT1]], %[[INIT2]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[B0:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B2:[a-zA-Z0-9]+]]: f32):
+// CANONICALIZECHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[S0]], %[[S0]]
+//      CANONICALIZECHECK:   %[[GENERIC2:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       [#[[MAP3]], #[[MAP0]], #[[MAP0]]]
+// CANONICALIZECHECK-SAME:       ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG2]], %[[GENERIC1]]#1 :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT2]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B5:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[B6:[a-zA-Z0-9]+]]: f32):
+// CANONICALIZECHECK-NEXT:     %[[S1:.+]] = arith.mulf %[[B5]], %[[B4]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[S1]]
+//      CANONICALIZECHECK:   return %[[GENERIC1]]#0, %[[GENERIC2]], %[[GENERIC1]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) -> tensor<20x10xf64> {
+  %init = linalg.init_tensor [20, 10] : tensor<20x10xf64>
+  %0 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<10x20xf32>, tensor<10xi32>)
+    outs(%init : tensor<20x10xf64>) {
+    ^bb0(%b0 : f32, %b1 : i32, %b2 : f64):
+      %1 = arith.sitofp %b1 : i32 to f64
+      %2 = arith.extf %b0 : f32 to f64
+      %3 = arith.addf %1, %2 : f64
+      linalg.yield %3 : f64
+    } -> tensor<20x10xf64>
+  return %0 : tensor<20x10xf64>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func @multi_statement(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<10x20xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<10xi32>)
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64>
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64>
+//      CHECK:   %[[GENERIC0:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-NEXT:     ^bb0(
+// CHECK-SAME:         %[[B0:.+]]: f32
+// CHECK-SAME:         %[[B1:.+]]: i32
+// CHECK-SAME:         %[[B2:[a-zA-Z0-9]+]]: f64
+// CHECK-SAME:         %[[B3:.+]]: f64
+// CHECK-NEXT:       %[[S0:.+]] = arith.sitofp %[[B1]] : i32 to f64
+// CHECK-NEXT:       linalg.yield %{{.+}}, %[[S0]]
+//      CHECK:   %[[GENERIC1:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]], #[[MAP2]], #[[MAP0]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[GENERIC0]]#1 :
+// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-NEXT:     ^bb0(
+// CHECK-SAME:         %[[B4:.+]]: f32
+// CHECK-SAME:         %[[B5:.+]]: i32
+// CHECK-SAME:         %[[B6:[a-zA-Z0-9]+]]: f64
+// CHECK-SAME:         %[[B7:[a-zA-Z0-9]+]]: f64
+// CHECK-SAME:         %[[B8:.+]]: f64
+// CHECK-NEXT:       %[[S1:.+]] = arith.extf %[[B4]] : f32 to f64
+// CHECK-NEXT:       linalg.yield %{{.+}}, %[[S1]]
+//      CHECK:   %[[GENERIC2:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]], #[[MAP0]], #[[MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[GENERIC0]]#1, %[[GENERIC1]]#1 :
+// CHECK-SAME:       outs(%[[INIT0]] :
+// CHECK-NEXT:     ^bb0(
+// CHECK-SAME:         %[[B9:.+]]: f32
+// CHECK-SAME:         %[[B10:.+]]: i32
+// CHECK-SAME:         %[[B11:[a-zA-Z0-9]+]]: f64
+// CHECK-SAME:         %[[B12:[a-zA-Z0-9]+]]: f64
+// CHECK-SAME:         %[[B13:.+]]: f64
+// CHECK-NEXT:       %[[S2:.+]] = arith.addf %[[B11]], %[[B12]] : f64
+// CHECK-NEXT:       linalg.yield %[[S2]]
+//      CHECK:   return %[[GENERIC2]]
+
+//  CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CANONICALIZECHECK: func @multi_statement(
+// CANONICALIZECHECK-SAME:     %[[ARG0:.+]]: tensor<10x20xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG1:.+]]: tensor<10xi32>)
+//  CANONICALIZECHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64>
+//  CANONICALIZECHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64>
+//      CANONICALIZECHECK:   %[[GENERIC0:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CANONICALIZECHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT1]] :
+// CANONICALIZECHECK-NEXT:     ^bb0(
+// CANONICALIZECHECK-SAME:         %[[B0:.+]]: i32
+// CANONICALIZECHECK-SAME:         %[[B1:.+]]: f64
+// CANONICALIZECHECK-NEXT:       %[[S0:.+]] = arith.sitofp %[[B0]] : i32 to f64
+// CANONICALIZECHECK-NEXT:       linalg.yield %[[S0]]
+//      CANONICALIZECHECK:   %[[GENERIC1:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       indexing_maps = [#[[MAP1]], #[[MAP1]]]
+// CANONICALIZECHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG0]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT1]] :
+// CANONICALIZECHECK-NEXT:     ^bb0(
+// CANONICALIZECHECK-SAME:         %[[B2:.+]]: f32
+// CANONICALIZECHECK-SAME:         %[[B3:.+]]: f64
+// CANONICALIZECHECK-NEXT:       %[[S1:.+]] = arith.extf %[[B2]] : f32 to f64
+// CANONICALIZECHECK-NEXT:       linalg.yield %[[S1]]
+//      CANONICALIZECHECK:   %[[GENERIC2:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP2]]]
+// CANONICALIZECHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[GENERIC0]], %[[GENERIC1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[INIT0]] :
+// CANONICALIZECHECK-NEXT:     ^bb0(
+// CANONICALIZECHECK-SAME:         %[[B4:[a-zA-Z0-9]+]]: f64
+// CANONICALIZECHECK-SAME:         %[[B5:[a-zA-Z0-9]+]]: f64
+// CANONICALIZECHECK-SAME:         %[[B6:.+]]: f64
+// CANONICALIZECHECK-NEXT:       %[[S2:.+]] = arith.addf %[[B4]], %[[B5]] : f64
+// CANONICALIZECHECK-NEXT:       linalg.yield %[[S2]]
+//      CANONICALIZECHECK:   return %[[GENERIC2]]

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index e11c49fe6d696..20862892f1068 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLinalgTestPasses
+  TestLinalgDecomposeOps.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp
new file mode 100644
index 0000000000000..a64387eaf5294
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp
@@ -0,0 +1,54 @@
+//===- TestLinalgDecomposeOps.cpp - Test Linalg decomposition  ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing decomposition of Linalg ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestLinalgDecomposeOps
+    : public PassWrapper<TestLinalgDecomposeOps, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps)
+
+  TestLinalgDecomposeOps() = default;
+  TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect>();
+  }
+  StringRef getArgument() const final { return "test-linalg-decompose-ops"; }
+  StringRef getDescription() const final {
+    return "Test Linalg decomposition patterns";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    RewritePatternSet decompositionPatterns(context);
+    linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns);
+    if (failed(applyPatternsAndFoldGreedily(
+            getOperation(), std::move(decompositionPatterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgDecomposeOps() {
+  PassRegistration<TestLinalgDecomposeOps>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 778c569c5ce16..f8fa2459e667f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -86,6 +86,7 @@ void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLastModifiedPass();
+void registerTestLinalgDecomposeOps();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
@@ -184,6 +185,7 @@ void registerTestPasses() {
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLastModifiedPass();
+  mlir::test::registerTestLinalgDecomposeOps();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgFusionTransforms();
   mlir::test::registerTestLinalgTensorFusionTransforms();


        


More information about the Mlir-commits mailing list