[Mlir-commits] [llvm] [mlir] [mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (PR #81895)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 5 15:37:00 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Kojo Acquah (KoolJBlack)

<details>
<summary>Changes</summary>

This patch adds a the `LowerVectorToArmNeonPattern` patterns to the ArmNeon. 

This pattern inspects `vector.contract` ops that can be 1-1 mapped to an `arm.neon.smmla` intrinsic. The contract ops must be separated into tiles who's inputs must fit that of a single smmla op (`2x8xi32` inputs and `2x2xi32` output). The `vector.contract` inputs must be sign extended from narrow types (<=i8) to be converted. If all conditions are met, an smmla op is inserted with additional `vector.shape_casts` to handle linearizing the input and output dimension. 



---
Full diff: https://github.com/llvm/llvm-project/pull/81895.diff


13 Files Affected:

- (added) mlir/include/mlir/Dialect/ArmNeon/Transforms.h (+21) 
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/ArmNeon/CMakeLists.txt (+2-13) 
- (added) mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt (+13) 
- (added) mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt (+14) 
- (added) mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp (+153) 
- (added) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+16) 
- (added) mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt (+13) 
- (added) mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp (+63) 
- (modified) mlir/test/lib/Dialect/CMakeLists.txt (+1) 
- (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
new file mode 100644
index 00000000000000..41dbc2633d52c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+} // namespace arm_neon
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSMEDialect
   MLIRArmSMETransforms
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
-  IR/ArmNeonDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
-  DEPENDS
-  MLIRArmNeonIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRSideEffectInterfaces
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+  ArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..dcd806e981479d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmNeonTransforms
+  LowerVectorToArmNeon.cpp
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeonDialect
+  MLIRFuncDialect
+  MLIRVectorDialect
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..f3ac36d08fc82b
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-lower-vector"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = dyn_cast<ShapedType>(container))
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Check index maps represent M N K and aren't transposed.
+    auto indexingMaps = op.getIndexingMapsArray();
+    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+                 affineMap.getNumResults() != 2;
+        })) {
+      return failure();
+    }
+
+    // Check iterator types for contract
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() != 3 ||
+        iteratorTypes[0] != vector::IteratorType::parallel ||
+        iteratorTypes[1] != vector::IteratorType::parallel ||
+        iteratorTypes[2] != vector::IteratorType::reduction) {
+      return failure();
+    }
+
+    // Check the tile size by mapping the dimensions of the contract
+    //  -- Tile size: [2, 2, 8]
+    // Infer tile sizes from operands. Check required tile size
+    // Note: RHS is not transposed
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+    auto dimM = lhsType.getDimSize(0);
+    auto dimN = rhsType.getDimSize(0);
+    auto dimK = lhsType.getDimSize(1);
+    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+      return failure();
+    }
+
+    // Check two extsi inputs Rhs Lhs
+    arith::ExtSIOp origLhsExtOp;
+    arith::ExtSIOp origRhsExtOp;
+    if (!(origLhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+        !(origRhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+      return failure();
+    }
+
+    arith::ExtSIOp extsiLhs;
+    arith::ExtSIOp extsiRhs;
+    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+    // following neon instruction. Check inputs for extsi are <=i8
+    if (auto lhsExtInType =
+            origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target lhs type with i8. This is likely redundant
+        Type targetLhsExtTy =
+            matchContainerType(rewriter.getI8Type(), lhsExtInType);
+        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+                                                   origLhsExtOp.getIn());
+      }
+    }
+    if (auto rhsExtInType =
+            origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target rhs type with i8
+        Type targetRhsExtTy =
+            matchContainerType(rewriter.getI8Type(), rhsExtInType);
+        extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+                                                   origRhsExtOp.getIn());
+      }
+    }
+
+    if (!extsiLhs || !extsiRhs) {
+      return failure();
+    }
+
+    // Collapse to 1D vectors required by smmla intrinsic
+    auto collapsedInputType = VectorType::get(
+        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+    auto collapsedOutputType =
+        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+    auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+    auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+    auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+        res.getLoc(), collapsedOutputType, res);
+
+    // Replace the contract with a neon op
+    auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+        collapsedRhs);
+
+    // Reshape output back to 2D
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+                                                     smmlaOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
new file mode 100644
index 00000000000000..fe7259a9919ccf
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -test-lower-vector-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
+// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
+// CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8>
+// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[D0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32>
+func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 00000000000000..21548ca57701f9
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArmNeonTestPasses
+  TestLowerToArmNeon.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeonDialect
+  MLIRArmNeonTransforms
+  MLIRIR
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
new file mode 100644
index 00000000000000..6398b69cc82816
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
@@ -0,0 +1,63 @@
+//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to ArmNeon as a
+// generally usable sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define PASS_NAME "test-lower-vector-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+struct TestLowerToArmNeon
+    : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
+
+  StringRef getArgument() const final { return PASS_NAME; }
+  StringRef getDescription() const final {
+    return "Tests lower vector to arm Neon.";
+  }
+  TestLowerToArmNeon() = default;
+  TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arm_neon::ArmNeonDialect>();
+  }
+
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void TestLowerToArmNeon::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  populateLowerVectorToArmNeonPatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+
+void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index e20cd4473a3580..29fb4441a24fd2 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_subdirectory(Affine)
 add_subdirectory(Arith)
+add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
 add_subdirectory(Bufferization)
 add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 701fc461b3b4e9..d4faa087bc6a0c 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestFuncToLLVM
     MLIRAffineTransformsTestPasses
     MLIRArithTestPasses
+    MLIRArmNeonTestPasses
     MLIRArmSMETestPasses
     MLIRBufferizationTestPasses
     MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0ba1a3a534e35c..c099bbd97eaecf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
 void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestLowerToArmNeon();
 void registerTestLowerToArmSME();
 void registerTestLowerToLLVM();
 void registerTestMakeIsolatedFromAbovePass();
@@ -236,6 +237,7 @@ void registerTestPasses() {
   mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
+  mlir::test::registerTestLowerToArmNeon();
   mlir::test::registerTestLowerToArmSME();
   mlir::test::registerTestLowerToLLVM();
   mlir::test::registerTestMakeIsolatedFromAbovePass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f33f165992213..59dc1100ea8df5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1929,6 +1929,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArmNeonTransforms",
+    srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
+    hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
+    includes = ["include"],
+    deps = [
+        ":ArithDialect",
+        ":ArmNeonIncGen",
+        ":ArmNeonDialect",
+        ":FuncDialect",
+        ":IR",
+        ":LLVMDialect",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":VectorDialect",
+        ":Transforms",
+        "//llvm:Core",
+        "//llvm:Support",
+    ],
+)
+
 gentbl_cc_library(
     name = "ArmNeonConversionIncGen",
     tbl_outs = [

``````````

</details>


https://github.com/llvm/llvm-project/pull/81895


More information about the Mlir-commits mailing list