[Mlir-commits] [mlir] d0766c0 - [mlir][vector] Fold vector.extractelement(vector.broadcast)
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 22 01:35:07 PST 2022
Author: Matthias Springer
Date: 2022-12-22T10:34:58+01:00
New Revision: d0766c0861c6f9ab4ec286a695ae8e161f418b2f
URL: https://github.com/llvm/llvm-project/commit/d0766c0861c6f9ab4ec286a695ae8e161f418b2f
DIFF: https://github.com/llvm/llvm-project/commit/d0766c0861c6f9ab4ec286a695ae8e161f418b2f.diff
LOG: [mlir][vector] Fold vector.extractelement(vector.broadcast)
Differential Revision: https://reviews.llvm.org/D140394
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 67d0f7566677c..345e6b0fd672c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1047,6 +1047,11 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
return splat.getInput();
+ // Fold extractelement(broadcast(X)) -> X.
+ if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
+ if (!broadcast.getSource().getType().isa<VectorType>())
+ return broadcast.getSource();
+
if (!pos || !src)
return {};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ebadecd11e64c..2ebe2d7f42952 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2095,3 +2095,15 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
%res = vector.extract_strided_slice %mask {offsets = [3], sizes = [5], strides = [1]} : vector<12x7xi1> to vector<5x7xi1>
return %res : vector<5x7xi1>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_extractelement_of_broadcast(
+// CHECK-SAME: %[[f:.*]]: f32
+// CHECK: return %[[f]]
+func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
+ %0 = vector.broadcast %f : f32 to vector<15xf32>
+ %c5 = arith.constant 5 : index
+ %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
+ return %1 : f32
+}
More information about the Mlir-commits
mailing list