[Mlir-commits] [mlir] [mlir][VectorOps] Add unrolling for n-D vector.interleave ops (3/4) (PR #80967)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Feb 13 02:51:24 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80967

>From c8a0b8eb9fa008a994f5b5031b213075fbd27c95 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:00:45 +0000
Subject: [PATCH] [mlir][VectorOps] Add unrolling for n-D vector.interleave ops

This unrolls n-D vector.interleave ops like:

```mlir
vector.interleave %i, %j : vector<6x3xf32>
```

To a sequence of 1-D operations, which can then be directly lowered to
LLVM.
---
 .../Vector/Transforms/LoweringPatterns.h      |  8 +++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/LowerVectorInterleave.cpp      | 64 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 48 ++++++++++++++
 5 files changed, 122 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d3..1cd3bab46396e3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
 void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f1..e3a436c4a94009 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
+    populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns,
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..f221b7462dfd7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
+  LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 00000000000000..0ca38eba942a5d
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of InterleaveOp.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+    if (resultType.getRank() == 1)
+      return failure();
+
+    // Below we unroll the leading (or front) dimension. If that dimension is
+    // scalable we can't unroll it.
+    if (resultType.getScalableDims().front())
+      return failure();
+
+    // n-D case: Unroll the leading dimension.
+    auto loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+      Value interleave =
+          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+      result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a46f2e101f3c35..3cbca65472fb69 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2497,3 +2497,51 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
   %0 = vector.interleave %a, %b : vector<[4]xi32>
   return %0 : vector<[8]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  return %0 : vector<2x[16]xi16>
+}



More information about the Mlir-commits mailing list