[Mlir-commits] [mlir] [mlir][transform] Emit error message with `emitSilenceableFailure` (PR #86146)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 21 14:30:24 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/86146
>From a535c0e239f6d704a4300d2f668f5e6cbe240532 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 21 Mar 2024 11:12:10 -0500
Subject: [PATCH 1/3] [mlir][transform] Emit error message with
`emitSilencableFailure`
---
.../Linalg/TransformOps/LinalgTransformOps.cpp | 9 ++++-----
.../test/Dialect/Linalg/flatten-unsupported.mlir | 16 ++++++++++++++++
2 files changed, 20 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/flatten-unsupported.mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ecf9983124821a..06acade06d771b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,11 +3269,10 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (!isElementwise(target)) {
- failed(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
- return emitDefaultSilenceableFailure(target);
- }
+ if (!isElementwise(target))
+
+ return mlir::emitSilenceableFailure(target->getLoc())
+ << "only elementwise flattening is supported";
// If rank <= 1, do nothing
if (target.getNumLoops() <= 1) {
results.push_back(target);
diff --git a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
new file mode 100644
index 00000000000000..476733fc74f5cf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
+
+func.func @non_elementwise(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+ // expected-error @+1 {{only elementwise flattening is supported}}
+ linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>) outs(%arg2: memref<2x4xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From e9de397e39dc4193509ce287cf846d723e9fce2a Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 21 Mar 2024 11:32:37 -0500
Subject: [PATCH 2/3] Address review comment
---
mlir/test/Dialect/Linalg/flatten-unsupported.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
index 476733fc74f5cf..e86de6c447d7be 100644
--- a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
func.func @non_elementwise(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
- // expected-error @+1 {{only elementwise flattening is supported}}
+ // expected-error @below {{only elementwise flattening is supported}}
linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>) outs(%arg2: memref<2x4xf32>)
return
}
>From 8af26a78bc89472f1d72197cadd42fd12ab98caa Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 21 Mar 2024 16:29:47 -0500
Subject: [PATCH 3/3] Improve remaining silenceable failure message and test
---
.../Linalg/TransformOps/LinalgTransformOps.cpp | 7 +++++--
.../Dialect/Linalg/flatten-unsupported.mlir | 17 +++++++++++++++++
2 files changed, 22 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 06acade06d771b..c2187cc3eaf35f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3270,20 +3270,23 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
if (!isElementwise(target))
-
return mlir::emitSilenceableFailure(target->getLoc())
<< "only elementwise flattening is supported";
+
// If rank <= 1, do nothing
if (target.getNumLoops() <= 1) {
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
+
+ // Attempt to flatten all dims to one
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
- return emitDefaultSilenceableFailure(target);
+ return mlir::emitSilenceableFailure(target->getLoc())
+ << "Attempted to flatten, but failed";
results.push_back(maybeFlattened->collapsedOp);
rewriter.replaceOp(target, maybeFlattened->results);
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
index e86de6c447d7be..8fa3494ea8bc68 100644
--- a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir
@@ -14,3 +14,20 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @unsupported_memref(%arg0: memref<32x7xf32, strided<[7, 2]>>, %arg1: memref<32x7xf32, strided<[7, 2]>>, %arg2: memref<32x7xf32, strided<[7, 2]>>) {
+ // expected-error @below {{Attempted to flatten, but failed}}
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32, strided<[7, 2]>>, memref<32x7xf32, strided<[7, 2]>>) outs(%arg2: memref<32x7xf32, strided<[7, 2]>>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list