[Mlir-commits] [mlir] 7becf0f - [mlir][vector] Fold extract(broadcast) of same rank
Lei Zhang
llvmlistbot at llvm.org
Thu Apr 7 10:03:38 PDT 2022
Author: Lei Zhang
Date: 2022-04-07T12:59:54-04:00
New Revision: 7becf0f6cd31ea7462c5e18a88cb2f7a2c508886
URL: https://github.com/llvm/llvm-project/commit/7becf0f6cd31ea7462c5e18a88cb2f7a2c508886
DIFF: https://github.com/llvm/llvm-project/commit/7becf0f6cd31ea7462c5e18a88cb2f7a2c508886.diff
LOG: [mlir][vector] Fold extract(broadcast) of same rank
This case is handled in neither the folding or canonicalization
patterns. The folding pattern cannot generate new broadcast ops,
so it should be handled by the canonicalization pattern.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D123307
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eda77392041ba..07546c0fd51ff 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1496,6 +1496,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return failure();
+
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return failure();
@@ -1504,10 +1505,10 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
- // We only consider the case where the rank of the source is smaller than
- // the rank of the extract dst. The other cases are handled in the folding
- // patterns.
- if (extractResultRank <= broadcastSrcRank)
+ // We only consider the case where the rank of the source is less than or
+ // equal to the rank of the extract dst. The other cases are handled in the
+ // folding patterns.
+ if (extractResultRank < broadcastSrcRank)
return failure();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a083851a4bb98..8b6640bb06784 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>
+// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
+// CHECK: return %[[R]] : vector<8xf32>
+func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+ %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
+ %r = vector.extract %b[0] : vector<1x8xf32>
+ return %r : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_shapecast
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
More information about the Mlir-commits
mailing list