[Mlir-commits] [mlir] 1e18815 - [MLIR] fix shape.broadcast canonicalize with all empty shape operands (#118941)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 20 02:32:14 PST 2024


Author: Chenhui Huang
Date: 2024-12-20T19:32:11+09:00
New Revision: 1e18815fdc13bb1f8b0b87acd8abf62b5cf70d53

URL: https://github.com/llvm/llvm-project/commit/1e18815fdc13bb1f8b0b87acd8abf62b5cf70d53
DIFF: https://github.com/llvm/llvm-project/commit/1e18815fdc13bb1f8b0b87acd8abf62b5cf70d53.diff

LOG: [MLIR] fix shape.broadcast canonicalize with all empty shape operands (#118941)

Example: all the operands of `shape.broadcast` are empty tensors.
```
func.func @all_empty(%arg0: tensor<f32>) -> tensor<0xindex> {
  %1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
  %2 = shape.const_shape [] : tensor<0xindex>
  %3 = shape.broadcast %1, %2, %1 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
  return %3 : tensor<0xindex>
}
```

One can reproduce crash when canonicalize with *down-top* order, cmd
like this:
`mlir-opt -split-input-file -allow-unregistered-dialect
-canonicalize="test-convergence top-down=0" %s`

The root cause is when all operands are empty tensor,
`RemoveEmptyShapeOperandsPattern` would filter out all operands.

Co-authored-by: Kai Sasaki <lewuathe at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index bebfaa8c1ea822..65efc88e9c4033 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -698,6 +698,14 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
     auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
                                                  isPotentiallyNonEmptyShape);
 
+    // Replace the op with empty shape constant if all operants are reduced to
+    // be empty.
+    if (newOperands.empty()) {
+      rewriter.replaceOpWithNewOp<ConstShapeOp>(
+          op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
+      return success();
+    }
+
     // Reduce op to equivalent without empty shape operands.
     if (newOperands.size() < op.getNumOperands()) {
       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5b98a7790debf2..cf439c9c1b8545 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=0" %s | FileCheck %s
 
 // CHECK-LABEL: func @f
 func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
@@ -134,6 +135,21 @@ func.func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
 
 // -----
 
+// All operands are known empty shapes.
+// CHECK-LABEL: @all_empty
+// CHECK-SAME:  (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<i1>)
+func.func @all_empty(%arg0: tensor<f32>, %arg1: tensor<i1>) -> tensor<0xindex> {
+  // CHECK: %[[CST:.*]] = shape.const_shape [] : tensor<0xindex>
+  // CHECK: return %[[CST]] : tensor<0xindex>
+  %1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
+  %2 = shape.shape_of %arg1 : tensor<i1> -> tensor<0xindex>
+  %3 = shape.const_shape [] : tensor<0xindex>
+  %4 = shape.broadcast %1, %2, %3 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
+  return %4 : tensor<0xindex>
+}
+
+// -----
+
 // Partial folding.
 // CHECK-LABEL: @partial_folding
 // CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)


        


More information about the Mlir-commits mailing list