[Mlir-commits] [mlir] 5812516 - [MLIR] Fix canonicalization pattern for 'shape.shape_of' (#134234)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 4 02:47:02 PDT 2025


Author: Alaa Ali
Date: 2025-04-04T11:46:58+02:00
New Revision: 5812516ae2e034d70b0cca20b95d627e163b4567

URL: https://github.com/llvm/llvm-project/commit/5812516ae2e034d70b0cca20b95d627e163b4567
DIFF: https://github.com/llvm/llvm-project/commit/5812516ae2e034d70b0cca20b95d627e163b4567.diff

LOG: [MLIR] Fix canonicalization pattern for 'shape.shape_of' (#134234)

This PR will fix a bug in a canonicalization pattern (operation
shape.shape_of: shape of reshape)

```
// Before
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
  return %0 : tensor<3xindex>
}
//This is will error out as follows:
error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
       ^
note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex>
```

```
// After
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex>
  return %0 : tensor<3xindex>
}
```
See file canonicalize.mlir in the change list for an example.

For the context, this bug was found while running a test on Keras 3, the
canonicalizer errors out due to an invalid tensor.cast operation when
the batch size is dynamic.
The operands of the op are tensor<3xi32> cast to tensor<3xindex>.
This change is related to a previous PR:
https://github.com/llvm/llvm-project/pull/98531

---------

Co-authored-by: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 10ba808cd26c2..f670614806dbd 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1734,10 +1734,23 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
     // Operand 'shape' of 'tensor.reshape' may now be used as the result of
     // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
     // formed IR, it may not be identical (dynamically vs statically shaped),
-    // in which case it needs to be cast first.
+    // in which case it needs to be cast first using 'tensor.cast'.
+    // Additionally, it may not have identical element type (i32 vs index)
+    // while it has identical shaped type (dynamic vs static), in which case it
+    // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
+    // op result must be shape or extent tensor.
     Value shape = tensorReshapeOp.getShape();
-    if (op.getType() != shape.getType())
-      shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
+
+    auto opTensorTy = cast<RankedTensorType>(op.getType());
+    auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
+
+    if (opTensorTy != shapeTensorTy) {
+      if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
+        shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
+      else if (!isExtentTensorType(shapeTensorTy))
+        shape =
+            rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
+    }
 
     rewriter.replaceOp(op, shape);
     return success();

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..b42fa75e4112d 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1389,10 +1389,25 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
 
 // -----
 
-// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
+// Check statically shaped types, with element types i32 to index.
+// CHECK-LABEL: func @shape_of_from_reshape_int_to_index
+// CHECK-SAME: %[[INPUT:.*]]: tensor<?x1xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
+func.func @shape_of_from_reshape_int_to_index(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
+  // CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex>
+  // CHECK: return %[[CAST_SHAPE]] : tensor<3xindex>
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
+    %1 = shape.shape_of %0 : tensor<?x1x1xf32> -> tensor<3xindex>
+    return %1 : tensor<3xindex>
+}
+
+// -----
+
+// Check similar element types, with statically shaped to dynamically shaped.
+// CHECK-LABEL: func @shape_of_from_reshape_static_to_dynamic
 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
 // CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
-func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
+func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
   // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
   // CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
   %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
@@ -1402,6 +1417,33 @@ func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: t
 
 // -----
 
+// Check similar element types, with dynamically shaped to statically shaped.
+// CHECK-LABEL: func @shape_of_from_reshape_dynamic_to_static
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
+  // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<5xindex>
+  // CHECK: return %[[CAST_SHAPE]] : tensor<5xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex>
+  return %1 : tensor<5xindex>
+}
+
+// -----
+
+// Check similar element types and similar static shape.
+// CHECK-LABEL: func @shape_of_from_reshape_identical_types
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
+func.func @shape_of_from_reshape_identical_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<5xindex> {
+  // CHECK: return %[[SHAPE]] : tensor<5xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex>
+  return %1 : tensor<5xindex>
+}
+
+// -----
+
 // CHECK-LABEL: func @shape_of_from_reshape_nofold
 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
 // CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>


        


More information about the Mlir-commits mailing list