[Mlir-commits] [mlir] 9ecc817 - [mlir] Add support for fusion into TiledLoopOp.
Alexander Belyaev
llvmlistbot at llvm.org
Fri May 21 09:14:18 PDT 2021
Author: Alexander Belyaev
Date: 2021-05-21T18:13:45+02:00
New Revision: 9ecc8178d72097c8f9e31ea7c50085748d187aff
URL: https://github.com/llvm/llvm-project/commit/9ecc8178d72097c8f9e31ea7c50085748d187aff
DIFF: https://github.com/llvm/llvm-project/commit/9ecc8178d72097c8f9e31ea7c50085748d187aff.diff
LOG: [mlir] Add support for fusion into TiledLoopOp.
Differential Revision: https://reviews.llvm.org/D102722
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index dd5c2d5a51d98..334cecb3eb477 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,6 +584,15 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
];
let extraClassDeclaration = [{
+ /// Number of loops
+ unsigned getNumLoops() { return step().size(); }
+
+ /// Number of input operands
+ unsigned getNumInputs() { return inputs().size(); }
+
+ /// Number of output operands
+ unsigned getNumOutputs() { return outputs().size(); }
+
/// Number of operands controlling the loop: lbs, ubs, steps
unsigned getNumControlOperands() { return 3 * getNumLoops(); }
@@ -597,7 +606,6 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
return getBody()->getArguments().take_back(outputs().size());
}
-
void setLowerBounds(ValueRange lowerBounds) {
unsigned numLoops = getNumLoops();
assert(lowerBounds.size() == numLoops &&
@@ -622,6 +630,16 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
setOperand(pos, steps[i]);
}
+ /// Block argument that corresponds to the `input` or `output` operand.
+ BlockArgument getTiedBlockArgument(OpOperand& operand) {
+ auto operandIndex = operand.getOperandNumber();
+ assert(
+ operandIndex >= getNumControlOperands() &&
+ operandIndex < getNumOperands() &&
+ "tied block arg is defined only for `input` and `output` arguments");
+ return getBody()->getArgument(operandIndex - 2 * getNumLoops());
+ }
+
/// Result that corresponds to the `outputs` argument of tensor type.
OpResult getTiedOpResult(OpOperand& opOperand) {
// No result can correspond to a memref argument.
@@ -642,7 +660,76 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
return getOperation()->getResult(tensorId);
}
- unsigned getNumLoops() { return step().size(); }
+ /// Append `operand` to the `input` arguments.
+ OpOperand& appendInputOperand(OpBuilder& builder, Value operand) {
+ int numLoops = getNumLoops();
+ int numInputs = getNumInputs();
+ int numOutputs = getNumOutputs();
+
+ getOperation()->insertOperands(getNumControlOperands() + numInputs,
+ operand);
+ getBody()->insertArgument(numLoops + numInputs, operand.getType());
+ getOperation()->setAttr(
+ TiledLoopOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr(
+ {numLoops, numLoops, numLoops, numInputs + 1, numOutputs}));
+ return getOperation()->getOpOperand(getNumControlOperands() + numInputs);
+ }
+
+ /// Append `operand` to the `output` arguments.
+ OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) {
+ int numLoops = getNumLoops();
+ int numInputs = getNumInputs();
+ int numOutputs = getNumOutputs();
+
+ getOperation()->insertOperands(
+ getNumControlOperands() + numInputs + numOutputs, operand);
+ getBody()->insertArgument(numLoops + numInputs + numOutputs,
+ operand.getType());
+ getOperation()->setAttr(
+ TiledLoopOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr(
+ {numLoops, numLoops, numLoops, numInputs, numOutputs + 1}));
+ return getOperation()->getOpOperand(getNumControlOperands() + numInputs +
+ numOutputs);
+ }
+
+ /// Erase `operand` from the `input` or `output` arguments.
+ void eraseOperand(OpBuilder& builder, OpOperand& operand) {
+ int numInputs = getNumInputs();
+ int numLoops = getNumLoops();
+ int numOutputs = getNumOutputs();
+ int numControlOperands = getNumControlOperands();
+
+ auto operandIndex = operand.getOperandNumber();
+ assert(operandIndex >= numControlOperands &&
+ operandIndex < getNumOperands() &&
+ "Can erase only `input` or `output` operand");
+
+ if (operandIndex >= numControlOperands + numInputs)
+ --numOutputs;
+ else
+ --numInputs;
+
+ getOperation()->eraseOperand(operandIndex);
+ getBody()->eraseArgument(operandIndex - 2 * numLoops);
+ getOperation()->setAttr(
+ TiledLoopOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr(
+ {numLoops, numLoops, numLoops, numInputs, numOutputs}));
+ }
+
+ OpOperand* findInputOperand(Value value) {
+ OperandRange::iterator it = llvm::find(inputs(), value);
+ if (it == inputs().end()) return nullptr;
+ return it.getBase();
+ }
+
+ OpOperand* findOutputOperand(Value value) {
+ OperandRange::iterator it = llvm::find(outputs(), value);
+ if (it == outputs().end()) return nullptr;
+ return it.getBase();
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 108d0414c34bf..465f933f862cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -107,6 +107,66 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
+// Return tiled operands for the fused producer op. When fusing into
+// `linalg.tiled_loop` one has to update `input` and `output` arguments of the
+// loop correspondingly.
+// Each input tensor of the producer op has to be added to `inputs` of the
+// `tiled_loop` if it is not present there already. Each output tensor has to
+// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
+// on whether the correponding result is an input or an output to the loop.
+//
+// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
+// intermediate result is not used by any other operation but the consumer. A
+// more generic way is to append all missing output tensors of the producer to
+// the tiled loop outputs and hence modify the number of the results, since we
+// would need to add the intermediate results to `linalg.yield`. After that a
+// canonicalization pass would move the unused output args of the `tiled_loop`
+// to the `input` section.
+static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
+ auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
+ if (!tiledLoop)
+ return llvm::to_vector<4>(producer.getShapedOperands());
+
+ SmallVector<Value, 4> tiledOperands;
+ assert(producer.hasTensorSemantics() &&
+ "only fusion on tensors is currently supported for TiledLinalgOp");
+
+ for (auto producerInput : producer.getInputTensors()) {
+ OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
+ if (addedInput == nullptr)
+ addedInput = &tiledLoop.appendInputOperand(b, producerInput);
+ BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
+ tiledOperands.push_back(addedBlockArg);
+ }
+ for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
+ Value producerOutput = en.value();
+
+ Value result = producer->getResult(en.index());
+ OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
+ OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
+ assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
+ "The result should be present in `input` or `output` args of "
+ "`tiled_loop");
+
+ bool isInput = resultInputOperand;
+ int opNumber = isInput ? resultInputOperand->getOperandNumber()
+ : resultOutputOperand->getOperandNumber();
+
+ OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
+ if (addedOutput == nullptr)
+ addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
+ : &tiledLoop.appendOutputOperand(b, producerOutput);
+
+ OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
+ auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
+ auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
+ resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
+ tiledLoop.eraseOperand(b, resultOperand);
+ tiledOperands.push_back(addedBlockArg);
+ }
+ return tiledOperands;
+}
+
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// provides the loop range information for the fused loops. The rest are
/// obtained from the producer itself, since they are not tiled + fused.
@@ -143,8 +203,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
clonedShapes.reserve(producer.getNumShapedOperands());
// Compute subranges for all tensor input/output operands.
- auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
- clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs,
+ clonedShapes.append(makeTiledShapes(b, loc, producer,
+ getTiledOperands(b, producer), ivs,
tileSizes, sizeBounds));
// Append the other operands.
@@ -808,7 +868,7 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
origOpToFusedOp[origOp.getOperation()] = fusedOp;
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
- // Prepare the b for the next insertion point.
+ // Prepare the builder for the next insertion point.
auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
if (!origOp.hasTensorSemantics())
continue;
@@ -844,16 +904,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
// 2. encode destructive updates that may be inplaceable by bufferization.
// To keep the second type of information while letting the unfused op die
// unused, we need to forward the producer output operand.
- for (auto &operand :
- cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
- if (auto opResult = operand.get().dyn_cast<OpResult>())
- if (opResult.getOwner() == origOp)
- operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+ if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
+ for (auto &operand : forOp.getIterOpOperands())
+ if (auto opResult = operand.get().dyn_cast<OpResult>())
+ if (opResult.getOwner() == origOp)
+ operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+ }
}
return fusedOps;
}
-template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
@@ -928,11 +988,9 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
const LinalgTilingOptions &tilingOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
- return tileAndFuseLinalgOpsImpl<scf::ForOp>(b, ops, dependenceGraph,
- tilingOptions);
case LinalgTilingLoopType::ParallelLoops:
- return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(b, ops, dependenceGraph,
- tilingOptions);
+ case LinalgTilingLoopType::TiledLoops:
+ return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
default:;
}
return llvm::None;
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 3775b67c43548..dc8a4acb4b4f9 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -1,15 +1,16 @@
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
- func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
- %arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
- %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
- ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
- return %1 : tensor<?x?xf32>
+ func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
+ %ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
+ %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+ ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
+ return %ABC : tensor<?x?xf32>
}
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
@@ -90,6 +91,64 @@ module {
// CHECK: }
// CHECK: return %[[RESULT]]
+// TLOOP-LABEL: func @matmul_fusion(
+// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+// TLOOP: %[[C32:.*]] = constant 32 : index
+// TLOOP: %[[C64:.*]] = constant 64 : index
+// TLOOP: %[[C16:.*]] = constant 16 : index
+// TLOOP: %[[C0:.*]] = constant 0 : index
+// TLOOP: %[[C1:.*]] = constant 1 : index
+
+// TLOOP: %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+
+// TLOOP: %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]])
+// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]])
+// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[A_:.*]] = %[[A]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[B_:.*]] = %[[B]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor<?x?xf32>)
+// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor<?x?xf32>) {
+
+// TLOOP: %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0]
+// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
+// TLOOP: %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0]
+
+// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
+
+// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]]
+
+// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C64]], %[[C16]])
+// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
+// TLOOP-SAME: %[[C__:.*]] = %[[C_]]: [[TY]])
+// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["parallel", "reduction"] {
+
+// TLOOP: %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]]
+// TLOOP: %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]]
+// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+
+// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
+
+// TLOOP: %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]]
+// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+// TLOOP: linalg.yield %[[RES0]] : [[TY]]
+// TLOOP: }
+// TLOOP: %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
+// TLOOP: linalg.yield %[[RES1]] : [[TY]]
+// TLOOP: }
+// TLOOP: return %[[ABC]] : [[TY]]
+
// -----
module {
@@ -144,6 +203,48 @@ module {
// CHECK: scf.yield %[[YIELD]]
// CHECK: return %[[RESULT]]
+// TLOOP-LABEL: func @matmul_plus_matmul
+// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[AB:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP: %[[C32:.*]] = constant 32 : index
+// TLOOP: %[[C64:.*]] = constant 64 : index
+// TLOOP: %[[C0:.*]] = constant 0 : index
+// TLOOP: %[[C1:.*]] = constant 1 : index
+
+// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]]
+
+// TLOOP: %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]],
+// TLOOP-SAME: %[[AB_:.*]] = %[[AB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) {
+
+// TLOOP: %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]]
+// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
+// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]]
+// TLOOP: %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]]
+
+// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME: ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME: outs(%[[AB_SUB_INIT]] : [[TY]])
+
+// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic
+// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
+
+// TLOOP: %[[RESULT_SUB:.*]] = subtensor_insert
+// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
+
+// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]]
+// TLOOP: }
+// TLOOP: return %[[RESULT]] : [[TY]]
+
// -----
module {
@@ -174,3 +275,53 @@ module {
// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
// CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
+
+
+// TLOOP-LABEL: func @matmul_out_fusion(
+// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0
+// TLOOP-DAG: %[[C32:.*]] = constant 32 : index
+// TLOOP-DAG: %[[C64:.*]] = constant 64 : index
+// TLOOP-DAG: %[[C16:.*]] = constant 16 : index
+// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
+
+// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]])
+// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
+
+// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
+// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
+// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
+
+// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
+// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
+// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
+// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["reduction"] {
+
+// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
+// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
+
+// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
+// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP: }
+// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
+// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
+// TLOOP: }
+// TLOOP: return %[[AB]] : [[TY]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 3e67774ba13a4..4413faca5dc61 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -278,6 +278,13 @@ void registerTestLinalgTensorFusionTransforms() {
"Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.");
}
+void registerTestLinalgTiledLoopFusionTransforms() {
+ PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
+ testTiledLoopFusionTransformsPass(
+ "test-linalg-tiled-loop-fusion-transform-patterns",
+ "Test Linalg on tensor fusion transformation "
+ "patterns by applying them greedily.");
+}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
"test-linalg-greedy-fusion",
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1ee3b6a0f2ace..23bfe775cae2c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,6 +81,7 @@ void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
+void registerTestLinalgTiledLoopFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTileAndFuseSequencePass();
@@ -159,6 +160,7 @@ void registerTestPasses() {
test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgTensorFusionTransforms();
+ test::registerTestLinalgTiledLoopFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
test::registerTestLinalgTileAndFuseSequencePass();
More information about the Mlir-commits
mailing list