[Mlir-commits] [mlir] [MLIR] Fix canonicalization pattern for 'shape.shape_of' (PR #134234)
Alaa Ali
llvmlistbot at llvm.org
Thu Apr 3 18:20:47 PDT 2025
https://github.com/alaa-ali updated https://github.com/llvm/llvm-project/pull/134234
>From 75de3afe3720c7c4f1c2ae4f484dfa9b9467925a Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 07:27:25 -0400
Subject: [PATCH 1/8] Fix canonicalization pattern for shape.shape_of
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 18 ++++++++++---
mlir/test/Dialect/Shape/canonicalize.mlir | 33 +++++++++++++++++++++--
2 files changed, 46 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 10ba808cd26c2..b8eac7c86797b 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1734,10 +1734,22 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
// Operand 'shape' of 'tensor.reshape' may now be used as the result of
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
// formed IR, it may not be identical (dynamically vs statically shaped),
- // in which case it needs to be cast first.
+ // in which case it needs to be cast first using 'tensor.cast'.
+ // Additionally, it may not have identical element type (i32 vs index)
+ // while it has identical shaped type (dynamic vs static), in which case it needs
+ // to be cast first using 'arith.index_cast'.
+ // Note: 'shape.shape_of' op result must be shape or extent tensor.
Value shape = tensorReshapeOp.getShape();
- if (op.getType() != shape.getType())
- shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
+
+ auto opTensorType = llvm::dyn_cast<RankedTensorType>(op.getType());
+ auto shapeTensorType = llvm::dyn_cast<RankedTensorType>(shape.getType());
+
+ if (op.getType() != shape.getType()) {
+ if (opTensorType.getElementType() == shapeTensorType.getElementType())
+ shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
+ else if (!isExtentTensorType(shape.getType()))
+ shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), op.getType(), shape);
+ }
rewriter.replaceOp(op, shape);
return success();
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9b25468b3ab1e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1389,10 +1389,25 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
// -----
-// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
+// Check statically shaped types, with element types i32 to index.
+// CHECK-LABEL: func @shape_of_from_reshape_compatible_types1
+// CHECK-SAME: %[[INPUT:.*]]: tensor<?x1xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
+func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
+ // CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex>
+ // CHECK: return %[[CAST_SHAPE]] : tensor<3xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
+ %1 = shape.shape_of %0 : tensor<?x1x1xf32> -> tensor<3xindex>
+ return %1 : tensor<3xindex>
+}
+
+// -----
+
+// Check similar element types, with statically shaped to dynamically shaped.
+// CHECK-LABEL: func @shape_of_from_reshape_compatible_types2
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
-func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
+func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
// CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
@@ -1402,6 +1417,20 @@ func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: t
// -----
+// Check similar element types, with dynamically shaped to statically shaped.
+// CHECK-LABEL: func @shape_of_from_reshape_compatible_types3
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_compatible_types3(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
+ // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<5xindex>
+ // CHECK: return %[[CAST_SHAPE]] : tensor<5xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex>
+ return %1 : tensor<5xindex>
+}
+
+// -----
+
// CHECK-LABEL: func @shape_of_from_reshape_nofold
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
>From 394735f79035ae8586521302b1b89fc99462d26d Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 08:34:15 -0400
Subject: [PATCH 2/8] dyn_cast check
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b8eac7c86797b..f9302256eefe2 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1741,11 +1741,13 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
// Note: 'shape.shape_of' op result must be shape or extent tensor.
Value shape = tensorReshapeOp.getShape();
- auto opTensorType = llvm::dyn_cast<RankedTensorType>(op.getType());
- auto shapeTensorType = llvm::dyn_cast<RankedTensorType>(shape.getType());
+ auto opTensorTy = llvm::dyn_cast<RankedTensorType>(op.getType());
+ auto shapeTensorTy = llvm::dyn_cast<RankedTensorType>(shape.getType());
+ if (!opTensorTy || !shapeTensorTy)
+ return failure();
if (op.getType() != shape.getType()) {
- if (opTensorType.getElementType() == shapeTensorType.getElementType())
+ if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
else if (!isExtentTensorType(shape.getType()))
shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), op.getType(), shape);
>From e12e2e4534e059f11070b3b5901d37c969031f47 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 17:03:43 -0400
Subject: [PATCH 3/8] use llvm::cast
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f9302256eefe2..052b6cdb3eee7 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1741,10 +1741,8 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
// Note: 'shape.shape_of' op result must be shape or extent tensor.
Value shape = tensorReshapeOp.getShape();
- auto opTensorTy = llvm::dyn_cast<RankedTensorType>(op.getType());
- auto shapeTensorTy = llvm::dyn_cast<RankedTensorType>(shape.getType());
- if (!opTensorTy || !shapeTensorTy)
- return failure();
+ auto opTensorTy = llvm::cast<RankedTensorType>(op.getType());
+ auto shapeTensorTy = llvm::cast<RankedTensorType>(shape.getType());
if (op.getType() != shape.getType()) {
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
>From 137dcd06ccb214698bd3f19f19ed3d55bf19fdfc Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaa.leithy at gmail.com>
Date: Thu, 3 Apr 2025 18:11:52 -0400
Subject: [PATCH 4/8] Update mlir/lib/Dialect/Shape/IR/Shape.cpp
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 052b6cdb3eee7..d0b064e6fc1df 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1741,8 +1741,8 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
// Note: 'shape.shape_of' op result must be shape or extent tensor.
Value shape = tensorReshapeOp.getShape();
- auto opTensorTy = llvm::cast<RankedTensorType>(op.getType());
- auto shapeTensorTy = llvm::cast<RankedTensorType>(shape.getType());
+ auto opTensorTy = cast<RankedTensorType>(op.getType());
+ auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
if (op.getType() != shape.getType()) {
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
>From 5cf4a388840d55b64ae5fb32eee42f7f02a6603d Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaa.leithy at gmail.com>
Date: Thu, 3 Apr 2025 18:16:45 -0400
Subject: [PATCH 5/8] Update mlir/lib/Dialect/Shape/IR/Shape.cpp
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index d0b064e6fc1df..f66a589c72f7e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1744,11 +1744,11 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
auto opTensorTy = cast<RankedTensorType>(op.getType());
auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
- if (op.getType() != shape.getType()) {
+ if (opTensorTy != shapeTensorTy) {
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
- shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
- else if (!isExtentTensorType(shape.getType()))
- shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), op.getType(), shape);
+ shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
+ else if (!isExtentTensorType(shapeTensorTy))
+ shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
}
rewriter.replaceOp(op, shape);
>From d2db005b3497dc1c4c9e51b7a6e42a81edaa70c8 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 18:41:55 -0400
Subject: [PATCH 6/8] fix code formatting issue
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f66a589c72f7e..f670614806dbd 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1735,20 +1735,21 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
// formed IR, it may not be identical (dynamically vs statically shaped),
// in which case it needs to be cast first using 'tensor.cast'.
- // Additionally, it may not have identical element type (i32 vs index)
- // while it has identical shaped type (dynamic vs static), in which case it needs
- // to be cast first using 'arith.index_cast'.
- // Note: 'shape.shape_of' op result must be shape or extent tensor.
+ // Additionally, it may not have identical element type (i32 vs index)
+ // while it has identical shaped type (dynamic vs static), in which case it
+ // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
+ // op result must be shape or extent tensor.
Value shape = tensorReshapeOp.getShape();
auto opTensorTy = cast<RankedTensorType>(op.getType());
auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
if (opTensorTy != shapeTensorTy) {
- if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
- shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
- else if (!isExtentTensorType(shapeTensorTy))
- shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
+ if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
+ shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
+ else if (!isExtentTensorType(shapeTensorTy))
+ shape =
+ rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
}
rewriter.replaceOp(op, shape);
>From 89a8ffad8aa20efd2258292eec30e08745fa2aa4 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 20:56:36 -0400
Subject: [PATCH 7/8] update LIT test names
---
mlir/test/Dialect/Shape/canonicalize.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9b25468b3ab1e..4a65edb3bc1bc 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1393,7 +1393,7 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
// CHECK-LABEL: func @shape_of_from_reshape_compatible_types1
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x1xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
-func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
+func.func @shape_of_from_reshape_int_to_index(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
// CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex>
// CHECK: return %[[CAST_SHAPE]] : tensor<3xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
@@ -1407,7 +1407,7 @@ func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor<?x1xf32>, %arg1
// CHECK-LABEL: func @shape_of_from_reshape_compatible_types2
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
-func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
+func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
// CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
@@ -1421,7 +1421,7 @@ func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1:
// CHECK-LABEL: func @shape_of_from_reshape_compatible_types3
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
-func.func @shape_of_from_reshape_compatible_types3(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
+func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<5xindex>
// CHECK: return %[[CAST_SHAPE]] : tensor<5xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
>From 245263ca2d85a4e906fe2ce27cf4366f99a4d4c5 Mon Sep 17 00:00:00 2001
From: Alaa Ali <alaaali at ah-alaaali-l.dhcp.mathworks.com>
Date: Thu, 3 Apr 2025 21:04:06 -0400
Subject: [PATCH 8/8] minor fix
---
mlir/test/Dialect/Shape/canonicalize.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 4a65edb3bc1bc..71a80de8adfb9 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1390,7 +1390,7 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
// -----
// Check statically shaped types, with element types i32 to index.
-// CHECK-LABEL: func @shape_of_from_reshape_compatible_types1
+// CHECK-LABEL: func @shape_of_from_reshape_int_to_index
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x1xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
func.func @shape_of_from_reshape_int_to_index(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
@@ -1404,7 +1404,7 @@ func.func @shape_of_from_reshape_int_to_index(%arg0: tensor<?x1xf32>, %arg1: ten
// -----
// Check similar element types, with statically shaped to dynamically shaped.
-// CHECK-LABEL: func @shape_of_from_reshape_compatible_types2
+// CHECK-LABEL: func @shape_of_from_reshape_static_to_dynamic
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
@@ -1418,7 +1418,7 @@ func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1:
// -----
// Check similar element types, with dynamically shaped to statically shaped.
-// CHECK-LABEL: func @shape_of_from_reshape_compatible_types3
+// CHECK-LABEL: func @shape_of_from_reshape_dynamic_to_static
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
More information about the Mlir-commits
mailing list