[Mlir-commits] [mlir] b015fcc - [mlir][vector] Fix extract op canonicalization for 0d vector

Thomas Raoux llvmlistbot at llvm.org
Tue Jan 17 09:28:17 PST 2023


Author: Thomas Raoux
Date: 2023-01-17T17:25:51Z
New Revision: b015fccbe503fd7109405decb4f3eb6269e7706b

URL: https://github.com/llvm/llvm-project/commit/b015fccbe503fd7109405decb4f3eb6269e7706b
DIFF: https://github.com/llvm/llvm-project/commit/b015fccbe503fd7109405decb4f3eb6269e7706b.diff

LOG: [mlir][vector] Fix extract op canonicalization for 0d vector

Fix ExtractOpFromBroadcast when the broadcast source is a 0d vector.

Differential Revision: https://reviews.llvm.org/D141735

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db2cd82790009..387f003e8c210 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1632,6 +1632,13 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     // folding patterns.
     if (extractResultRank < broadcastSrcRank)
       return failure();
+
+    // Special case if broadcast src is a 0D vector.
+    if (extractResultRank == 0) {
+      assert(broadcastSrcRank == 0 && source.getType().isa<VectorType>());
+      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      return success();
+    }
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1990b893f02f9..081d94e5805fc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -519,6 +519,18 @@ func.func @fold_extract_broadcast(%a : f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast_0dvec
+//  CHECK-SAME:   %[[A:.*]]: vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   return %[[B]] : f32
+func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+  %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_negative
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
 //       CHECK:   vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>


        


More information about the Mlir-commits mailing list