[Mlir-commits] [mlir] 001d601 - [mlir][ArmSVE] Add basic arithmetic operations
Alex Zinenko
llvmlistbot at llvm.org
Wed May 5 00:50:30 PDT 2021
Author: Javier Setoain
Date: 2021-05-05T09:50:18+02:00
New Revision: 001d601ac4fb1ee02d4bb3990f2f5a8afacd4932
URL: https://github.com/llvm/llvm-project/commit/001d601ac4fb1ee02d4bb3990f2f5a8afacd4932
DIFF: https://github.com/llvm/llvm-project/commit/001d601ac4fb1ee02d4bb3990f2f5a8afacd4932.diff
LOG: [mlir][ArmSVE] Add basic arithmetic operations
While we figure out how to best add Standard support for scalable
vectors, these instructions provide a workaround for basic arithmetic
between scalable vectors.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D100837
Added:
Modified:
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Target/LLVMIR/arm-sve.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index 4e75a5628fec..33c60ba7c8a5 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -122,6 +122,42 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
+class ScalableFOp<string mnemonic, string op_description,
+ list<OpTrait> traits = []> :
+ ArmSVE_Op<mnemonic, !listconcat(traits,
+ [AllTypesMatch<["src1", "src2", "dst"]>])> {
+ let summary = op_description # " for scalable vectors of floats";
+ let description = [{
+ The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
+ returns one scalable vector with the result of the }] # op_description # [{.
+ }];
+ let arguments = (ins
+ ScalableVectorOf<[AnyFloat]>:$src1,
+ ScalableVectorOf<[AnyFloat]>:$src2
+ );
+ let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
+ let assemblyFormat =
+ "$src1 `,` $src2 attr-dict `:` type($src1)";
+}
+
+class ScalableIOp<string mnemonic, string op_description,
+ list<OpTrait> traits = []> :
+ ArmSVE_Op<mnemonic, !listconcat(traits,
+ [AllTypesMatch<["src1", "src2", "dst"]>])> {
+ let summary = op_description # " for scalable vectors of integers";
+ let description = [{
+ The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
+ returns one scalable vector with the result of the }] # op_description # [{.
+ }];
+ let arguments = (ins
+ ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
+ ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+ );
+ let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
+ let assemblyFormat =
+ "$src1 `,` $src2 attr-dict `:` type($src1)";
+}
+
def SdotOp : ArmSVE_Op<"sdot",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
@@ -266,6 +302,25 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
"attr-dict `:` type($res)";
}
+
+def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
+
+def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
+
+def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
+
+def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
+
+def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
+
+def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
+
+def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
+
+def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
+
+def ScalableDivFOp : ScalableFOp<"divf", "division">;
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index b0197cbe8a9f..b258f2ad9315 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -84,6 +84,38 @@ using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
+static void
+populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
+ // clang-format off
+ patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
+ OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
+ OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
+ OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
+ OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
+ OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
+ OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
+ OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
+ OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
+ >(converter);
+ // clang-format on
+}
+
+static void
+configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
+ // clang-format off
+ target.addIllegalOp<ScalableAddIOp,
+ ScalableAddFOp,
+ ScalableSubIOp,
+ ScalableSubFOp,
+ ScalableMulIOp,
+ ScalableMulFOp,
+ ScalableSDivIOp,
+ ScalableUDivIOp,
+ ScalableDivFOp>();
+ // clang-format on
+}
+
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
@@ -106,20 +138,14 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
UmmlaOpLowering,
VectorScaleOpLowering>(converter);
// clang-format on
+ populateBasicSVEArithmeticExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addLegalOp<SdotIntrOp>();
- target.addIllegalOp<SdotOp>();
- target.addLegalOp<SmmlaIntrOp>();
- target.addIllegalOp<SmmlaOp>();
- target.addLegalOp<UdotIntrOp>();
- target.addIllegalOp<UdotOp>();
- target.addLegalOp<UmmlaIntrOp>();
- target.addIllegalOp<UmmlaOp>();
- target.addLegalOp<VectorScaleIntrOp>();
- target.addIllegalOp<VectorScaleOp>();
+ target.addLegalOp<SdotIntrOp, SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp,
+ VectorScaleIntrOp>();
+ target.addIllegalOp<SdotOp, SmmlaOp, UdotOp, UmmlaOp, VectorScaleOp>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
@@ -135,4 +161,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
+ configureBasicSVEArithmeticLegalizations(target);
}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 247b53e9c6d0..f81196f4928f 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -40,6 +40,40 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
return %0 : !arm_sve.vector<4xi32>
}
+func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
+ %b: !arm_sve.vector<4xi32>,
+ %c: !arm_sve.vector<4xi32>,
+ %d: !arm_sve.vector<4xi32>,
+ %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
+ %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
+ // CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
+ %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
+ // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
+ %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
+ // CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
+ %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
+ // CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
+ %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
+ return %3 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
+ %b: !arm_sve.vector<4xf32>,
+ %c: !arm_sve.vector<4xf32>,
+ %d: !arm_sve.vector<4xf32>,
+ %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
+ // CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
+ %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
+ // CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
+ %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
+ // CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
+ %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
+ // CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
+ %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
+ return %3 : !arm_sve.vector<4xf32>
+}
+
func @get_vector_scale() -> index {
// CHECK: arm_sve.vscale
%0 = arm_sve.vector_scale : index
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 8834ef87207b..44cc2fa12217 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -36,6 +36,26 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
return %0 : !arm_sve.vector<4xi32>
}
+func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
+ %b: !arm_sve.vector<4xi32>,
+ %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
+ %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
+ // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
+ %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
+ return %1 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
+ %b: !arm_sve.vector<4xf32>,
+ %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
+ // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
+ %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
+ // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
+ %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
+ return %1 : !arm_sve.vector<4xf32>
+}
+
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 46fc8845e204..71d4b0aee9b4 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,30 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
llvm.return %0 : !llvm.vec<?x4 x i32>
}
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
+llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
+ %arg1: !llvm.vec<? x 4 x i32>,
+ %arg2: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: mul <vscale x 4 x i32>
+ %0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+ // CHECK: add <vscale x 4 x i32>
+ %1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
+ llvm.return %1 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
+llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
+ %arg1: !llvm.vec<? x 4 x f32>,
+ %arg2: !llvm.vec<? x 4 x f32>)
+ -> !llvm.vec<? x 4 x f32> {
+ // CHECK: fmul <vscale x 4 x float>
+ %0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
+ // CHECK: fadd <vscale x 4 x float>
+ %1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
+ llvm.return %1 : !llvm.vec<? x 4 x f32>
+}
+
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()
More information about the Mlir-commits
mailing list