[Mlir-commits] [llvm] [mlir] [mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (PR #81895)
Kojo Acquah
llvmlistbot at llvm.org
Tue Mar 5 10:37:32 PST 2024
https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/81895
>From 2341c57b18eda6edb5f143bc2eb010aa98464dbf Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 15 Feb 2024 17:59:46 +0000
Subject: [PATCH] Implement LowerVectorToArmNeon
---
.../include/mlir/Dialect/ArmNeon/Transforms.h | 21 +++
.../Conversion/VectorToLLVM/CMakeLists.txt | 1 +
mlir/lib/Dialect/ArmNeon/CMakeLists.txt | 15 +-
mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt | 13 ++
.../Dialect/ArmNeon/Transforms/CMakeLists.txt | 14 ++
.../Transforms/LowerVectorToArmNeon.cpp | 153 ++++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 21 +++
7 files changed, 225 insertions(+), 13 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/ArmNeon/Transforms.h
create mode 100644 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
new file mode 100644
index 00000000000000..41dbc2633d52c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+} // namespace arm_neon
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSMEDialect
MLIRArmSMETransforms
MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
- IR/ArmNeonDialect.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
- DEPENDS
- MLIRArmNeonIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
- )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+ ArmNeonDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+ DEPENDS
+ MLIRArmNeonIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..dcd806e981479d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmNeonTransforms
+ LowerVectorToArmNeon.cpp
+
+ DEPENDS
+ MLIRArmNeonIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmNeonDialect
+ MLIRFuncDialect
+ MLIRVectorDialect
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..01806f43a10056
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-vector-lowering"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+ if (auto shapedTy = dyn_cast<ShapedType>(container))
+ return shapedTy.clone(element);
+
+ return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Check index maps represent M N K and aren't transposed.
+ auto indexingMaps = op.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+ return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+ affineMap.getNumResults() != 2;
+ })) {
+ return failure();
+ }
+
+ // Check iterator types for contract
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() != 3 ||
+ iteratorTypes[0] != vector::IteratorType::parallel ||
+ iteratorTypes[1] != vector::IteratorType::parallel ||
+ iteratorTypes[2] != vector::IteratorType::reduction) {
+ return failure();
+ }
+
+ // Check the tile size by mapping the dimensions of the contract
+ // -- Tile size: [2, 2, 8]
+ // Infer tile sizes from operands. Check required tile size
+ // Note: RHS is not transposed
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+ auto dimM = lhsType.getDimSize(0);
+ auto dimN = rhsType.getDimSize(0);
+ auto dimK = lhsType.getDimSize(1);
+ if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+ return failure();
+ }
+
+ // Check two extsi inputs Rhs Lhs
+ arith::ExtSIOp origLhsExtOp;
+ arith::ExtSIOp origRhsExtOp;
+ if (!(origLhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+ !(origRhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+ return failure();
+ }
+
+ arith::ExtSIOp extsiLhs;
+ arith::ExtSIOp extsiRhs;
+ // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+ // following neon instruction. Check inputs for extsi are <=i8
+ if (auto lhsExtInType =
+ origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+ // Target lhs type with i8. This is likely redundant
+ Type targetLhsExtTy =
+ matchContainerType(rewriter.getI8Type(), lhsExtInType);
+ extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+ origLhsExtOp.getIn());
+ }
+ }
+ if (auto rhsExtInType =
+ origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+ // Target rhs type with i8
+ Type targetRhsExtTy =
+ matchContainerType(rewriter.getI8Type(), rhsExtInType);
+ extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+ origRhsExtOp.getIn());
+ }
+ }
+
+ if (!extsiLhs || !extsiRhs) {
+ return failure();
+ }
+
+ // Collapse to 1D vectors required by smmla intrinsic
+ auto collapsedInputType = VectorType::get(
+ {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+ auto collapsedOutputType =
+ VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+ auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+ extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+ auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+ extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+ auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+ res.getLoc(), collapsedOutputType, res);
+
+ // Replace the contract with a neon op
+ auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+ op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+ collapsedRhs);
+
+ // Reshape output back to 2D
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+ smmlaOp);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f33f165992213..59dc1100ea8df5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1929,6 +1929,27 @@ cc_library(
],
)
+cc_library(
+ name = "ArmNeonTransforms",
+ srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
+ hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
+ includes = ["include"],
+ deps = [
+ ":ArithDialect",
+ ":ArmNeonIncGen",
+ ":ArmNeonDialect",
+ ":FuncDialect",
+ ":IR",
+ ":LLVMDialect",
+ ":SideEffectInterfaces",
+ ":Support",
+ ":VectorDialect",
+ ":Transforms",
+ "//llvm:Core",
+ "//llvm:Support",
+ ],
+)
+
gentbl_cc_library(
name = "ArmNeonConversionIncGen",
tbl_outs = [
More information about the Mlir-commits
mailing list