[Mlir-commits] [mlir] e79b7f5 - [mlir][Vector] Fold extractelement splat.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 8 00:55:06 PDT 2022


Author: jacquesguan
Date: 2022-04-08T07:54:37Z
New Revision: e79b7f501c19784d6160b105a8b84e7fdf28e113

URL: https://github.com/llvm/llvm-project/commit/e79b7f501c19784d6160b105a8b84e7fdf28e113
DIFF: https://github.com/llvm/llvm-project/commit/e79b7f501c19784d6160b105a8b84e7fdf28e113.diff

LOG: [mlir][Vector] Fold extractelement splat.

This revision supports to fold vector.extractelement (splat X) -> X.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122960

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 07546c0fd51ff..7d9febec632ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -950,7 +950,12 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 
   Attribute src = operands[0];
   Attribute pos = operands[1];
-  if (!src || !pos)
+
+  // Fold extractelement (splat X) -> X.
+  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
+    return splat.getInput();
+
+  if (!pos || !src)
     return {};
 
   auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b6640bb06784..033f17ae2fe12 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1409,3 +1409,13 @@ func @extract_element_fold() -> i32 {
   %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
   return %1 : i32
 }
+
+// CHECK-LABEL: func @extract_element_splat_fold
+//  CHECK-SAME: (%[[ARG:.+]]: i32)
+//       CHECK:   return %[[ARG]]
+func @extract_element_splat_fold(%a : i32) -> i32 {
+  %v = vector.splat %a : vector<4xi32>
+  %i = arith.constant 2 : i32
+  %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
+  return %1 : i32
+}


        


More information about the Mlir-commits mailing list