[Mlir-commits] [mlir] [mlir][vector] shape_cast(constant) -> constant fold for non-splats (PR #145539)
James Newling
llvmlistbot at llvm.org
Thu Jul 31 08:12:14 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/145539
>From 19548decc330b771bbc31cf28cc7970ca6e88580 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 24 Jun 2025 08:24:40 -0700
Subject: [PATCH] generalize constant folder
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++---
mlir/test/Dialect/Vector/canonicalize.mlir | 34 ++++++++++++++++++++--
2 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8789f55707267..ee9819f84f74c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
}
// shape_cast(constant) -> constant
- if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ if (auto denseAttr =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+ return denseAttr.reshape(getType());
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9cfebd545400e..56996b5f364a5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
More information about the Mlir-commits
mailing list