[llvm] [mlir] [mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (PR #81895)

Diego Caballero via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 11:37:30 PST 2024


================
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-lower-vector"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = dyn_cast<ShapedType>(container))
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Check index maps represent M N K and aren't transposed.
+    auto indexingMaps = op.getIndexingMapsArray();
+    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+                 affineMap.getNumResults() != 2;
+        })) {
+      return failure();
+    }
+
+    // Check iterator types for contract
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() != 3 ||
+        iteratorTypes[0] != vector::IteratorType::parallel ||
+        iteratorTypes[1] != vector::IteratorType::parallel ||
+        iteratorTypes[2] != vector::IteratorType::reduction) {
+      return failure();
+    }
+
+    // Check the tile size by mapping the dimensions of the contract
+    //  -- Tile size: [2, 2, 8]
+    // Infer tile sizes from operands. Check required tile size
+    // Note: RHS is not transposed
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+    auto dimM = lhsType.getDimSize(0);
+    auto dimN = rhsType.getDimSize(0);
+    auto dimK = lhsType.getDimSize(1);
+    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+      return failure();
+    }
+
+    // Check two extsi inputs Rhs Lhs
+    arith::ExtSIOp origLhsExtOp;
+    arith::ExtSIOp origRhsExtOp;
+    if (!(origLhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+        !(origRhsExtOp =
+              dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+      return failure();
+    }
+
+    arith::ExtSIOp extsiLhs;
+    arith::ExtSIOp extsiRhs;
+    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+    // following neon instruction. Check inputs for extsi are <=i8
+    if (auto lhsExtInType =
+            origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
----------------
dcaballe wrote:

Let's see if I got this properly: if `lhsExtInType.getElementTypeBitWidth()` is `8`, `rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy, origLhsExtOp.getIn())` would create an i8 -> i8 extension, which is redundant.

If that is the case, you can just assign orighLhsExtOp to extSiLhs for that case.

So perhaps you could do something like:
```
    arith::ExtSIOp extsiLhs = origLhsExtOp.getIn(); // Perhaps we want to change the name of `extsiLhs` since it may not be an extsi op for some cases.
      if (lhsExtInType.getElementTypeBitWidth() < 8) {  // Note that I removed the `=`.
        Type targetLhsExtTy =
            matchContainerType(rewriter.getI8Type(), lhsExtInType);
        extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
                                                   origLhsExtOp.getIn());
      }
```

Sorry if I got this wrong! :)

https://github.com/llvm/llvm-project/pull/81895


More information about the llvm-commits mailing list