[Mlir-commits] [mlir] [mlir][vector] shape_cast(constant) -> constant fold for non-splats (PR #145539)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 09:01:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: James Newling (newling)
<details>
<summary>Changes</summary>
The folder `shape_cast(splat constant) -> splat constant` was first introduced [here](https://github.com/llvm/llvm-project/commit/36480657d8ce97836f76bf5fa8c36677b9cdc19a#diff-484cea976e0c96459027c951733bf2d22d34c5a0c0de6f577069870ef4588983R2600) (Nov 2020). In that commit there is a comment to _Only handle splat for now_. Based on that I assume the intention was to, at a later time, support a general `shape_cast(constant) -> constant` folder. That is what this PR does
Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:
```mlir
func.func @<!-- -->foo() -> (vector<4xi32>, vector<2x2xi32>) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
%0 = vector.shape_cast %cst : vector<4xi32> to vector<2x2xi32>
return %cst, %0 : vector<4xi32>, vector<2x2xi32>
}
```
gets folded with this new folder to
```mlir
func.func @<!-- -->foo() -> (vector<4xi32>, vector<2x2xi32>) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
%cst_0 = arith.constant dense<[[1, 2], [3, 4]]> : vector<2x2xi32> # 'large' constant 2
return %cst, %cst_0 : vector<4xi32>, vector<2x2xi32>
}
```
Notes on the above case:
1) This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the `DenseIntOrFPElementsAttrStorage` [constructor](https://github.com/llvm/llvm-project/blob/da5c442550a3823fff05c14300c1664d0fbf68c8/mlir/lib/IR/AttributeDetail.h#L59)) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.
2) This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.
---
Full diff: https://github.com/llvm/llvm-project/pull/145539.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4-5)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+34-4)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..ddc80063fd340 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5881,14 +5881,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
}
// shape_cast(constant) -> constant
- if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ if (auto denseAttr =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+ return denseAttr.reshape(getType());
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..a06a98ee1b93b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1219,11 +1219,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1233,6 +1233,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -1549,7 +1579,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
- ) -> vector<4x2xf32>
+ ) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1636,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
- -> vector<4x2x6xf32>
+ -> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/145539
More information about the Mlir-commits
mailing list