[Mlir-commits] [mlir] 20daeda - 2d Arm Neon sdot op, and lowering to the intrinsic.

Ahmed Taei llvmlistbot at llvm.org
Thu Jun 10 14:36:52 PDT 2021


Author: Benoit Jacob
Date: 2021-06-10T14:36:39-07:00
New Revision: 20daedacca803b81db6d8773b705345702bf0fc3

URL: https://github.com/llvm/llvm-project/commit/20daedacca803b81db6d8773b705345702bf0fc3
DIFF: https://github.com/llvm/llvm-project/commit/20daedacca803b81db6d8773b705345702bf0fc3.diff

LOG: 2d Arm Neon sdot op, and lowering to the intrinsic.

This adds Sdot2d op, which is similar to the usual Neon
intrinsic except that it takes 2d vector operands, reflecting the
structure of the arithmetic that it's performing: 4 separate
4-dimensional dot products, whence the vector<4x4xi8> shape.

This also adds a new pass, arm-neon-2d-to-intr, lowering
this new 2d op to the 1d intrinsic.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D102504

Added: 
    mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h
    mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
    mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt
    mlir/test/Dialect/ArmNeon/invalid.mlir
    mlir/test/Target/LLVMIR/arm-neon-2d.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h
new file mode 100644
index 0000000000000..a27e91f96e8d9
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h
@@ -0,0 +1,30 @@
+//===- ArmNeon2dToIntr.h - convert Arm Neon 2d ops to intrinsics ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_
+#define MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class FuncOp;
+template <typename T>
+class OperationPass;
+
+/// Populates patterns for the lowering of Arm NEON 2D ops to intrinsics.
+/// See createConvertArmNeon2dToIntrPass.
+void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns);
+
+/// Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e.
+/// equivalent ops operating on flattened 1D vectors and mapping more
+/// directly to the corresponding Arm NEON instruction.
+std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index d75e617a902c6..a78b72894c49c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -10,6 +10,7 @@
 #define MLIR_CONVERSION_PASSES_H
 
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 508a0084015fc..ba5e27a2a87c6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -607,4 +607,15 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmNeon2dToIntr
+//===----------------------------------------------------------------------===//
+
+def ConvertArmNeon2dToIntr : Pass<"arm-neon-2d-to-intr", "FuncOp"> {
+  let summary = "Convert Arm NEON structured ops to intrinsics";
+  let constructor = "mlir::createConvertArmNeon2dToIntrPass()";
+  let dependentDialects = ["arm_neon::ArmNeonDialect", "vector::VectorDialect"];
+}
+
+
 #endif // MLIR_CONVERSION_PASSES

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index a9d9f6f539dd4..b530c62f902ba 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
 
 //===----------------------------------------------------------------------===//
 // ArmNeon dialect definition
@@ -117,4 +118,58 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
     "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
   }
 
+class ArmNeon_2dOp<string mnemonic, list<OpTrait> traits = []>
+    : Op</*dialect=*/ArmNeon_Dialect,
+         /*opName=*/"2d." # mnemonic,
+         /*traits=*/traits>;
+
+def Sdot2dOp : ArmNeon_2dOp<"sdot", [
+      NoSideEffect,
+      AllTypesMatch<["b", "c"]>,
+      AllTypesMatch<["a", "res"]>,
+      PredOpTrait<
+        "operand `a` should be 1-dimensional",
+        CPred<"a().getType().cast<VectorType>().getShape().size() == 1">
+      >,
+      PredOpTrait<
+        "operand `b` should be 2-dimensional",
+        CPred<"b().getType().cast<VectorType>().getShape().size() == 2">
+      >,
+      PredOpTrait<
+        "operand `b` should have 4 columns",
+        CPred<"b().getType().cast<VectorType>().getShape()[1] == 4">
+      >,
+      PredOpTrait<
+        "operand `b` should have as many rows as the size of operand `a`",
+        CPred<"b().getType().cast<VectorType>().getShape()[0] == a().getType().cast<VectorType>().getShape()[0]">
+      >,
+      ]
+  > {
+  let summary = "sdot op";
+  let description = [{
+    The two input vectors `b` and `c` have a 2D shape, consisting of either 2
+    or 4 rows, each row having length 4. This operation computes the pair-wise
+    dot-products of the rows of `b` and `c` and accumulates them with the
+    corresponding entry of `a`:
+
+    ```
+    res[i] := a[i] + dot_product(b[i, ...], c[i, ...])
+    ```
+
+  }];
+  // Supports either:
+  //   (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32>
+  //   (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32>
+  // TODO: how do we express 2D shape requirements here?
+  let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
+                       VectorOfLengthAndType<[16, 8], [I8]>:$b,
+                       VectorOfLengthAndType<[16, 8], [I8]>:$c);
+  let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
+  let extraClassDeclaration = [{
+    static constexpr int kReductionSize = 4;
+  }];
+}
+
 #endif // ARMNEON_OPS

diff  --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
new file mode 100644
index 0000000000000..6d7e7c102b042
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -0,0 +1,75 @@
+//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "../PassDetail.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  /// Convert to 1-dimensional vector type to match the requirements of
+  /// arm.neon.intr.sdot
+  LogicalResult matchAndRewrite(Sdot2dOp op,
+                                PatternRewriter &rewriter) const override {
+    Type elemType = op.b().getType().cast<VectorType>().getElementType();
+    int length = op.b().getType().cast<VectorType>().getShape()[0] *
+                 Sdot2dOp::kReductionSize;
+    VectorType flattenedVectorType = VectorType::get({length}, elemType);
+    Value b2d = op.b();
+    Value c2d = op.c();
+    Location loc = op.getLoc();
+    Value b1d =
+        rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
+    Value c1d =
+        rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
+    Value newOp =
+        rewriter.create<SdotOp>(loc, op.res().getType(), op.a(), b1d, c1d);
+    rewriter.replaceOp(op, {newOp});
+    return success();
+  }
+};
+
+class ConvertArmNeon2dToIntr
+    : public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
+  void runOnOperation() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    RewritePatternSet patterns(context);
+    populateConvertArmNeon2dToIntrPatterns(patterns);
+
+    if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+
+void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) {
+  patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass() {
+  return std::make_unique<ConvertArmNeon2dToIntr>();
+}
+
+} // namespace mlir

diff  --git a/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt
new file mode 100644
index 0000000000000..5c729c86373a3
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArmNeon2dToIntr
+  ArmNeon2dToIntr.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeon2dToIntr
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeon
+  MLIRPass
+  MLIRTransforms
+  MLIRIR
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b89be2a569417..72cfb08405ace 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(AffineToStandard)
+add_subdirectory(ArmNeon2dToIntr)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToStandard)

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 5287c6c1490b5..993cac8dad9dd 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -80,6 +80,10 @@ namespace vector {
 class VectorDialect;
 } // end namespace vector
 
+namespace arm_neon {
+class ArmNeonDialect;
+} // end namespace arm_neon
+
 #define GEN_PASS_CLASSES
 #include "mlir/Conversion/Passes.h.inc"
 

diff  --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
new file mode 100644
index 0000000000000..460cdc019aefa
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func @a_is_2d(%a : vector<2x2xi32>, %b : vector<4x4xi8>) -> vector<2x2xi32> {
+    // expected-error at +1 {{operand `a` should be 1-dimensional}}
+    %0 = arm_neon.2d.sdot %a, %b, %b : vector<4x4xi8>, vector<4x4xi8> to vector<2x2xi32>
+    return %0 : vector<2x2xi32>
+}
+
+// -----
+
+func @b_is_3d(%a : vector<4xi32>, %b : vector<1x4x4xi8>) -> vector<4xi32> {
+    // expected-error at +1 {{operand `b` should be 2-dimensional}}
+    %0 = arm_neon.2d.sdot %a, %b, %b : vector<1x4x4xi8>, vector<1x4x4xi8> to vector<4xi32>
+    return %0 : vector<4xi32>
+}
+
+// -----
+
+func @b_has_2_columns(%a : vector<4xi32>, %b : vector<4x2xi8>) -> vector<4xi32> {
+    // expected-error at +1 {{operand `b` should have 4 columns}}
+    %0 = arm_neon.2d.sdot %a, %b, %b : vector<4x2xi8>, vector<4x2xi8> to vector<4xi32>
+    return %0 : vector<4xi32>
+}
+
+// -----
+
+func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi8>) -> vector<4xi32> {
+    // expected-error at +1 {{operand `b` should have as many rows as the size of operand `a`}}
+    %0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
+    return %0 : vector<4xi32>
+}

diff  --git a/mlir/test/Target/LLVMIR/arm-neon-2d.mlir b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir
new file mode 100644
index 0000000000000..b75afdc05af4b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -arm-neon-2d-to-intr %s | FileCheck %s
+
+// CHECK-LABEL: arm_neon_sdot2d_4x4_i8i8
+func @arm_neon_sdot2d_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32>
+  // CHECK-NEXT: return %{{.*}} : vector<4xi32>
+  %0 = arm_neon.2d.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: arm_neon_sdot2d_2x4_i8i8
+func @arm_neon_sdot2d_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> {
+  // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: return %{{.*}} : vector<2xi32>
+  %0 = arm_neon.2d.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32>
+  return %0 : vector<2xi32>
+}


        


More information about the Mlir-commits mailing list