[Mlir-commits] [mlir] [mlir][Vector] Add vector.extract(vector.shuffle) folder (PR #115105)

Diego Caballero llvmlistbot at llvm.org
Wed Nov 6 16:51:52 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/115105

>From a0da17ca5a5da46212aa913ae697fac9e402fcc5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 5 Nov 2024 17:19:38 -0800
Subject: [PATCH 1/2] [mlir][Vector] Add vector.extract(vector.shuffle) folder

This PR adds a folder for extracting an element from a vector shuffle.
It turns something like:

```
   %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
     : vector<8xf32>, vector<8xf32>
   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
```

into:

```
   %extract = vector.extract %b[7] : f32 from vector<4xf32>
```
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 42 ++++++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++
 2 files changed, 60 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d8913251e56e9e..723044aa2b66e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1705,6 +1705,46 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   return extractOp.getResult();
 }
 
+/// Fold extractOp coming from ShuffleOp.
+///
+/// Example:
+///
+///   %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
+///     : vector<8xf32>, vector<8xf32>
+///   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
+/// ->
+///   %extract = vector.extract %b[7] : f32 from vector<4xf32>
+///
+static Value foldExtractFromShuffle(ExtractOp extractOp) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
+  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+  if (!shuffleOp)
+    return Value();
+
+  // TODO: 0-D or multi-dimensional vectors not supported yet.
+  if (shuffleOp.getResultVectorType().getRank() != 1)
+    return Value();
+
+  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
+  auto shuffleMask = shuffleOp.getMask();
+  int64_t extractIdx = extractOp.getStaticPosition()[0];
+  int64_t shuffleIdx = shuffleMask[extractIdx];
+
+  // Find the shuffled vector to extract from based on the shuffle index.
+  if (shuffleIdx < inputVecSize) {
+    extractOp.setOperand(0, shuffleOp.getV1());
+    extractOp.setStaticPosition({shuffleIdx});
+  } else {
+    extractOp.setOperand(0, shuffleOp.getV2());
+    extractOp.setStaticPosition({shuffleIdx - inputVecSize});
+  }
+
+  return extractOp.getResult();
+}
+
 // Fold extractOp with source coming from ShapeCast op.
 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   // TODO: Canonicalization for dynamic position not implemented yet.
@@ -1953,6 +1993,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return res;
   if (auto res = foldExtractFromBroadcast(*this))
     return res;
+  if (auto res = foldExtractFromShuffle(*this))
+    return res;
   if (auto res = foldExtractFromShapeCast(*this))
     return res;
   if (auto val = foldExtractFromExtractStrided(*this))
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index df87f86765a3a3..5ae769090dac66 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
   %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+// -----
+
+// CHECK-LABEL: @fold_extract_shuffle
+//  CHECK-SAME:   %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   vector.extract %[[A]][0] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[B]][0] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[A]][7] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[B]][7] : f32 from vector<8xf32>
+func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
+                                -> (f32, f32, f32, f32) {
+  %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
+  %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
+  %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
+  %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
+  %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
+  return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
+}
 
 // -----
 

>From 4c814a2449cd3eb2eec6f97ec23e3c73cf356ee1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Wed, 6 Nov 2024 15:31:03 -0800
Subject: [PATCH 2/2] Feedback

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 723044aa2b66e4..db199a46e1637c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1713,10 +1713,11 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
 ///     : vector<8xf32>, vector<8xf32>
 ///   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
 /// ->
-///   %extract = vector.extract %b[7] : f32 from vector<4xf32>
+///   %extract = vector.extract %b[7] : f32 from vector<8xf32>
 ///
 static Value foldExtractFromShuffle(ExtractOp extractOp) {
-  // TODO: Canonicalization for dynamic position not implemented yet.
+  // Dynamic positions are not folded as the resulting code would be more
+  // complex than the input code.
   if (extractOp.hasDynamicPosition())
     return Value();
 



More information about the Mlir-commits mailing list