[Mlir-commits] [mlir] [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant (PR #147691)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 9 03:44:38 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Mengmeng Sun (MengmSun)
<details>
<summary>Changes</summary>
We have the case that after `ConvertToLLVMPass` it looks like:
```
...
%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>
%10 = vector.extract %8[0] : vector<192xi8> from vector<1x192xi8>
...
```
Our next pass is `Canonicalizer`. Several months ago everything went smoothly. However recently we met problem that
```
mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element type"' failed.
```
and we found that's because a `reshape` operation is added for `vector.shape_cast(constant) -> constant`. This operation will fail if the element type of the source attribute and return type are different.
So we want to add the constraints that only when **the element type of the source attribute and return type are the same** it will return `reshape` operation to make our case work as before and will not influence other cases.
---
Full diff: https://github.com/llvm/llvm-project/pull/147691.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-3)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..5bbe6704aac48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- // shape_cast(constant) -> constant
+ // shape_cast(constant) -> constant,
+ // if element type of the source and result are the same
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ if (splatAttr.getElementType() == resultType.getElementType())
+ return splatAttr.reshape(getType());
+ }
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..69da8a31d2c9b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
+func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> {
+ %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
+ // CHECK-NOT: vector.shape_cast
+ %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8>
+ // CHECK-NOT: vector.extract
+ %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8>
+ return %2 : vector<12xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
``````````
</details>
https://github.com/llvm/llvm-project/pull/147691
More information about the Mlir-commits
mailing list