[Mlir-commits] [mlir] af5c471 - [mlir][Vector] Add vector.extract(vector.shuffle) folder (#115105)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 6 18:17:16 PST 2024


Author: Diego Caballero
Date: 2024-11-06T18:17:12-08:00
New Revision: af5c471a4d9a9bff30b381d1fe2fe828672bb812

URL: https://github.com/llvm/llvm-project/commit/af5c471a4d9a9bff30b381d1fe2fe828672bb812
DIFF: https://github.com/llvm/llvm-project/commit/af5c471a4d9a9bff30b381d1fe2fe828672bb812.diff

LOG: [mlir][Vector] Add vector.extract(vector.shuffle) folder (#115105)

This PR adds a folder for extracting an element from a vector shuffle.
It turns something like:

```
   %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
     : vector<8xf32>, vector<8xf32>
   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
```

into:

```
   %extract = vector.extract %b[7] : f32 from vector<8xf32>
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d8913251e56e9e..db199a46e1637c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1705,6 +1705,47 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   return extractOp.getResult();
 }
 
+/// Fold extractOp coming from ShuffleOp.
+///
+/// Example:
+///
+///   %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
+///     : vector<8xf32>, vector<8xf32>
+///   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
+/// ->
+///   %extract = vector.extract %b[7] : f32 from vector<8xf32>
+///
+static Value foldExtractFromShuffle(ExtractOp extractOp) {
+  // Dynamic positions are not folded as the resulting code would be more
+  // complex than the input code.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
+  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+  if (!shuffleOp)
+    return Value();
+
+  // TODO: 0-D or multi-dimensional vectors not supported yet.
+  if (shuffleOp.getResultVectorType().getRank() != 1)
+    return Value();
+
+  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
+  auto shuffleMask = shuffleOp.getMask();
+  int64_t extractIdx = extractOp.getStaticPosition()[0];
+  int64_t shuffleIdx = shuffleMask[extractIdx];
+
+  // Find the shuffled vector to extract from based on the shuffle index.
+  if (shuffleIdx < inputVecSize) {
+    extractOp.setOperand(0, shuffleOp.getV1());
+    extractOp.setStaticPosition({shuffleIdx});
+  } else {
+    extractOp.setOperand(0, shuffleOp.getV2());
+    extractOp.setStaticPosition({shuffleIdx - inputVecSize});
+  }
+
+  return extractOp.getResult();
+}
+
 // Fold extractOp with source coming from ShapeCast op.
 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   // TODO: Canonicalization for dynamic position not implemented yet.
@@ -1953,6 +1994,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return res;
   if (auto res = foldExtractFromBroadcast(*this))
     return res;
+  if (auto res = foldExtractFromShuffle(*this))
+    return res;
   if (auto res = foldExtractFromShapeCast(*this))
     return res;
   if (auto val = foldExtractFromExtractStrided(*this))

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index df87f86765a3a3..5ae769090dac66 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
   %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+// -----
+
+// CHECK-LABEL: @fold_extract_shuffle
+//  CHECK-SAME:   %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   vector.extract %[[A]][0] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[B]][0] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[A]][7] : f32 from vector<8xf32>
+//       CHECK:   vector.extract %[[B]][7] : f32 from vector<8xf32>
+func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
+                                -> (f32, f32, f32, f32) {
+  %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
+  %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
+  %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
+  %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
+  %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
+  return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list