[Mlir-commits] [mlir] [mlir][Vector] Add vector.extract(vector.shuffle) folder (PR #115105)
Diego Caballero
llvmlistbot at llvm.org
Wed Nov 6 16:51:52 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/115105
>From a0da17ca5a5da46212aa913ae697fac9e402fcc5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 5 Nov 2024 17:19:38 -0800
Subject: [PATCH 1/2] [mlir][Vector] Add vector.extract(vector.shuffle) folder
This PR adds a folder for extracting an element from a vector shuffle.
It turns something like:
```
%shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
: vector<8xf32>, vector<8xf32>
%extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
```
into:
```
%extract = vector.extract %b[7] : f32 from vector<4xf32>
```
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 42 ++++++++++++++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++
2 files changed, 60 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d8913251e56e9e..723044aa2b66e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1705,6 +1705,46 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return extractOp.getResult();
}
+/// Fold extractOp coming from ShuffleOp.
+///
+/// Example:
+///
+/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
+/// : vector<8xf32>, vector<8xf32>
+/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
+/// ->
+/// %extract = vector.extract %b[7] : f32 from vector<4xf32>
+///
+static Value foldExtractFromShuffle(ExtractOp extractOp) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
+ auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+ if (!shuffleOp)
+ return Value();
+
+ // TODO: 0-D or multi-dimensional vectors not supported yet.
+ if (shuffleOp.getResultVectorType().getRank() != 1)
+ return Value();
+
+ int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
+ auto shuffleMask = shuffleOp.getMask();
+ int64_t extractIdx = extractOp.getStaticPosition()[0];
+ int64_t shuffleIdx = shuffleMask[extractIdx];
+
+ // Find the shuffled vector to extract from based on the shuffle index.
+ if (shuffleIdx < inputVecSize) {
+ extractOp.setOperand(0, shuffleOp.getV1());
+ extractOp.setStaticPosition({shuffleIdx});
+ } else {
+ extractOp.setOperand(0, shuffleOp.getV2());
+ extractOp.setStaticPosition({shuffleIdx - inputVecSize});
+ }
+
+ return extractOp.getResult();
+}
+
// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1953,6 +1993,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return res;
if (auto res = foldExtractFromBroadcast(*this))
return res;
+ if (auto res = foldExtractFromShuffle(*this))
+ return res;
if (auto res = foldExtractFromShapeCast(*this))
return res;
if (auto val = foldExtractFromExtractStrided(*this))
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index df87f86765a3a3..5ae769090dac66 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
+// -----
+
+// CHECK-LABEL: @fold_extract_shuffle
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32>
+// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32>
+func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
+ -> (f32, f32, f32, f32) {
+ %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
+ %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
+ %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
+ %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
+ %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
+ return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
+}
// -----
>From 4c814a2449cd3eb2eec6f97ec23e3c73cf356ee1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Wed, 6 Nov 2024 15:31:03 -0800
Subject: [PATCH 2/2] Feedback
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 723044aa2b66e4..db199a46e1637c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1713,10 +1713,11 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
/// : vector<8xf32>, vector<8xf32>
/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
/// ->
-/// %extract = vector.extract %b[7] : f32 from vector<4xf32>
+/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
///
static Value foldExtractFromShuffle(ExtractOp extractOp) {
- // TODO: Canonicalization for dynamic position not implemented yet.
+ // Dynamic positions are not folded as the resulting code would be more
+ // complex than the input code.
if (extractOp.hasDynamicPosition())
return Value();
More information about the Mlir-commits
mailing list