[Mlir-commits] [mlir] 519f591 - [mlir] Add fma operation to std dialect
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Feb 17 10:06:10 PST 2021
Author: Eugene Zhulenev
Date: 2021-02-17T10:06:01-08:00
New Revision: 519f5917b458e51d4d12e034490d1a6f42d72f77
URL: https://github.com/llvm/llvm-project/commit/519f5917b458e51d4d12e034490d1a6f42d72f77
DIFF: https://github.com/llvm/llvm-project/commit/519f5917b458e51d4d12e034490d1a6f42d72f77.diff
LOG: [mlir] Add fma operation to std dialect
Will remove `vector.fma` operation in the followup CLs.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D96801
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 4e6ff2e359c0..29863c82c502 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -103,7 +103,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
-// types. Individual classes will have `lhs` and `rhs` accessor to operands.
+// types.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect,
@@ -122,6 +122,32 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
}];
}
+// Base class for standard binary arithmetic operations.
+class ArithmeticBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+ ArithmeticOp<mnemonic, traits> {
+
+ let parser = [{
+ return impl::parseOneResultSameOperandTypeOp(parser, result);
+ }];
+
+ let printer = [{
+ return printStandardBinaryOp(this->getOperation(), p);
+ }];
+}
+
+// Base class for standard ternary arithmetic operations.
+class ArithmeticTernaryOp<string mnemonic, list<OpTrait> traits = []> :
+ ArithmeticOp<mnemonic, traits> {
+
+ let parser = [{
+ return impl::parseOneResultSameOperandTypeOp(parser, result);
+ }];
+
+ let printer = [{
+ return printStandardTernaryOp(this->getOperation(), p);
+ }];
+}
+
// Base class for standard arithmetic operations on integers, vectors and
// tensors thereof. This operation takes two operands and returns one result,
// each of these is required to be of the same type. This type may be an
@@ -130,8 +156,8 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
//
// <op>i %0, %1 : i32
//
-class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- ArithmeticOp<mnemonic,
+class IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+ ArithmeticBinaryOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@@ -145,12 +171,27 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
//
// <op>f %0, %1 : f32
//
-class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- ArithmeticOp<mnemonic,
+class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+ ArithmeticBinaryOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
+// Base class for standard arithmetic ternary operations on floats, vectors and
+// tensors thereof. This operation has three operands and returns one result,
+// each of these is required to be of the same type. This type may be a
+// floating point scalar type, a vector whose element type is a floating point
+// type, or a floating point tensor. The custom assembly form of the operation
+// is as follows
+//
+// <op> %0, %1, %2 : f32
+//
+class FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
+ ArithmeticTernaryOp<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
+ Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>;
+
// Base class for memref allocating ops: alloca and alloc.
//
// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@@ -257,7 +298,7 @@ def AbsFOp : FloatUnaryOp<"absf"> {
// AddFOp
//===----------------------------------------------------------------------===//
-def AddFOp : FloatArithmeticOp<"addf"> {
+def AddFOp : FloatBinaryOp<"addf"> {
let summary = "floating point addition operation";
let description = [{
Syntax:
@@ -294,7 +335,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
// AddIOp
//===----------------------------------------------------------------------===//
-def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
+def AddIOp : IntBinaryOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Syntax:
@@ -418,7 +459,7 @@ def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
// AndOp
//===----------------------------------------------------------------------===//
-def AndOp : IntArithmeticOp<"and", [Commutative]> {
+def AndOp : IntBinaryOp<"and", [Commutative]> {
let summary = "integer binary and";
let description = [{
Syntax:
@@ -1269,7 +1310,7 @@ def ConstantOp : Std_Op<"constant",
// CopySignOp
//===----------------------------------------------------------------------===//
-def CopySignOp : FloatArithmeticOp<"copysign"> {
+def CopySignOp : FloatBinaryOp<"copysign"> {
let summary = "A copysign operation";
let description = [{
Syntax:
@@ -1384,11 +1425,49 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
// DivFOp
//===----------------------------------------------------------------------===//
-def DivFOp : FloatArithmeticOp<"divf"> {
+def DivFOp : FloatBinaryOp<"divf"> {
let summary = "floating point division operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// FmaFOp
+//===----------------------------------------------------------------------===//
+
+def FmaFOp : FloatTernaryOp<"fmaf"> {
+ let summary = "floating point fused multipy-add operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type
+ ```
+
+ The `fmaf` operation takes three operands and returns one result, each of
+ these is required to be the same type. This type may be a floating point
+ scalar type, a vector whose element type is a floating point type, or a
+ floating point tensor.
+
+ Example:
+
+ ```mlir
+ // Scalar fused multiply-add: d = a*b + c
+ %d = fmaf %a, %b, %c : f64
+
+ // SIMD vector fused multiply-add, e.g. for Intel SSE.
+ %i = fmaf %f, %g, %h : vector<4xf32>
+
+ // Tensor fused multiply-add.
+ %w = fmaf %x, %y, %z : tensor<4x?xbf16>
+ ```
+
+ The semantics of the operation correspond to those of the `llvm.fma`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
+ particular case of lowering to LLVM, this is guaranteed to lower
+ to the `llvm.fma.*` intrinsic.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
@@ -1854,7 +1933,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
// MulFOp
//===----------------------------------------------------------------------===//
-def MulFOp : FloatArithmeticOp<"mulf"> {
+def MulFOp : FloatBinaryOp<"mulf"> {
let summary = "floating point multiplication operation";
let description = [{
Syntax:
@@ -1891,7 +1970,7 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
// MulIOp
//===----------------------------------------------------------------------===//
-def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
+def MulIOp : IntBinaryOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasFolder = 1;
}
@@ -1933,7 +2012,7 @@ def NegFOp : FloatUnaryOp<"negf"> {
// OrOp
//===----------------------------------------------------------------------===//
-def OrOp : IntArithmeticOp<"or", [Commutative]> {
+def OrOp : IntBinaryOp<"or", [Commutative]> {
let summary = "integer binary or";
let description = [{
Syntax:
@@ -2040,7 +2119,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
// RemFOp
//===----------------------------------------------------------------------===//
-def RemFOp : FloatArithmeticOp<"remf"> {
+def RemFOp : FloatBinaryOp<"remf"> {
let summary = "floating point division remainder operation";
}
@@ -2141,7 +2220,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
// ShiftLeftOp
//===----------------------------------------------------------------------===//
-def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
+def ShiftLeftOp : IntBinaryOp<"shift_left"> {
let summary = "integer left-shift";
let description = [{
The shift_left operation shifts an integer value to the left by a variable
@@ -2161,7 +2240,7 @@ def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
// SignedDivIOp
//===----------------------------------------------------------------------===//
-def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
+def SignedDivIOp : IntBinaryOp<"divi_signed"> {
let summary = "signed integer division operation";
let description = [{
Syntax:
@@ -2196,7 +2275,7 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
// SignedFloorDivIOp
//===----------------------------------------------------------------------===//
-def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
+def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> {
let summary = "signed floor integer division operation";
let description = [{
Syntax:
@@ -2225,7 +2304,7 @@ def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
// SignedCeilDivIOp
//===----------------------------------------------------------------------===//
-def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
+def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> {
let summary = "signed ceil integer division operation";
let description = [{
Syntax:
@@ -2253,7 +2332,7 @@ def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
// SignedRemIOp
//===----------------------------------------------------------------------===//
-def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
+def SignedRemIOp : IntBinaryOp<"remi_signed"> {
let summary = "signed integer division remainder operation";
let description = [{
Syntax:
@@ -2288,7 +2367,7 @@ def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
// SignedShiftRightOp
//===----------------------------------------------------------------------===//
-def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
+def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
let summary = "signed integer right-shift";
let description = [{
The shift_right_signed operation shifts an integer value to the right by
@@ -2488,7 +2567,7 @@ def StoreOp : Std_Op<"store",
// SubFOp
//===----------------------------------------------------------------------===//
-def SubFOp : FloatArithmeticOp<"subf"> {
+def SubFOp : FloatBinaryOp<"subf"> {
let summary = "floating point subtraction operation";
let hasFolder = 1;
}
@@ -2497,7 +2576,7 @@ def SubFOp : FloatArithmeticOp<"subf"> {
// SubIOp
//===----------------------------------------------------------------------===//
-def SubIOp : IntArithmeticOp<"subi"> {
+def SubIOp : IntBinaryOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
}
@@ -3173,7 +3252,7 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
// UnsignedDivIOp
//===----------------------------------------------------------------------===//
-def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
+def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> {
let summary = "unsigned integer division operation";
let description = [{
Syntax:
@@ -3208,7 +3287,7 @@ def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
// UnsignedRemIOp
//===----------------------------------------------------------------------===//
-def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
+def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> {
let summary = "unsigned integer division remainder operation";
let description = [{
Syntax:
@@ -3243,7 +3322,7 @@ def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
// UnsignedShiftRightOp
//===----------------------------------------------------------------------===//
-def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
+def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> {
let summary = "unsigned integer right-shift";
let description = [{
The shift_right_unsigned operation shifts an integer value to the right by
@@ -3332,7 +3411,7 @@ def ViewOp : Std_Op<"view", [
// XOrOp
//===----------------------------------------------------------------------===//
-def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
+def XOrOp : IntBinaryOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let description = [{
The `xor` operation takes two operands and returns one result, each of these
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1cccab0f9e93..0b28d2589a46 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1662,6 +1662,7 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
+using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
using Log10OpLowering =
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -3775,6 +3776,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
ExpOpLowering,
Exp2OpLowering,
FloorFOpLowering,
+ FmaFOpLowering,
GenericAtomicRMWOpLowering,
LogOpLowering,
Log10OpLowering,
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5582c0bde555..52b41ca305d1 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -158,6 +158,32 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
p << " : " << op->getResult(0).getType();
}
+/// A custom ternary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
+ assert(op->getNumOperands() == 3 && "ternary op should have three operands");
+ assert(op->getNumResults() == 1 && "ternary op should have one result");
+
+ // If not all the operand and result types are the same, just use the
+ // generic assembly form to avoid omitting information in printing.
+ auto resultType = op->getResult(0).getType();
+ if (op->getOperand(0).getType() != resultType ||
+ op->getOperand(1).getType() != resultType ||
+ op->getOperand(2).getType() != resultType) {
+ p.printGenericOp(op);
+ return;
+ }
+
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+ << op->getOperand(0) << ", " << op->getOperand(1) << ", "
+ << op->getOperand(2);
+ p.printOptionalAttrDict(op->getAttrs());
+
+ // Now we can output only one type for all operands and the result.
+ p << " : " << op->getResult(0).getType();
+}
+
/// A custom cast operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index edf0425a93fe..f4eba23c9d38 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -223,3 +223,16 @@ func @powf(%arg0 : f64) {
%0 = math.powf %arg0, %arg0 : f64
std.return
}
+
+// -----
+
+// CHECK-LABEL: func @fmaf(
+// CHECK-SAME: %[[ARG0:.*]]: f32
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
+ // CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32
+ %0 = fmaf %arg0, %arg0, %arg0 : f32
+ // CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+ %1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
+ std.return
+}
More information about the Mlir-commits
mailing list