[Mlir-commits] [mlir] 2c5eea0 - [mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts (#113828)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 29 15:47:48 PDT 2024
Author: Kunwar Grover
Date: 2024-10-29T22:47:44Z
New Revision: 2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340
URL: https://github.com/llvm/llvm-project/commit/2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340
DIFF: https://github.com/llvm/llvm-project/commit/2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340.diff
LOG: [mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts (#113828)
The current vector.insert folder tries to replace a scalar with a 0-rank
vector. This patch fixes this crash by not folding unless they types of
the result and replacement are same.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..1853ae04f45d90 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertOpConstantFolder>(context);
}
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
- if (getNumIndices() == 0)
+ // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+ // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+ // (type mismatch).
+ if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d6bc199e601c0..c963460e7259fb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -800,6 +800,43 @@ func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vecto
// -----
+// CHECK-LABEL: func @extract_no_fold_scalar_to_0d(
+// CHECK-SAME: %[[v:.*]]: vector<f32>)
+// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
+// CHECK: return %[[extract]]
+func.func @extract_no_fold_scalar_to_0d(%v: vector<f32>) -> f32 {
+ %0 = vector.extract %v[] : f32 from vector<f32>
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_fold_same_rank(
+// CHECK-SAME: %[[v:.*]]: vector<2x2xf32>)
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: : vector<2x2xf32>
+// CHECK-NOT: vector.insert
+// CHECK: return %[[CST]]
+func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> {
+ %cst = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+ %0 = vector.insert %cst, %v [] : vector<2x2xf32> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_no_fold_scalar_to_0d(
+// CHECK-SAME: %[[v:.*]]: vector<f32>)
+// CHECK: %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
+// CHECK: return %[[extract]]
+func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.insert %cst, %v [] : f32 into vector<f32>
+ return %0 : vector<f32>
+}
+
+// -----
+
// CHECK-LABEL: dont_fold_expand_collapse
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -2606,17 +2643,6 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
// -----
-// CHECK-LABEL: func @extract_from_0d_regression(
-// CHECK-SAME: %[[v:.*]]: vector<f32>)
-// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
-// CHECK: return %[[extract]]
-func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
- %0 = vector.extract %v[] : f32 from vector<f32>
- return %0 : f32
-}
-
-// -----
-
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
More information about the Mlir-commits
mailing list