[Mlir-commits] [mlir] [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant (PR #147691)
Mengmeng Sun
llvmlistbot at llvm.org
Tue Jul 15 03:12:34 PDT 2025
https://github.com/MengmSun updated https://github.com/llvm/llvm-project/pull/147691
>From 5e9aa6b6b1fe332f5e0c92b14835484558c00732 Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Wed, 9 Jul 2025 02:53:36 -0700
Subject: [PATCH 1/2] [mlir][Vector]Add constraints to
vector.shape_cast(constant) -> constant
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++++++---
mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++
2 files changed, 18 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..5bbe6704aac48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- // shape_cast(constant) -> constant
+ // shape_cast(constant) -> constant,
+ // if element type of the source and result are the same
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ if (splatAttr.getElementType() == resultType.getElementType())
+ return splatAttr.reshape(getType());
+ }
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..69da8a31d2c9b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
+func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> {
+ %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
+ // CHECK-NOT: vector.shape_cast
+ %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8>
+ // CHECK-NOT: vector.extract
+ %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8>
+ return %2 : vector<12xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
>From 5604b262bae4dabbcba9803abf52f57eb95edaf1 Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Tue, 15 Jul 2025 03:11:26 -0700
Subject: [PATCH 2/2] Update based on comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 26 ++++++++++++++++++----
mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++--
2 files changed, 24 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5bbe6704aac48..4cc4eed08f2da 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5922,12 +5922,30 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- // shape_cast(constant) -> constant,
- // if element type of the source and result are the same
+ // shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
- if (splatAttr.getElementType() == resultType.getElementType())
- return splatAttr.reshape(getType());
+
+ // The shape and 'scalable dims' of the new attribute must match the result
+ // of the shape_cast:
+ auto newShape = resultType.getShape();
+ auto newScalableDims = resultType.getScalableDims();
+
+ // The element type must be retained. Note that this is to handle currently
+ // valid IR like
+ //
+ // ```
+ // %0 = llvm.mlir.constant(dense<0.> : vector<1xf8E4M3FN>) : vector<1xi8>
+ // %1 = vector.shape_cast %0 : vector<1xi8> to vector<1x1xi8>
+ // ```
+ //
+ // where the element types of the attribute and result do not match.
+ auto newElementType = splatAttr.getElementType();
+
+ auto newAttr = VectorType::get(newShape, newElementType, newScalableDims);
+
+ return DenseElementsAttr::get(newAttr,
+ splatAttr.getSplatValue<Attribute>());
}
// shape_cast(poison) -> poison
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 69da8a31d2c9b..b0114905db742 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1003,11 +1003,11 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
+// CHECK: %[[CONST:.*]] = llvm.mlir.constant
+// CHECK-NEXT: return %[[CONST]]
func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> {
%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
- // CHECK-NOT: vector.shape_cast
%1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8>
- // CHECK-NOT: vector.extract
%2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8>
return %2 : vector<12xi8>
}
More information about the Mlir-commits
mailing list