[Mlir-commits] [mlir] f596394 - Add arm_neon.sdot operation

Ahmed Taei llvmlistbot at llvm.org
Wed Mar 17 08:26:12 PDT 2021


Author: Ahmed Taei
Date: 2021-03-17T08:24:58-07:00
New Revision: f5963944d97d40300eeec8b43ae67aea2115398c

URL: https://github.com/llvm/llvm-project/commit/f5963944d97d40300eeec8b43ae67aea2115398c
DIFF: https://github.com/llvm/llvm-project/commit/f5963944d97d40300eeec8b43ae67aea2115398c.diff

LOG: Add arm_neon.sdot operation

Differential Revision: https://reviews.llvm.org/D98198

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
    mlir/test/Dialect/ArmNeon/roundtrip.mlir
    mlir/test/Target/LLVMIR/arm-neon.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index b87c4b1cc12f..d57337bc8253 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -39,7 +39,7 @@ class ArmNeon_IntrOp<string mnemonic, list<int> overloadedResults,
                      list<int> overloadedOperands, int numResults,
                      list<OpTrait> traits = [], bit requiresAccessGroup = 0>
     : LLVM_IntrOpBase</*dialect=*/ArmNeon_Dialect,
-                      /*opName=*/mnemonic,
+                      /*opName=*/"intr." # mnemonic,
                       /*enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic),
                       /*overloadedResults=*/overloadedResults,
                       /*overloadedOperands=*/overloadedOperands,
@@ -53,6 +53,13 @@ class ArmNeon_OverloadedOneResultIntrOp<string mnemonic,
                                         list<OpTrait> traits = []>
   : ArmNeon_IntrOp<mnemonic, [0], [], 1, traits>;
 
+// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one
+// overloaded result and overloaded operands list.
+class ArmNeon_OverloadedOperandsWithOneResultIntrOp<string mnemonic,
+                                                    list<int> overloadedOperands,
+                                                    list<OpTrait> traits = []>
+  : ArmNeon_IntrOp<mnemonic, [0], overloadedOperands, 1, traits>;
+
 def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [
        NoSideEffect,
        AllTypesMatch<["a", "b"]>,
@@ -82,5 +89,32 @@ def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [
     "$a `,` $b attr-dict `:` type($a) `to` type($res)";
 }
 
+def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
+      NoSideEffect,
+      AllTypesMatch<["b", "c"]>,
+      AllTypesMatch<["a", "res"]>,
+      TypesMatchWith<"res has the same number of elements as operand b",
+                     "b", "res",
+                     "VectorType::get({$_self.cast<VectorType>().getShape()[0] / 4},"
+                     "IntegerType::get($_self.getContext(), 32))">]> {
+  let summary = "sdot op";
+  let description = [{
+    Signed integer addition of dot product (vector). This instruction performs
+    the following operation on signed integer vectors: res = dot(b, c) + a,
+    where vector operands are partitioned into groups of four elements.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
+  }];
+  // Supports either:
+  //   (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32>
+  //   (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<16xi32>
+  let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
+                       VectorOfLengthAndType<[16, 8], [I8]>:$b,
+                       VectorOfLengthAndType<[16, 8], [I8]>:$c);
+  let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
+  }
 
 #endif // ARMNEON_OPS

diff  --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 014da313a089..2252d2c857fe 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -3,18 +3,25 @@
 // CHECK-LABEL: arm_neon_smull
 func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
     -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
-  // CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16>
-  %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
+  // CHECK: arm_neon.intr.smull {{.*}}: vector<8xi8> to vector<8xi16>
+  %0 = arm_neon.intr.smull %a, %b : vector<8xi8> to vector<8xi16>
   %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
     vector<8xi16> to vector<4xi16>
 
-  // CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32>
-  %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
+  // CHECK: arm_neon.intr.smull {{.*}}: vector<4xi16> to vector<4xi32>
+  %1 = arm_neon.intr.smull %00, %00 : vector<4xi16> to vector<4xi32>
   %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
     vector<4xi32> to vector<2xi32>
 
-  // CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64>
-  %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
+  // CHECK: arm_neon.intr.smull {{.*}}: vector<2xi32> to vector<2xi64>
+  %2 = arm_neon.intr.smull %11, %11 : vector<2xi32> to vector<2xi64>
 
   return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
 }
+
+// CHECK-LABEL: arm_neon_sdot
+func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
+  // CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
+  %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
+  return %0 : vector<2xi32>
+}

diff  --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index 3cd7641c5798..d99f573c8bec 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -4,16 +4,16 @@
 llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
   //      CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}})
   // CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
-  %0 = arm_neon.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16>
+  %0 = arm_neon.intr.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16>
   %1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : vector<8xi16>, vector<8xi16>
 
   // CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]])
   // CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> <i32 1, i32 2>
-  %2 = arm_neon.smull %1, %1 : vector<4xi16> to vector<4xi32>
+  %2 = arm_neon.intr.smull %1, %1 : vector<4xi16> to vector<4xi32>
   %3 = llvm.shufflevector %2, %2 [1, 2] : vector<4xi32>, vector<4xi32>
 
   // CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]])
-  %4 = arm_neon.smull %3, %3 : vector<2xi32> to vector<2xi64>
+  %4 = arm_neon.intr.smull %3, %3 : vector<2xi32> to vector<2xi64>
 
   %5 = llvm.mlir.undef : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
   %6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
@@ -23,3 +23,19 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
   //      CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> }
   llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
 }
+
+// CHECK-LABEL: arm_neon_sdot_i8i8
+llvm.func @arm_neon_sdot_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
+  // CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
+  // CHECK-NEXT: ret <2 x i32>
+  %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
+  llvm.return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: arm_neon_sdot_i16i16
+llvm.func @arm_neon_sdot_i16i16(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
+  // CHECK-NEXT: ret <4 x i32>
+  %0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}


        


More information about the Mlir-commits mailing list