[Mlir-commits] [mlir] 19b9c74 - [mlir] Return new scf.forall handle in fuse_into_containing_op
Harsh Menon
llvmlistbot at llvm.org
Thu May 25 09:47:04 PDT 2023
Author: Harsh Menon
Date: 2023-05-25T09:34:46-07:00
New Revision: 19b9c74b42824c5f7d73481d9c1fe8e385a4426c
URL: https://github.com/llvm/llvm-project/commit/19b9c74b42824c5f7d73481d9c1fe8e385a4426c
DIFF: https://github.com/llvm/llvm-project/commit/19b9c74b42824c5f7d73481d9c1fe8e385a4426c.diff
LOG: [mlir] Return new scf.forall handle in fuse_into_containing_op
Since the scf.forall is now consumed by the fuse into
containing op, we need to return a handle to the new scf.forall.
This patch does that and also ensures that the new bbArg
added to the scf.forall is used in its body.
Differential Revision: https://reviews.llvm.org/D151418
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 179cafdbf274f..06ef84b43f04b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -183,7 +183,8 @@ def FuseIntoContainingOp :
let arguments = (ins TransformHandleTypeInterface:$producer_op,
TransformHandleTypeInterface:$containing_op);
- let results = (outs TransformHandleTypeInterface:$fused_op);
+ let results = (outs TransformHandleTypeInterface:$fused_op,
+ TransformHandleTypeInterface:$new_containing_op);
let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
" `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4f476d1053827..a63f6647cf7b3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -344,7 +344,8 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
Value producerOp,
Value containingOp) {
result.addOperands({producerOp, containingOp});
- result.addTypes(transform::AnyOpType::get(builder.getContext()));
+ auto resultType = transform::AnyOpType::get(builder.getContext());
+ result.addTypes({resultType, resultType});
}
/// Add new operands to the forall op for users of the producerOp
@@ -388,8 +389,16 @@ static Operation *replaceForAllWithNewSignature(
newforallOp.getRegion().takeBody(forallOp.getRegion());
// Add additional block argument for new value being returned
+ // and replaces all uses of the new output with corresponding bbArg
+ // inside the scf.forall to enable fusion into this new scf.forall.
newforallOp.getBody()->addArgument(newOuts.back().getType(),
newOuts.back().getLoc());
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return newforallOp->isProperAncestor(op);
+ });
// Fix terminator
scf::InParallelOp terminatorOp = newforallOp.getTerminator();
@@ -749,14 +758,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
}
results.set(cast<OpResult>(getFusedOp()), fusedOps);
+ results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
return DiagnosedSilenceableFailure::success();
}
void transform::FuseIntoContainingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getProducerOp(), effects);
- onlyReadsHandle(getContainingOp(), effects);
- producesHandle(getFusedOp(), effects);
+ consumesHandle(getContainingOp(), effects);
+ producesHandle(getResults(), effects);
modifiesPayload(effects);
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index f3f480247f7ff..d67b4802e772a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -48,7 +48,7 @@ module {
// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -92,7 +92,7 @@ module {
// tensor.empty is not tileable. The op is cloned and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -139,7 +139,7 @@ module {
// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -188,7 +188,7 @@ module {
// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
@@ -249,7 +249,7 @@ module {
// linalg.generic is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -285,7 +285,7 @@ module {
%2 = transform.merge_handles %0, %0 : !transform.any_op
// It shouldn't be a problem to fuse this handle.
- transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
@@ -351,7 +351,7 @@ module {
// linalg.generic is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -417,7 +417,7 @@ module {
// linalg.generic is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
}
}
@@ -482,6 +482,81 @@ module {
// linalg.generic is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
- : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+#map3 = affine_map<(d0) -> (d0)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_using_new_handle
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ %0 = linalg.generic {
+ indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.addf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+
+ %1 = linalg.generic {
+ indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
+ } ins(%0 : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.mulf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %2 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %4 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %5 = affine.min #map2(%i)[%d0, %idx]
+ %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+ // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
+ %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
+ func.return %3, %1 : tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">)
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
+ %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
index e824a55b92a64..7dd31835ce84f 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
@@ -52,7 +52,7 @@ module {
// Fuse all producers.
transform.structured.fuse_into_containing_op %producers into %forall_op
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
@@ -112,6 +112,6 @@ module {
// Fuse all producers.
transform.structured.fuse_into_containing_op %reversed_producers into %forall_op
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
More information about the Mlir-commits
mailing list