[llvm] [mlir] [DO NOT SUBMIT] Implement LowerVectorToArmNeon Pattern (PR #81895)

Kojo Acquah via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 12:41:28 PST 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/81895

>From 8ce5a145c7d1fb53cc691a0cdebf5de75e89253b Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 15 Feb 2024 17:59:46 +0000
Subject: [PATCH] Implement LowerVectorToArmNeon

---
 .../include/mlir/Dialect/ArmNeon/Transforms.h |  21 +++
 .../Conversion/VectorToLLVM/CMakeLists.txt    |   1 +
 mlir/lib/Dialect/ArmNeon/CMakeLists.txt       |  15 +-
 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt    |  13 ++
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |  14 ++
 .../Transforms/LowerVectorToArmNeon.cpp       | 153 ++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  21 +++
 7 files changed, 225 insertions(+), 13 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmNeon/Transforms.h
 create mode 100644 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp

diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
new file mode 100644
index 00000000000000..41dbc2633d52c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+} // namespace arm_neon
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSMEDialect
   MLIRArmSMETransforms
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
-  IR/ArmNeonDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
-  DEPENDS
-  MLIRArmNeonIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRSideEffectInterfaces
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+  ArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..dcd806e981479d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmNeonTransforms
+  LowerVectorToArmNeon.cpp
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeonDialect
+  MLIRFuncDialect
+  MLIRVectorDialect
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..01806f43a10056
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-vector-lowering"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = dyn_cast<ShapedType>(container))
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Check index maps represent M N K and aren't transposed.
+    auto indexingMaps = op.getIndexingMapsArray();
+    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+                 affineMap.getNumResults() != 2;
+        })) {
+      return failure();
+    }
+
+    // Check iterator types for contract
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() != 3 ||
+        iteratorTypes[0] != vector::IteratorType::parallel ||
+        iteratorTypes[1] != vector::IteratorType::parallel ||
+        iteratorTypes[2] != vector::IteratorType::reduction) {
+      return failure();
+    }
+
+    // Check the tile size by mapping the dimensions of the contract
+    //  -- Tile size: [2, 2, 8]
+    // Infer tile sizes from operands. Check required tile size
+    // Note: RHS is not transposed
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+    auto dimM = lhsType.getDimSize(0);
+    auto dimN = rhsType.getDimSize(0);
+    auto dimK = lhsType.getDimSize(1);
+    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+      return failure();
+    }
+
+    // Check two extsi inputs Rhs Lhs
+    arith::ExtSIOp origLhsExtOp;
+    arith::ExtSIOp origRhsExtOp;
+    if (!(origLhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+        !(origRhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+      return failure();
+    }
+
+    arith::ExtSIOp extsiLhs;
+    arith::ExtSIOp extsiRhs;
+    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+    // following neon instruction. Check inputs for extsi are <=i8
+    if (auto lhsExtInType =
+            origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target lhs type with i8. This is likely redundant
+        Type targetLhsExtTy =
+            matchContainerType(rewriter.getI8Type(), lhsExtInType);
+        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+                                                   origLhsExtOp.getIn());
+      }
+    }
+    if (auto rhsExtInType =
+            origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target rhs type with i8
+        Type targetRhsExtTy =
+            matchContainerType(rewriter.getI8Type(), rhsExtInType);
+        extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+                                                   origRhsExtOp.getIn());
+      }
+    }
+
+    if (!extsiLhs || !extsiRhs) {
+      return failure();
+    }
+
+    // Collapse to 1D vectors required by smmla intrinsic
+    auto collapsedInputType = VectorType::get(
+        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+    auto collapsedOutputType =
+        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+    auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+    auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+    auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+        res.getLoc(), collapsedOutputType, res);
+
+    // Replace the contract with a neon op
+    auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+        collapsedRhs);
+
+    // Reshape output back to 2D
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+                                                     smmlaOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7860ccd0406a13..9b79c8d3c615ca 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1928,6 +1928,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArmNeonTransforms",
+    srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
+    hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
+    includes = ["include"],
+    deps = [
+        ":ArithDialect",
+        ":ArmNeonIncGen",
+        ":ArmNeonDialect",
+        ":FuncDialect",
+        ":IR",
+        ":LLVMDialect",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":VectorDialect",
+        ":Transforms",
+        "//llvm:Core",
+        "//llvm:Support",
+    ],
+)
+
 gentbl_cc_library(
     name = "ArmNeonConversionIncGen",
     tbl_outs = [



More information about the llvm-commits mailing list