[Mlir-commits] [mlir] 001d601 - [mlir][ArmSVE] Add basic arithmetic operations

Alex Zinenko llvmlistbot at llvm.org
Wed May 5 00:50:30 PDT 2021


Author: Javier Setoain
Date: 2021-05-05T09:50:18+02:00
New Revision: 001d601ac4fb1ee02d4bb3990f2f5a8afacd4932

URL: https://github.com/llvm/llvm-project/commit/001d601ac4fb1ee02d4bb3990f2f5a8afacd4932
DIFF: https://github.com/llvm/llvm-project/commit/001d601ac4fb1ee02d4bb3990f2f5a8afacd4932.diff

LOG: [mlir][ArmSVE] Add basic arithmetic operations

While we figure out how to best add Standard support for scalable
vectors, these instructions provide a workaround for basic arithmetic
between scalable vectors.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100837

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Target/LLVMIR/arm-sve.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index 4e75a5628fec..33c60ba7c8a5 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -122,6 +122,42 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                   /*list<OpTrait> traits=*/traits,
                   /*int numResults=*/1>;
 
+class ScalableFOp<string mnemonic, string op_description,
+                  list<OpTrait> traits = []> :
+  ArmSVE_Op<mnemonic, !listconcat(traits,
+                       [AllTypesMatch<["src1", "src2", "dst"]>])> {
+  let summary = op_description # " for scalable vectors of floats";
+  let description = [{
+    The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
+    returns one scalable vector with the result of the }] # op_description # [{.
+  }];
+  let arguments = (ins
+          ScalableVectorOf<[AnyFloat]>:$src1,
+          ScalableVectorOf<[AnyFloat]>:$src2
+  );
+  let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
+  let assemblyFormat =
+    "$src1 `,` $src2 attr-dict `:` type($src1)";
+}
+
+class ScalableIOp<string mnemonic, string op_description,
+                  list<OpTrait> traits = []> :
+  ArmSVE_Op<mnemonic, !listconcat(traits,
+                       [AllTypesMatch<["src1", "src2", "dst"]>])> {
+  let summary = op_description # " for scalable vectors of integers";
+  let description = [{
+    The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
+    returns one scalable vector with the result of the }] # op_description # [{.
+  }];
+  let arguments = (ins
+          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+  );
+  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
+  let assemblyFormat =
+    "$src1 `,` $src2 attr-dict `:` type($src1)";
+}
+
 def SdotOp : ArmSVE_Op<"sdot",
                [NoSideEffect,
                AllTypesMatch<["src1", "src2"]>,
@@ -266,6 +302,25 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
     "attr-dict `:` type($res)";
 }
 
+
+def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
+
+def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
+
+def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
+
+def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
+
+def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
+
+def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
+
+def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
+
+def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
+
+def ScalableDivFOp : ScalableFOp<"divf", "division">;
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
   Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index b0197cbe8a9f..b258f2ad9315 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -84,6 +84,38 @@ using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
 using VectorScaleOpLowering =
     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
 
+static void
+populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
+                                         OwningRewritePatternList &patterns) {
+  // clang-format off
+  patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
+               OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
+               OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
+               OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
+               OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
+               OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
+               OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
+               OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
+               OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
+              >(converter);
+  // clang-format on
+}
+
+static void
+configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
+  // clang-format off
+  target.addIllegalOp<ScalableAddIOp,
+                      ScalableAddFOp,
+                      ScalableSubIOp,
+                      ScalableSubFOp,
+                      ScalableMulIOp,
+                      ScalableMulFOp,
+                      ScalableSDivIOp,
+                      ScalableUDivIOp,
+                      ScalableDivFOp>();
+  // clang-format on
+}
+
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
@@ -106,20 +138,14 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                UmmlaOpLowering,
                VectorScaleOpLowering>(converter);
   // clang-format on
+  populateBasicSVEArithmeticExportPatterns(converter, patterns);
 }
 
 void mlir::configureArmSVELegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addLegalOp<SdotIntrOp>();
-  target.addIllegalOp<SdotOp>();
-  target.addLegalOp<SmmlaIntrOp>();
-  target.addIllegalOp<SmmlaOp>();
-  target.addLegalOp<UdotIntrOp>();
-  target.addIllegalOp<UdotOp>();
-  target.addLegalOp<UmmlaIntrOp>();
-  target.addIllegalOp<UmmlaOp>();
-  target.addLegalOp<VectorScaleIntrOp>();
-  target.addIllegalOp<VectorScaleOp>();
+  target.addLegalOp<SdotIntrOp, SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp,
+                    VectorScaleIntrOp>();
+  target.addIllegalOp<SdotOp, SmmlaOp, UdotOp, UmmlaOp, VectorScaleOp>();
   auto hasScalableVectorType = [](TypeRange types) {
     for (Type type : types)
       if (type.isa<arm_sve::ScalableVectorType>())
@@ -135,4 +161,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
         return !hasScalableVectorType(op->getOperandTypes()) &&
                !hasScalableVectorType(op->getResultTypes());
       });
+  configureBasicSVEArithmeticLegalizations(target);
 }

diff  --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 247b53e9c6d0..f81196f4928f 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -40,6 +40,40 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
   return %0 : !arm_sve.vector<4xi32>
 }
 
+func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
+                     %b: !arm_sve.vector<4xi32>,
+                     %c: !arm_sve.vector<4xi32>,
+                     %d: !arm_sve.vector<4xi32>,
+                     %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
+  %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
+  // CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
+  %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
+  // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
+  %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
+  // CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
+  %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
+  // CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
+  %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
+  return %3 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
+                     %b: !arm_sve.vector<4xf32>,
+                     %c: !arm_sve.vector<4xf32>,
+                     %d: !arm_sve.vector<4xf32>,
+                     %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
+  // CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
+  %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
+  // CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
+  %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
+  // CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
+  %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
+  // CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
+  %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
+  return %3 : !arm_sve.vector<4xf32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vscale
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 8834ef87207b..44cc2fa12217 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -36,6 +36,26 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
   return %0 : !arm_sve.vector<4xi32>
 }
 
+func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
+                     %b: !arm_sve.vector<4xi32>,
+                     %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
+  %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
+  %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
+  return %1 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
+                     %b: !arm_sve.vector<4xf32>,
+                     %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
+  // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
+  %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
+  %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
+  return %1 : !arm_sve.vector<4xf32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vector_scale : index
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 46fc8845e204..71d4b0aee9b4 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,30 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
   llvm.return %0 : !llvm.vec<?x4 x i32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
+llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
+                          %arg1: !llvm.vec<? x 4 x i32>,
+                          %arg2: !llvm.vec<? x 4 x i32>)
+                          -> !llvm.vec<? x 4 x i32> {
+  // CHECK: mul <vscale x 4 x i32>
+  %0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  // CHECK: add <vscale x 4 x i32>
+  %1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
+  llvm.return %1 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
+llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
+                          %arg1: !llvm.vec<? x 4 x f32>,
+                          %arg2: !llvm.vec<? x 4 x f32>)
+                          -> !llvm.vec<? x 4 x f32> {
+  // CHECK: fmul <vscale x 4 x float>
+  %0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
+  // CHECK: fadd <vscale x 4 x float>
+  %1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
+  llvm.return %1 : !llvm.vec<? x 4 x f32>
+}
+
 // CHECK-LABEL: define i64 @get_vector_scale()
 llvm.func @get_vector_scale() -> i64 {
   // CHECK: call i64 @llvm.vscale.i64()


        


More information about the Mlir-commits mailing list