[Mlir-commits] [mlir] e79b7f5 - [mlir][Vector] Fold extractelement splat.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 8 00:55:06 PDT 2022
Author: jacquesguan
Date: 2022-04-08T07:54:37Z
New Revision: e79b7f501c19784d6160b105a8b84e7fdf28e113
URL: https://github.com/llvm/llvm-project/commit/e79b7f501c19784d6160b105a8b84e7fdf28e113
DIFF: https://github.com/llvm/llvm-project/commit/e79b7f501c19784d6160b105a8b84e7fdf28e113.diff
LOG: [mlir][Vector] Fold extractelement splat.
This revision supports to fold vector.extractelement (splat X) -> X.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D122960
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 07546c0fd51ff..7d9febec632ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
Attribute src = operands[0];
Attribute pos = operands[1];
- if (!src || !pos)
+
+ // Fold extractelement (splat X) -> X.
+ if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
+ return splat.getInput();
+
+ if (!pos || !src)
return {};
auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b6640bb06784..033f17ae2fe12 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 {
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
return %1 : i32
}
+
+// CHECK-LABEL: func @extract_element_splat_fold
+// CHECK-SAME: (%[[ARG:.+]]: i32)
+// CHECK: return %[[ARG]]
+func @extract_element_splat_fold(%a : i32) -> i32 {
+ %v = vector.splat %a : vector<4xi32>
+ %i = arith.constant 2 : i32
+ %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+ return %1 : i32
+}
More information about the Mlir-commits
mailing list