[Mlir-commits] [mlir] [mlir][vector] shape_cast(constant) -> constant fold for non-splats (PR #145539)

James Newling llvmlistbot at llvm.org
Thu Jul 31 08:13:48 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/145539

>From 3533e62b7f52add07aebf023b2868fc46bb45025 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 24 Jun 2025 08:24:40 -0700
Subject: [PATCH] generalize constant folder

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  9 +++---
 mlir/test/Dialect/Vector/canonicalize.mlir | 34 ++++++++++++++++++++--
 2 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8789f55707267..ee9819f84f74c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   }
 
   // shape_cast(constant) -> constant
-  if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
-    return splatAttr.reshape(getType());
+  if (auto denseAttr =
+          dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+    return denseAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
-  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
     return ub::PoisonAttr::get(getContext());
-  }
 
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9cfebd545400e..56996b5f364a5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
 
 // -----
 
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
 //       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
   %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
   %cst_1 = arith.constant dense<1> : vector<12x2xi32>
   %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+//               CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+//               CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+  %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+  %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+  return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+//  CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+//  CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+//      CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+  %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+  %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+  return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: shape_cast_poison
 //       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>



More information about the Mlir-commits mailing list