[Mlir-commits] [llvm] [mlir] [mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (PR #81895)

Kojo Acquah llvmlistbot at llvm.org
Fri Mar 8 14:34:16 PST 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/81895

>From 2341c57b18eda6edb5f143bc2eb010aa98464dbf Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 15 Feb 2024 17:59:46 +0000
Subject: [PATCH 1/4] Implement LowerVectorToArmNeon

---
 .../include/mlir/Dialect/ArmNeon/Transforms.h |  21 +++
 .../Conversion/VectorToLLVM/CMakeLists.txt    |   1 +
 mlir/lib/Dialect/ArmNeon/CMakeLists.txt       |  15 +-
 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt    |  13 ++
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |  14 ++
 .../Transforms/LowerVectorToArmNeon.cpp       | 153 ++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  21 +++
 7 files changed, 225 insertions(+), 13 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmNeon/Transforms.h
 create mode 100644 mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp

diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
new file mode 100644
index 00000000000000..41dbc2633d52c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+} // namespace arm_neon
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSMEDialect
   MLIRArmSMETransforms
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
-  IR/ArmNeonDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
-  DEPENDS
-  MLIRArmNeonIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRSideEffectInterfaces
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+  ArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..dcd806e981479d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmNeonTransforms
+  LowerVectorToArmNeon.cpp
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeonDialect
+  MLIRFuncDialect
+  MLIRVectorDialect
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..01806f43a10056
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-vector-lowering"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = dyn_cast<ShapedType>(container))
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Check index maps represent M N K and aren't transposed.
+    auto indexingMaps = op.getIndexingMapsArray();
+    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+                 affineMap.getNumResults() != 2;
+        })) {
+      return failure();
+    }
+
+    // Check iterator types for contract
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() != 3 ||
+        iteratorTypes[0] != vector::IteratorType::parallel ||
+        iteratorTypes[1] != vector::IteratorType::parallel ||
+        iteratorTypes[2] != vector::IteratorType::reduction) {
+      return failure();
+    }
+
+    // Check the tile size by mapping the dimensions of the contract
+    //  -- Tile size: [2, 2, 8]
+    // Infer tile sizes from operands. Check required tile size
+    // Note: RHS is not transposed
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+    auto dimM = lhsType.getDimSize(0);
+    auto dimN = rhsType.getDimSize(0);
+    auto dimK = lhsType.getDimSize(1);
+    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+      return failure();
+    }
+
+    // Check two extsi inputs Rhs Lhs
+    arith::ExtSIOp origLhsExtOp;
+    arith::ExtSIOp origRhsExtOp;
+    if (!(origLhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+        !(origRhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+      return failure();
+    }
+
+    arith::ExtSIOp extsiLhs;
+    arith::ExtSIOp extsiRhs;
+    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+    // following neon instruction. Check inputs for extsi are <=i8
+    if (auto lhsExtInType =
+            origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target lhs type with i8. This is likely redundant
+        Type targetLhsExtTy =
+            matchContainerType(rewriter.getI8Type(), lhsExtInType);
+        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+                                                   origLhsExtOp.getIn());
+      }
+    }
+    if (auto rhsExtInType =
+            origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+        // Target rhs type with i8
+        Type targetRhsExtTy =
+            matchContainerType(rewriter.getI8Type(), rhsExtInType);
+        extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+                                                   origRhsExtOp.getIn());
+      }
+    }
+
+    if (!extsiLhs || !extsiRhs) {
+      return failure();
+    }
+
+    // Collapse to 1D vectors required by smmla intrinsic
+    auto collapsedInputType = VectorType::get(
+        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+    auto collapsedOutputType =
+        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+    auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+    auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+    auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+        res.getLoc(), collapsedOutputType, res);
+
+    // Replace the contract with a neon op
+    auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+        collapsedRhs);
+
+    // Reshape output back to 2D
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+                                                     smmlaOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f33f165992213..59dc1100ea8df5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1929,6 +1929,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArmNeonTransforms",
+    srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
+    hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
+    includes = ["include"],
+    deps = [
+        ":ArithDialect",
+        ":ArmNeonIncGen",
+        ":ArmNeonDialect",
+        ":FuncDialect",
+        ":IR",
+        ":LLVMDialect",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":VectorDialect",
+        ":Transforms",
+        "//llvm:Core",
+        "//llvm:Support",
+    ],
+)
+
 gentbl_cc_library(
     name = "ArmNeonConversionIncGen",
     tbl_outs = [

>From eae425d2138d21b0846c6159a27e841f7e7f6f3d Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Tue, 5 Mar 2024 22:38:21 +0000
Subject: [PATCH 2/4] implement TestLowerToArmNeon

---
 .../Transforms/LowerVectorToArmNeon.cpp       |  2 +-
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 16 +++++
 mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt  | 13 ++++
 .../Dialect/ArmNeon/TestLowerToArmNeon.cpp    | 63 +++++++++++++++++++
 mlir/test/lib/Dialect/CMakeLists.txt          |  1 +
 mlir/tools/mlir-opt/CMakeLists.txt            |  1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 7 files changed, 97 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
 create mode 100644 mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
index 01806f43a10056..f3ac36d08fc82b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -22,7 +22,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#define DEBUG_TYPE "arm-neon-vector-lowering"
+#define DEBUG_TYPE "arm-neon-lower-vector"
 
 using namespace mlir;
 using namespace mlir::arm_neon;
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
new file mode 100644
index 00000000000000..fe7259a9919ccf
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -test-lower-vector-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
+// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
+// CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8>
+// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[D0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32>
+func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 00000000000000..21548ca57701f9
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArmNeonTestPasses
+  TestLowerToArmNeon.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeonDialect
+  MLIRArmNeonTransforms
+  MLIRIR
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
new file mode 100644
index 00000000000000..6398b69cc82816
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
@@ -0,0 +1,63 @@
+//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to ArmNeon as a
+// generally usable sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define PASS_NAME "test-lower-vector-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+struct TestLowerToArmNeon
+    : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
+
+  StringRef getArgument() const final { return PASS_NAME; }
+  StringRef getDescription() const final {
+    return "Tests lower vector to arm Neon.";
+  }
+  TestLowerToArmNeon() = default;
+  TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arm_neon::ArmNeonDialect>();
+  }
+
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void TestLowerToArmNeon::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  populateLowerVectorToArmNeonPatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+
+void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index e20cd4473a3580..29fb4441a24fd2 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_subdirectory(Affine)
 add_subdirectory(Arith)
+add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
 add_subdirectory(Bufferization)
 add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 701fc461b3b4e9..d4faa087bc6a0c 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestFuncToLLVM
     MLIRAffineTransformsTestPasses
     MLIRArithTestPasses
+    MLIRArmNeonTestPasses
     MLIRArmSMETestPasses
     MLIRBufferizationTestPasses
     MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0ba1a3a534e35c..c099bbd97eaecf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
 void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestLowerToArmNeon();
 void registerTestLowerToArmSME();
 void registerTestLowerToLLVM();
 void registerTestMakeIsolatedFromAbovePass();
@@ -236,6 +237,7 @@ void registerTestPasses() {
   mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
+  mlir::test::registerTestLowerToArmNeon();
   mlir::test::registerTestLowerToArmSME();
   mlir::test::registerTestLowerToLLVM();
   mlir::test::registerTestMakeIsolatedFromAbovePass();

>From 339def8576cc95bdd7ce0768cca93ff663b89b49 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Wed, 6 Mar 2024 19:01:38 +0000
Subject: [PATCH 3/4] diego comments

---
 .../include/mlir/Dialect/ArmNeon/Transforms.h |  6 ++--
 .../Dialect/ArmNeon/Transforms/CMakeLists.txt |  2 +-
 ...cpp => LowerContractionToSMMLAPattern.cpp} | 32 +++++++++----------
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 28 +++++++++++++++-
 .../Dialect/ArmNeon/TestLowerToArmNeon.cpp    |  8 ++---
 5 files changed, 49 insertions(+), 27 deletions(-)
 rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerVectorToArmNeon.cpp => LowerContractionToSMMLAPattern.cpp} (87%)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 41dbc2633d52c6..49cad22defec03 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -1,5 +1,4 @@
-//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
-//-*-===//
+//===- Transforms.h - ArmNeon Transformation Entrypoints --------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -13,7 +12,8 @@
 namespace mlir {
 
 namespace arm_neon {
-void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToSMMLAPatternPatterns(
+    RewritePatternSet &patterns);
 } // namespace arm_neon
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index dcd806e981479d..84fb1b0116d2ae 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRArmNeonTransforms
-  LowerVectorToArmNeon.cpp
+  LowerContractionToSMMLAPattern.cpp
 
   DEPENDS
   MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
similarity index 87%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index f3ac36d08fc82b..fc9b84bbfda453 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -1,5 +1,4 @@
-//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
-//-----------===//
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -22,24 +21,24 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#define DEBUG_TYPE "arm-neon-lower-vector"
+#define DEBUG_TYPE "lower-contract-to-arm-neon"
 
 using namespace mlir;
 using namespace mlir::arm_neon;
 
 namespace {
 
-// Return the shaped type with new element type.
+/// Return the shaped type with new element type.
 static Type matchContainerType(Type element, Type container) {
-  if (auto shapedTy = dyn_cast<ShapedType>(container))
+  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
     return shapedTy.clone(element);
-
+  }
   return element;
 }
 
-// Lowering from vector::contractOp directly to the arm neon
-// intrinsic.
-class LowerVectorToArmNeonPattern
+/// Lowering from vector::contractOp directly to the arm neon
+/// intrinsic.
+class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -82,12 +81,11 @@ class LowerVectorToArmNeonPattern
     }
 
     // Check two extsi inputs Rhs Lhs
-    arith::ExtSIOp origLhsExtOp;
-    arith::ExtSIOp origRhsExtOp;
-    if (!(origLhsExtOp =
-              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
-        !(origRhsExtOp =
-              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+    arith::ExtSIOp origLhsExtOp =
+        dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+    arith::ExtSIOp origRhsExtOp =
+        dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+    if (!origLhsExtOp || !origRhsExtOp) {
       return failure();
     }
 
@@ -146,8 +144,8 @@ class LowerVectorToArmNeonPattern
 
 } // namespace
 
-void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
 }
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index fe7259a9919ccf..cba7b00ba77a82 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-lower-vector-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
 // CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
@@ -14,3 +14,29 @@ func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: ve
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_same_types
+// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
+// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
+func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_without_extsi
+// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
+// CHECK-DAG: %[[D0:.*]] = vector.contract
+func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
index 6398b69cc82816..2e5a4a98888211 100644
--- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
+++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
@@ -20,7 +20,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#define PASS_NAME "test-lower-vector-to-arm-neon"
+#define PASS_NAME "test-lower-to-arm-neon"
 
 using namespace mlir;
 using namespace mlir::arm_neon;
@@ -31,9 +31,7 @@ struct TestLowerToArmNeon
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
 
   StringRef getArgument() const final { return PASS_NAME; }
-  StringRef getDescription() const final {
-    return "Tests lower vector to arm Neon.";
-  }
+  StringRef getDescription() const final { return "Tests lower to arm Neon."; }
   TestLowerToArmNeon() = default;
   TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
 
@@ -49,7 +47,7 @@ struct TestLowerToArmNeon
 void TestLowerToArmNeon::runOnOperation() {
   MLIRContext *context = &getContext();
   RewritePatternSet patterns(context);
-  populateLowerVectorToArmNeonPatterns(patterns);
+  populateLowerContractionToSMMLAPatternPatterns(patterns);
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }

>From 91b508afbf95d2403a83a8cf4d4b1ed5e89a92d5 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 8 Mar 2024 22:34:05 +0000
Subject: [PATCH 4/4] further cleanup

---
 .../LowerContractionToSMMLAPattern.cpp        | 33 ++++++++-----------
 1 file changed, 14 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index fc9b84bbfda453..953bc59ed5dc97 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -36,8 +36,8 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-/// Lowering from vector::contractOp directly to the arm neon
-/// intrinsic.
+/// Lowering from a single vector::contractOp directly to the arm neon smmla
+/// intrinsic. The shapes of the contract and intrinsic must match.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -49,7 +49,7 @@ class LowerContractionToSMMLAPattern
     Value rhs = op.getRhs();
     Value res = op.getAcc();
 
-    // Check index maps represent M N K and aren't transposed.
+    // Check index maps that represent M N K in contract.
     auto indexingMaps = op.getIndexingMapsArray();
     if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
           return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
@@ -58,7 +58,7 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Check iterator types for contract
+    // Check iterator types for contract.
     auto iteratorTypes = op.getIteratorTypesArray();
     if (iteratorTypes.size() != 3 ||
         iteratorTypes[0] != vector::IteratorType::parallel ||
@@ -67,10 +67,7 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Check the tile size by mapping the dimensions of the contract
-    //  -- Tile size: [2, 2, 8]
-    // Infer tile sizes from operands. Check required tile size
-    // Note: RHS is not transposed
+    // Check the tile size by mapping the dimensions of the contract.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
     auto dimM = lhsType.getDimSize(0);
@@ -80,7 +77,7 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Check two extsi inputs Rhs Lhs
+    // Check two extsi inputs Rhs Lhs for contract.
     arith::ExtSIOp origLhsExtOp =
         dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
     arith::ExtSIOp origRhsExtOp =
@@ -89,27 +86,25 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    arith::ExtSIOp extsiLhs;
-    arith::ExtSIOp extsiRhs;
     // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
     // following neon instruction. Check inputs for extsi are <=i8
+    Value extsiLhs;
+    Value extsiRhs;
     if (auto lhsExtInType =
             origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
       if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        // Target lhs type with i8. This is likely redundant
         Type targetLhsExtTy =
             matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+        extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
                                                    origLhsExtOp.getIn());
       }
     }
     if (auto rhsExtInType =
             origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
       if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        // Target rhs type with i8
         Type targetRhsExtTy =
             matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+        extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
                                                    origRhsExtOp.getIn());
       }
     }
@@ -123,15 +118,15 @@ class LowerContractionToSMMLAPattern
         {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
     auto collapsedOutputType =
         VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
-    auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+    auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
         extsiLhs.getLoc(), collapsedInputType, extsiLhs);
-    auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+    auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
         extsiRhs.getLoc(), collapsedInputType, extsiRhs);
-    auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+    auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
         res.getLoc(), collapsedOutputType, res);
 
     // Replace the contract with a neon op
-    auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+    auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
         op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
         collapsedRhs);
 



More information about the Mlir-commits mailing list