[Mlir-commits] [mlir] af5c471 - [mlir][Vector] Add vector.extract(vector.shuffle) folder (#115105)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 6 18:17:16 PST 2024
Author: Diego Caballero
Date: 2024-11-06T18:17:12-08:00
New Revision: af5c471a4d9a9bff30b381d1fe2fe828672bb812
URL: https://github.com/llvm/llvm-project/commit/af5c471a4d9a9bff30b381d1fe2fe828672bb812
DIFF: https://github.com/llvm/llvm-project/commit/af5c471a4d9a9bff30b381d1fe2fe828672bb812.diff
LOG: [mlir][Vector] Add vector.extract(vector.shuffle) folder (#115105)
This PR adds a folder for extracting an element from a vector shuffle.
It turns something like:
```
%shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
: vector<8xf32>, vector<8xf32>
%extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
```
into:
```
%extract = vector.extract %b[7] : f32 from vector<8xf32>
```
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d8913251e56e9e..db199a46e1637c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1705,6 +1705,47 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return extractOp.getResult();
}
+/// Fold extractOp coming from ShuffleOp.
+///
+/// Example:
+///
+/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
+/// : vector<8xf32>, vector<8xf32>
+/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
+/// ->
+/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
+///
+static Value foldExtractFromShuffle(ExtractOp extractOp) {
+ // Dynamic positions are not folded as the resulting code would be more
+ // complex than the input code.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
+ auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+ if (!shuffleOp)
+ return Value();
+
+ // TODO: 0-D or multi-dimensional vectors not supported yet.
+ if (shuffleOp.getResultVectorType().getRank() != 1)
+ return Value();
+
+ int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
+ auto shuffleMask = shuffleOp.getMask();
+ int64_t extractIdx = extractOp.getStaticPosition()[0];
+ int64_t shuffleIdx = shuffleMask[extractIdx];
+
+ // Find the shuffled vector to extract from based on the shuffle index.
+ if (shuffleIdx < inputVecSize) {
+ extractOp.setOperand(0, shuffleOp.getV1());
+ extractOp.setStaticPosition({shuffleIdx});
+ } else {
+ extractOp.setOperand(0, shuffleOp.getV2());
+ extractOp.setStaticPosition({shuffleIdx - inputVecSize});
+ }
+
+ return extractOp.getResult();
+}
+
// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1953,6 +1994,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return res;
if (auto res = foldExtractFromBroadcast(*this))
return res;
+ if (auto res = foldExtractFromShuffle(*this))
+ return res;
if (auto res = foldExtractFromShapeCast(*this))
return res;
if (auto val = foldExtractFromExtractStrided(*this))
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index df87f86765a3a3..5ae769090dac66 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
+// -----
+
+// CHECK-LABEL: @fold_extract_shuffle
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32>
+func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
+ -> (f32, f32, f32, f32) {
+ %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
+ %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
+ %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
+ %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
+ %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
+ return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
+}
// -----
More information about the Mlir-commits
mailing list