[Mlir-commits] [llvm] [mlir] [mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (PR #81895)
Diego Caballero
llvmlistbot at llvm.org
Tue Mar 5 17:21:19 PST 2024
================
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-lower-vector"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+ if (auto shapedTy = dyn_cast<ShapedType>(container))
+ return shapedTy.clone(element);
+
+ return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Check index maps represent M N K and aren't transposed.
+ auto indexingMaps = op.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+ return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+ affineMap.getNumResults() != 2;
+ })) {
+ return failure();
+ }
+
+ // Check iterator types for contract
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() != 3 ||
+ iteratorTypes[0] != vector::IteratorType::parallel ||
+ iteratorTypes[1] != vector::IteratorType::parallel ||
+ iteratorTypes[2] != vector::IteratorType::reduction) {
+ return failure();
+ }
+
+ // Check the tile size by mapping the dimensions of the contract
+ // -- Tile size: [2, 2, 8]
+ // Infer tile sizes from operands. Check required tile size
+ // Note: RHS is not transposed
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+ auto dimM = lhsType.getDimSize(0);
+ auto dimN = rhsType.getDimSize(0);
+ auto dimK = lhsType.getDimSize(1);
+ if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+ return failure();
+ }
+
+ // Check two extsi inputs Rhs Lhs
+ arith::ExtSIOp origLhsExtOp;
+ arith::ExtSIOp origRhsExtOp;
+ if (!(origLhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+ !(origRhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
----------------
dcaballe wrote:
Probably good to move the assignments outside the if. We are following that pattern in general for multiple dyn casts
https://github.com/llvm/llvm-project/pull/81895
More information about the Mlir-commits
mailing list