[Mlir-commits] [mlir] [DO NOT SUBMIT] Implement LowerVectorToArmNeon Pattern (PR #81895)

Kojo Acquah llvmlistbot at llvm.org
Thu Feb 15 14:59:36 PST 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/81895

>From 24021eee9213c2df0bae62243d374b28c9493e25 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 15 Feb 2024 22:57:50 +0000
Subject: [PATCH] Implement LowerVectorToArmNeon

---
 .../mlir/Dialect/ArmNeon/CMakeLists.txt       |   1 +
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |   0
 .../Dialect/ArmNeon/Transforms/Transforms.h   |  21 +++
 .../Conversion/VectorToLLVM/CMakeLists.txt    |   1 +
 mlir/lib/Dialect/ArmNeon/CMakeLists.txt       |  15 +-
 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt    |  13 ++
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |   0
 .../Transforms/LowerVectorToArmNeon.cpp       | 154 ++++++++++++++++++
 8 files changed, 192 insertions(+), 13 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmNeon/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/ArmNeon/Transforms/Transforms.h
 create mode 100644 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp

diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 1c679bcd049b85..adcd09d78bab59 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -4,3 +4,4 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
 set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
 mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms/Transforms.h
new file mode 100644
index 00000000000000..0dcb80be758cb8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+}  // namespace arm_neon
+
+}  // namespace mlir
+
+#endif  // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSMEDialect
   MLIRArmSMETransforms
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
-  IR/ArmNeonDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
-  DEPENDS
-  MLIRArmNeonIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRSideEffectInterfaces
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+  ArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..48a56e9ce26d8a
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,154 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-vector-lowering"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = dyn_cast<ShapedType>(container))
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Check index maps represent M N K and aren't transposed.
+    auto indexingMaps = op.getIndexingMapsArray();
+    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+                 affineMap.getNumResults() != 2;
+        })) {
+      llvm::dbgs() << "The affine check failed! \n";
+      return failure();
+    }
+
+    // Check iterator types for contract
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() != 3 ||
+        iteratorTypes[0] != vector::IteratorType::parallel ||
+        iteratorTypes[1] != vector::IteratorType::parallel ||
+        iteratorTypes[2] != vector::IteratorType::reduction) {
+      return failure();
+    }
+
+    // Check the tile size by mapping the dimensions of the contract
+    //  -- Tile size: [2, 2, 8]
+    // Infer tile sizes from operands. Check required tile size
+    // Note: RHS is not transposed
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+    auto dimM = lhsType.getDimSize(0);
+    auto dimN = rhsType.getDimSize(0);
+    auto dimK = lhsType.getDimSize(1);
+    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+      return failure();
+    }
+
+    // Check two extsi inputs Rhs Lhs
+    arith::ExtSIOp origLhsExtOp;
+    arith::ExtSIOp origRhsExtOp;
+    if (!(origLhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+        !(origRhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+      return failure();
+    }
+
+    arith::ExtSIOp extsiLhs;
+    arith::ExtSIOp extsiRhs;
+    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+    // following neon instruction. Check inputs for extsi are <=i8
+    if (auto lhsExtInType =
+            origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target lhs type with i8. This is likely redundant
+        Type targetLhsExtTy =
+            matchContainerType(rewriter.getI8Type(), lhsExtInType);
+        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+                                                   origLhsExtOp.getIn());
+      }
+    }
+    if (auto rhsExtInType =
+            origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target rhs type with i8
+        Type targetRhsExtTy =
+            matchContainerType(rewriter.getI8Type(), rhsExtInType);
+        extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+                                                   origRhsExtOp.getIn());
+      }
+    }
+
+    if (!extsiLhs || !extsiRhs) {
+      return failure();
+    }
+
+    // Collapse to 1D vectors required by smmla intrinsic
+    auto collapsedInputType = VectorType::get(
+        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+    auto collapsedOutputType =
+        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+    auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+    auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+    auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+        res.getLoc(), collapsedOutputType, res);
+
+    // Replace the contract with a neon op
+    auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+        collapsedRhs);
+
+    // Reshape output back to 2D
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+                                                     smmlaOp);
+    return success();
+  }
+};
+
+}  // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}



More information about the Mlir-commits mailing list