[Mlir-commits] [mlir] 7becf0f - [mlir][vector] Fold extract(broadcast) of same rank

Lei Zhang llvmlistbot at llvm.org
Thu Apr 7 10:03:38 PDT 2022


Author: Lei Zhang
Date: 2022-04-07T12:59:54-04:00
New Revision: 7becf0f6cd31ea7462c5e18a88cb2f7a2c508886

URL: https://github.com/llvm/llvm-project/commit/7becf0f6cd31ea7462c5e18a88cb2f7a2c508886
DIFF: https://github.com/llvm/llvm-project/commit/7becf0f6cd31ea7462c5e18a88cb2f7a2c508886.diff

LOG: [mlir][vector] Fold extract(broadcast) of same rank

This case is handled in neither the folding or canonicalization
patterns. The folding pattern cannot generate new broadcast ops,
so it should be handled by the canonicalization pattern.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123307

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eda77392041ba..07546c0fd51ff 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1496,6 +1496,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     Operation *defOp = extractOp.getVector().getDefiningOp();
     if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
       return failure();
+
     Value source = defOp->getOperand(0);
     if (extractOp.getType() == source.getType())
       return failure();
@@ -1504,10 +1505,10 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     };
     unsigned broadcastSrcRank = getRank(source.getType());
     unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is smaller than
-    // the rank of the extract dst. The other cases are handled in the folding
-    // patterns.
-    if (extractResultRank <= broadcastSrcRank)
+    // We only consider the case where the rank of the source is less than or
+    // equal to the rank of the extract dst. The other cases are handled in the
+    // folding patterns.
+    if (extractResultRank < broadcastSrcRank)
       return failure();
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a083851a4bb98..8b6640bb06784 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
+//       CHECK:   return %[[R]] : vector<8xf32>
+func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+  %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
+  %r = vector.extract %b[0] : vector<1x8xf32>
+  return %r : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_shapecast
 //  CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
 //       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>


        


More information about the Mlir-commits mailing list