[Mlir-commits] [mlir] c214cee - [mlir] improve error handling in Linalg op splitting
Alex Zinenko
llvmlistbot at llvm.org
Fri Jan 6 09:35:15 PST 2023
Author: Alex Zinenko
Date: 2023-01-06T18:35:08+01:00
New Revision: c214cee772f7ce9c9128384b2dd1640f12c9feac
URL: https://github.com/llvm/llvm-project/commit/c214cee772f7ce9c9128384b2dd1640f12c9feac
DIFF: https://github.com/llvm/llvm-project/commit/c214cee772f7ce9c9128384b2dd1640f12c9feac.diff
LOG: [mlir] improve error handling in Linalg op splitting
In several cases, the splitting may be known to be a noop, i.e., produce
no second part. Thread this information through the transform utilities
to the transform dialect, and differentiate it from the error state.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D141138
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Split.cpp
mlir/test/Dialect/Linalg/transform-op-split.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index cc2b63d2786ee..3be4ee72cd5a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -162,6 +162,9 @@ FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
/// Split the given `op` into two parts along the given iteration space
/// `dimension` at the specified `splitPoint`, and return the two parts.
+/// If the second part is statically known to be empty, do not create it
+/// and return nullptr instead. Error state is signalled by returning
+/// a pair of nullptrs.
///
/// For example, the following op:
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dddee38b32710..f170d0bd1199b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1043,6 +1043,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
// Split each target operation.
SmallVector<Operation *> first, second;
+ Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, splitPoints)) {
Operation *target = std::get<0>(pair);
auto linalgOp = dyn_cast<LinalgOp>(target);
@@ -1067,6 +1068,32 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
getDimension(), std::get<1>(pair));
+
+ // Propagate errors.
+ if (!first.back() && !second.back()) {
+ auto diag = emitDefiniteFailure() << "internal failure in splitting";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Do not add null second parts.
+ if (!second.back()) {
+ noSecondPart = target;
+ second.pop_back();
+ }
+ }
+
+ if (second.size() != first.size() && !second.empty()) {
+ results.set(getFirst().cast<OpResult>(), {});
+ results.set(getSecond().cast<OpResult>(), {});
+ auto diag =
+ emitSilenceableError()
+ << "splitting does not produce the second part for a subset of targets";
+ diag.attachNote() << "expected splitting to produce the second part of all "
+ "or none of the targets";
+ diag.attachNote(noSecondPart->getLoc())
+ << "first target with no second part";
+ return diag;
}
results.set(getFirst().cast<OpResult>(), first);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index f58182e71c96e..c8c9c0bd4af89 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -128,6 +128,10 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
dimension, remainingSize, totalOffset, secondResults);
+ // Propagate any errors in part creation.
+ if (!firstPart || !secondPart)
+ return {TilingInterface(), TilingInterface()};
+
// Replace the original op with the results of the two newly created ops.
rewriter.replaceOp(op, secondResults);
return {firstPart, secondPart};
diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index ba5ff6ec6bfec..1d7f15efe73cb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -46,6 +46,16 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
return %0 : tensor<100xf32>
}
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
// CHECK-LABEL: @one_d_static_overflow
// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
@@ -268,3 +278,45 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
return %0 : tensor<100xf32>
}
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ // expected-error @below {{splitting does not produce the second part for a subset of targets}}
+ // expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
+ %1:2 = transform.structured.split %0 after 142 { dimension = 0 }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+func.func @split_one_but_not_other(
+ %arg0: tensor<100xf32>, %arg1: tensor<100xf32>,
+ %arg2: tensor<200xf32>, %arg3: tensor<200xf32>)
+ -> (tensor<100xf32>, tensor<200xf32>) {
+ // expected-note @below {{first target with no second part}}
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ %i = linalg.index 0 : index
+ %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
+ linalg.yield %call_res : f32
+ } -> tensor<100xf32>
+
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ %i = linalg.index 0 : index
+ %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
+ linalg.yield %call_res : f32
+ } -> tensor<200xf32>
+
+ return %0, %1 : tensor<100xf32>, tensor<200xf32>
+}
+
More information about the Mlir-commits
mailing list