[Mlir-commits] [mlir] [mlir][VectorOps] Add unrolling for n-D vector.interleave ops (3/4) (PR #80967)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Feb 13 09:37:21 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80967
>From 661e3b6840a9336359cc28f0d5033c3130d555ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:00:45 +0000
Subject: [PATCH] [mlir][VectorOps] Add unrolling for n-D vector.interleave ops
This unrolls n-D vector.interleave ops like:
```mlir
vector.interleave %i, %j : vector<6x3xf32>
```
To a sequence of 1-D operations, which can then be directly lowered to
LLVM.
---
.../Vector/TransformOps/VectorTransformOps.td | 14 +++
.../Vector/Transforms/LoweringPatterns.h | 8 ++
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../TransformOps/VectorTransformOps.cpp | 5 ++
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Transforms/LowerVectorInterleave.cpp | 88 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 24 +++++
...vector-interleave-lowering-transforms.mlir | 49 +++++++++++
8 files changed, 190 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
create mode 100644 mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index ce88360aa52e9d..83df5fe27d7a4a 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -292,6 +292,20 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.lower_interleave",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector interleave operations should be lowered to
+ finer-grained vector primitives.
+
+ This is usally a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d3..1cd3bab46396e3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f1..e3a436c4a94009 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
+ populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns,
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 19922c4295fe03..6c2cfd8833dddc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
}
}
+void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorInterleaveLoweringPatterns(patterns);
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..f221b7462dfd7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
+ LowerVectorInterleave.cpp
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 00000000000000..fffa63d13820a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,88 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Progressive lowering of InterleaveOp.
+///
+/// Each leading dimension is unrolled until the result of the interleave is
+/// rank 1 (or the dimension is scalable, so can't be unrolled).
+///
+/// Example:
+///
+/// ```
+/// %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
+/// ```
+/// Becomes:
+/// ```
+/// %lhs_0 = vector.extract %lhs[0]
+/// %rhs_0 = vector.extract %rhs[0]
+/// %lhs_1 = vector.extract %lhs[1]
+/// %rhs_1 = vector.extract %rhs[1]
+/// %zip_0 = vector.interleave %lhs_0, %rhs_0
+/// %zip_1 = vector.interleave %lhs_1, %rhs_1
+/// %res_0 = vector.insert %zip_0, %undef[0]
+/// %0 = vector.insert %zip_1, %res_0[1]
+/// ```
+///
+/// If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
+/// following the same pattern.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+ if (resultType.getRank() == 1)
+ return failure();
+
+ // Below we unroll the leading (or front) dimension. If that dimension is
+ // scalable we can't unroll it.
+ if (resultType.getScalableDims().front())
+ return failure();
+
+ // n-D case: Unroll the leading dimension.
+ auto loc = op.getLoc();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+ Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+ Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+ Value interleave =
+ rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+ result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a46f2e101f3c35..e94e51d49a98b7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2497,3 +2497,27 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
%0 = vector.interleave %a, %b : vector<[4]xi32>
return %0 : vector<[8]xi32>
}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+ // CHECK: llvm.shufflevector
+ // CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8>
+ return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+ // CHECK: llvm.intr.experimental.vector.interleave2
+ // CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ return %0 : vector<2x[16]xi16>
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
new file mode 100644
index 00000000000000..05be19c3c81f9a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_2d
+// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
+ // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
+ // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
+ // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+ // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+ // CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8>
+ return %0 : vector<2x6xi8>
+}
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
+ // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
+ // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
+ // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+ // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+ // CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ return %0 : vector<2x[16]xi16>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_interleave
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list