[Mlir-commits] [mlir] d0766c0 - [mlir][vector] Fold vector.extractelement(vector.broadcast)

Matthias Springer llvmlistbot at llvm.org
Thu Dec 22 01:35:07 PST 2022


Author: Matthias Springer
Date: 2022-12-22T10:34:58+01:00
New Revision: d0766c0861c6f9ab4ec286a695ae8e161f418b2f

URL: https://github.com/llvm/llvm-project/commit/d0766c0861c6f9ab4ec286a695ae8e161f418b2f
DIFF: https://github.com/llvm/llvm-project/commit/d0766c0861c6f9ab4ec286a695ae8e161f418b2f.diff

LOG: [mlir][vector] Fold vector.extractelement(vector.broadcast)

Differential Revision: https://reviews.llvm.org/D140394

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 67d0f7566677c..345e6b0fd672c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1047,6 +1047,11 @@ OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
   if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
     return splat.getInput();
 
+  // Fold extractelement(broadcast(X)) -> X.
+  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
+    if (!broadcast.getSource().getType().isa<VectorType>())
+      return broadcast.getSource();
+
   if (!pos || !src)
     return {};
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ebadecd11e64c..2ebe2d7f42952 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2095,3 +2095,15 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
   %res = vector.extract_strided_slice %mask {offsets = [3], sizes = [5], strides = [1]} : vector<12x7xi1> to vector<5x7xi1>
   return %res : vector<5x7xi1>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_extractelement_of_broadcast(
+//  CHECK-SAME:     %[[f:.*]]: f32
+//       CHECK:   return %[[f]]
+func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
+  %0 = vector.broadcast %f : f32 to vector<15xf32>
+  %c5 = arith.constant 5 : index
+  %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
+  return %1 : f32
+}


        


More information about the Mlir-commits mailing list