[Mlir-commits] [mlir] [mlir][vector] Add n-d deinterleave lowering (PR #94237)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jun 6 04:53:41 PDT 2024


================
@@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
+///
+/// Example:
+///
+/// ```mlir
+/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+/// ```
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
+/// %0 = vector.extract %a[0, 0, 0]                  ─┐
+///        : vector<8xi64> from vector<1x2x3x8xi64>   |
+/// %1, %2 = vector.deinterleave %0 :                 |
+///        : vector<8xi64> -> vector<4xi64>           | -- Initial deinterleave
+/// %3 = vector.insert %1, %result [0, 0, 0]          |    operation unrolled.
+///        : vector<4xi64> into vector<1x2x3x4xi64>   |
+/// %4 = vector.insert %2, %result [0, 0, 0]          |
+///        : vector<4xi64> into vector<1x2x3x4xi64>   ┘
+/// %5 = vector.extract %a[0, 0, 1]                  ─┐
+///        : vector<8xi64> from vector<1x2x3x8xi64>   |
+/// %6, %7 = vector.deinterleave %5 :                 |
+///        : vector<8xi64> -> vector<4xi64>           | -- Recursive pattern for
+/// %8 = vector.insert %6, %3 [0, 0, 1]               |    subsequent unrolled
+///        : vector<4xi64> into vector<1x2x3x4xi64>   |    deinterleave
+/// %9 = vector.insert %7, %3 [0, 0, 1]               |    operations. Repeated
+///        : vector<4xi64> into vector<1x2x3x4xi64>   ┘    5x in this case.
----------------
MacDue wrote:

Looks like there's a typo here with both result being inserted into `%3`

https://github.com/llvm/llvm-project/pull/94237


More information about the Mlir-commits mailing list