[Mlir-commits] [mlir] [mlir][Mesh] Fix invalid IR in rewrite pattern (PR #78094)
Matthias Springer
llvmlistbot at llvm.org
Sun Jan 14 02:58:37 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/78094
This commit fixes `test/Dialect/Mesh/folding.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
```
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: error: Unexpected number of results 0. Expected 2.
%0:2 = mesh.cluster_shape @mesh1 : index, index
^
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: note: see current operation: "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
mlir-asm-printer: Verifying operation: builtin.module
Unexpected number of results 0. Expected 2.
mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form
"builtin.module"() ({
"mesh.cluster"() <{dim_sizes = array<i64: 2, 3>, rank = 2 : i64, sym_name = "mesh1"}> : () -> ()
"func.func"() <{function_type = () -> (index, index), sym_name = "cluster_shape_op_folding_all_axes_static_mesh"}> ({
%0 = "arith.constant"() <{value = 2 : index}> : () -> index
%1 = "arith.constant"() <{value = 3 : index}> : () -> index
"mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
%2:2 = "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> (index, index)
"func.return"(%0, %1) : (index, index) -> ()
}) : () -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```
If `axes` is empty, the op verifier assumes that all dimensions are queried. (Expected 2 results.)
>From 1926e42b2991a553d99c5d3a1690cf94bc0f5584 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sun, 14 Jan 2024 10:57:06 +0000
Subject: [PATCH] [mlir][Mesh] Fix invalid IR in rewrite pattern
This commit fixes `test/Dialect/Mesh/folding.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
```
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: error: Unexpected number of results 0. Expected 2.
%0:2 = mesh.cluster_shape @mesh1 : index, index
^
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: note: see current operation: "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
mlir-asm-printer: Verifying operation: builtin.module
Unexpected number of results 0. Expected 2.
mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form
"builtin.module"() ({
"mesh.cluster"() <{dim_sizes = array<i64: 2, 3>, rank = 2 : i64, sym_name = "mesh1"}> : () -> ()
"func.func"() <{function_type = () -> (index, index), sym_name = "cluster_shape_op_folding_all_axes_static_mesh"}> ({
%0 = "arith.constant"() <{value = 2 : index}> : () -> index
%1 = "arith.constant"() <{value = 3 : index}> : () -> index
"mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
%2:2 = "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> (index, index)
"func.return"(%0, %1) : (index, index) -> ()
}) : () -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```
If `axes` is empty, the op verifier assumes that all dimensions are queried. (Expected 2 results.)
---
.../lib/Dialect/Mesh/Transforms/Simplifications.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c9275ad5ad4551..67e1bf6320dbf3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -103,13 +103,14 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
}
// Leave only the dynamic mesh axes to be queried.
- ClusterShapeOp newShapeOp =
- builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
- for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
- newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+ if (!newShapeOpMeshAxes.empty()) {
+ ClusterShapeOp newShapeOp =
+ builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+ newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+ }
}
-
- rewriter.replaceAllUsesWith(op.getResults(), newResults);
+ rewriter.replaceOp(op, newResults);
return success();
}
More information about the Mlir-commits
mailing list