[Mlir-commits] [mlir] [MLIR] fix shape.broadcast canonicalize with all empty shape operands (PR #118941)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 6 01:06:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chenhui Huang (YellowHCH)
<details>
<summary>Changes</summary>
Example: all the operands of `shape.broadcast` are empty tensors.
```
func.func @<!-- -->all_empty(%arg0: tensor<f32>) -> tensor<0xindex> {
%1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
%2 = shape.const_shape [] : tensor<0xindex>
%3 = shape.broadcast %1, %2, %1 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
return %3 : tensor<0xindex>
}
```
One can reproduce crash when canonicalize with *down-top* order, cmd like this:
`mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=0" %s`
The root cause is when all operands are empty tensor, `RemoveEmptyShapeOperandsPattern` would filter out all operands.
---
Full diff: https://github.com/llvm/llvm-project/pull/118941.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+6)
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+16)
``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index bebfaa8c1ea822..8779fe837c7ae7 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -699,6 +699,12 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
isPotentiallyNonEmptyShape);
// Reduce op to equivalent without empty shape operands.
+ if (newOperands.empty()) {
+ rewriter.replaceOpWithNewOp<ConstShapeOp>(
+ op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
+ return success();
+ }
+
if (newOperands.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
op->getAttrs());
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5b98a7790debf2..d55e0d8291cfef 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=1" %s | FileCheck %s
// CHECK-LABEL: func @f
func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
@@ -134,6 +135,21 @@ func.func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
// -----
+// All operands are known empty shapes.
+// CHECK-LABEL: @all_empty
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<i1>)
+func.func @all_empty(%arg0: tensor<f32>, %arg1: tensor<i1>) -> tensor<0xindex> {
+ // CHECK: %[[CST:.*]] = shape.const_shape [] : tensor<0xindex>
+ // CHECK: return %[[CST]] : tensor<0xindex>
+ %1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
+ %2 = shape.shape_of %arg1 : tensor<i1> -> tensor<0xindex>
+ %3 = shape.const_shape [] : tensor<0xindex>
+ %4 = shape.broadcast %1, %2, %3 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
+ return %4 : tensor<0xindex>
+}
+
+// -----
+
// Partial folding.
// CHECK-LABEL: @partial_folding
// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)
``````````
</details>
https://github.com/llvm/llvm-project/pull/118941
More information about the Mlir-commits
mailing list