[Mlir-commits] [mlir] 519f591 - [mlir] Add fma operation to std dialect

Eugene Zhulenev llvmlistbot at llvm.org
Wed Feb 17 10:06:10 PST 2021


Author: Eugene Zhulenev
Date: 2021-02-17T10:06:01-08:00
New Revision: 519f5917b458e51d4d12e034490d1a6f42d72f77

URL: https://github.com/llvm/llvm-project/commit/519f5917b458e51d4d12e034490d1a6f42d72f77
DIFF: https://github.com/llvm/llvm-project/commit/519f5917b458e51d4d12e034490d1a6f42d72f77.diff

LOG: [mlir] Add fma operation to std dialect

Will remove `vector.fma` operation in the followup CLs.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96801

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 4e6ff2e359c0..29863c82c502 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -103,7 +103,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
 
 // Base class for standard arithmetic operations.  Requires operands and
 // results to be of the same type, but does not constrain them to specific
-// types.  Individual classes will have `lhs` and `rhs` accessor to operands.
+// types.
 class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
     Op<StandardOps_Dialect, mnemonic,
        !listconcat(traits, [NoSideEffect,
@@ -122,6 +122,32 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
   }];
 }
 
+// Base class for standard binary arithmetic operations.
+class ArithmeticBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticOp<mnemonic, traits> {
+
+  let parser = [{
+    return impl::parseOneResultSameOperandTypeOp(parser, result);
+  }];
+
+  let printer = [{
+    return printStandardBinaryOp(this->getOperation(), p);
+  }];
+}
+
+// Base class for standard ternary arithmetic operations.
+class ArithmeticTernaryOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticOp<mnemonic, traits> {
+
+  let parser = [{
+    return impl::parseOneResultSameOperandTypeOp(parser, result);
+  }];
+
+  let printer = [{
+    return printStandardTernaryOp(this->getOperation(), p);
+  }];
+}
+
 // Base class for standard arithmetic operations on integers, vectors and
 // tensors thereof.  This operation takes two operands and returns one result,
 // each of these is required to be of the same type.  This type may be an
@@ -130,8 +156,8 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 //
 //     <op>i %0, %1 : i32
 //
-class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    ArithmeticOp<mnemonic,
+class IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticBinaryOp<mnemonic,
       !listconcat(traits,
                   [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@@ -145,12 +171,27 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 //
 //     <op>f %0, %1 : f32
 //
-class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    ArithmeticOp<mnemonic,
+class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticBinaryOp<mnemonic,
       !listconcat(traits,
                   [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
     Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
 
+// Base class for standard arithmetic ternary operations on floats, vectors and
+// tensors thereof.  This operation has three operands and returns one result,
+// each of these is required to be of the same type.  This type may be a
+// floating point scalar type, a vector whose element type is a floating point
+// type, or a floating point tensor.  The custom assembly form of the operation
+// is as follows
+//
+//     <op> %0, %1, %2 : f32
+//
+class FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticTernaryOp<mnemonic,
+      !listconcat(traits,
+                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
+    Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>;
+
 // Base class for memref allocating ops: alloca and alloc.
 //
 //   %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@@ -257,7 +298,7 @@ def AbsFOp : FloatUnaryOp<"absf"> {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-def AddFOp : FloatArithmeticOp<"addf"> {
+def AddFOp : FloatBinaryOp<"addf"> {
   let summary = "floating point addition operation";
   let description = [{
     Syntax:
@@ -294,7 +335,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
+def AddIOp : IntBinaryOp<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
     Syntax:
@@ -418,7 +459,7 @@ def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
 // AndOp
 //===----------------------------------------------------------------------===//
 
-def AndOp : IntArithmeticOp<"and", [Commutative]> {
+def AndOp : IntBinaryOp<"and", [Commutative]> {
   let summary = "integer binary and";
   let description = [{
     Syntax:
@@ -1269,7 +1310,7 @@ def ConstantOp : Std_Op<"constant",
 // CopySignOp
 //===----------------------------------------------------------------------===//
 
-def CopySignOp : FloatArithmeticOp<"copysign"> {
+def CopySignOp : FloatBinaryOp<"copysign"> {
   let summary = "A copysign operation";
   let description = [{
     Syntax:
@@ -1384,11 +1425,49 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
 // DivFOp
 //===----------------------------------------------------------------------===//
 
-def DivFOp : FloatArithmeticOp<"divf"> {
+def DivFOp : FloatBinaryOp<"divf"> {
   let summary = "floating point division operation";
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// FmaFOp
+//===----------------------------------------------------------------------===//
+
+def FmaFOp : FloatTernaryOp<"fmaf"> {
+  let summary = "floating point fused multipy-add operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type
+    ```
+
+    The `fmaf` operation takes three operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    Example:
+
+    ```mlir
+    // Scalar fused multiply-add: d = a*b + c
+    %d = fmaf %a, %b, %c : f64
+
+    // SIMD vector fused multiply-add, e.g. for Intel SSE.
+    %i = fmaf %f, %g, %h : vector<4xf32>
+
+    // Tensor fused multiply-add.
+    %w = fmaf %x, %y, %z : tensor<4x?xbf16>
+    ```
+
+    The semantics of the operation correspond to those of the `llvm.fma`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
+    particular case of lowering to LLVM, this is guaranteed to lower
+    to the `llvm.fma.*` intrinsic.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // FPExtOp
 //===----------------------------------------------------------------------===//
@@ -1854,7 +1933,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-def MulFOp : FloatArithmeticOp<"mulf"> {
+def MulFOp : FloatBinaryOp<"mulf"> {
   let summary = "floating point multiplication operation";
   let description = [{
     Syntax:
@@ -1891,7 +1970,7 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
+def MulIOp : IntBinaryOp<"muli", [Commutative]> {
   let summary = "integer multiplication operation";
   let hasFolder = 1;
 }
@@ -1933,7 +2012,7 @@ def NegFOp : FloatUnaryOp<"negf"> {
 // OrOp
 //===----------------------------------------------------------------------===//
 
-def OrOp : IntArithmeticOp<"or", [Commutative]> {
+def OrOp : IntBinaryOp<"or", [Commutative]> {
   let summary = "integer binary or";
   let description = [{
     Syntax:
@@ -2040,7 +2119,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-def RemFOp : FloatArithmeticOp<"remf"> {
+def RemFOp : FloatBinaryOp<"remf"> {
   let summary = "floating point division remainder operation";
 }
 
@@ -2141,7 +2220,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
 // ShiftLeftOp
 //===----------------------------------------------------------------------===//
 
-def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
+def ShiftLeftOp : IntBinaryOp<"shift_left"> {
   let summary = "integer left-shift";
   let description = [{
     The shift_left operation shifts an integer value to the left by a variable
@@ -2161,7 +2240,7 @@ def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
 // SignedDivIOp
 //===----------------------------------------------------------------------===//
 
-def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
+def SignedDivIOp : IntBinaryOp<"divi_signed"> {
   let summary = "signed integer division operation";
   let description = [{
     Syntax:
@@ -2196,7 +2275,7 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
 // SignedFloorDivIOp
 //===----------------------------------------------------------------------===//
 
-def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
+def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> {
   let summary = "signed floor integer division operation";
   let description = [{
     Syntax:
@@ -2225,7 +2304,7 @@ def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
 // SignedCeilDivIOp
 //===----------------------------------------------------------------------===//
 
-def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
+def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> {
   let summary = "signed ceil integer division operation";
   let description = [{
     Syntax:
@@ -2253,7 +2332,7 @@ def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
 // SignedRemIOp
 //===----------------------------------------------------------------------===//
 
-def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
+def SignedRemIOp : IntBinaryOp<"remi_signed"> {
   let summary = "signed integer division remainder operation";
   let description = [{
     Syntax:
@@ -2288,7 +2367,7 @@ def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
 // SignedShiftRightOp
 //===----------------------------------------------------------------------===//
 
-def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
+def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
   let summary = "signed integer right-shift";
   let description = [{
     The shift_right_signed operation shifts an integer value to the right by
@@ -2488,7 +2567,7 @@ def StoreOp : Std_Op<"store",
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-def SubFOp : FloatArithmeticOp<"subf"> {
+def SubFOp : FloatBinaryOp<"subf"> {
   let summary = "floating point subtraction operation";
   let hasFolder = 1;
 }
@@ -2497,7 +2576,7 @@ def SubFOp : FloatArithmeticOp<"subf"> {
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-def SubIOp : IntArithmeticOp<"subi"> {
+def SubIOp : IntBinaryOp<"subi"> {
   let summary = "integer subtraction operation";
   let hasFolder = 1;
 }
@@ -3173,7 +3252,7 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
 // UnsignedDivIOp
 //===----------------------------------------------------------------------===//
 
-def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
+def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> {
   let summary = "unsigned integer division operation";
   let description = [{
     Syntax:
@@ -3208,7 +3287,7 @@ def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
 // UnsignedRemIOp
 //===----------------------------------------------------------------------===//
 
-def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
+def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> {
   let summary = "unsigned integer division remainder operation";
   let description = [{
     Syntax:
@@ -3243,7 +3322,7 @@ def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
 // UnsignedShiftRightOp
 //===----------------------------------------------------------------------===//
 
-def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
+def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> {
   let summary = "unsigned integer right-shift";
   let description = [{
     The shift_right_unsigned operation shifts an integer value to the right by
@@ -3332,7 +3411,7 @@ def ViewOp : Std_Op<"view", [
 // XOrOp
 //===----------------------------------------------------------------------===//
 
-def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
+def XOrOp : IntBinaryOp<"xor", [Commutative]> {
   let summary = "integer binary xor";
   let description = [{
     The `xor` operation takes two operands and returns one result, each of these

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1cccab0f9e93..0b28d2589a46 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1662,6 +1662,7 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
+using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
 using Log10OpLowering =
     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -3775,6 +3776,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       ExpOpLowering,
       Exp2OpLowering,
       FloorFOpLowering,
+      FmaFOpLowering,
       GenericAtomicRMWOpLowering,
       LogOpLowering,
       Log10OpLowering,

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5582c0bde555..52b41ca305d1 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -158,6 +158,32 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
   p << " : " << op->getResult(0).getType();
 }
 
+/// A custom ternary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
+  assert(op->getNumOperands() == 3 && "ternary op should have three operands");
+  assert(op->getNumResults() == 1 && "ternary op should have one result");
+
+  // If not all the operand and result types are the same, just use the
+  // generic assembly form to avoid omitting information in printing.
+  auto resultType = op->getResult(0).getType();
+  if (op->getOperand(0).getType() != resultType ||
+      op->getOperand(1).getType() != resultType ||
+      op->getOperand(2).getType() != resultType) {
+    p.printGenericOp(op);
+    return;
+  }
+
+  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+    << op->getOperand(0) << ", " << op->getOperand(1) << ", "
+    << op->getOperand(2);
+  p.printOptionalAttrDict(op->getAttrs());
+
+  // Now we can output only one type for all operands and the result.
+  p << " : " << op->getResult(0).getType();
+}
+
 /// A custom cast operation printer that omits the "std." prefix from the
 /// operation names.
 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index edf0425a93fe..f4eba23c9d38 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -223,3 +223,16 @@ func @powf(%arg0 : f64) {
   %0 = math.powf %arg0, %arg0 : f64
   std.return
 }
+
+// -----
+
+// CHECK-LABEL: func @fmaf(
+// CHECK-SAME: %[[ARG0:.*]]: f32
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
+  // CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32
+  %0 = fmaf %arg0, %arg0, %arg0 : f32
+  // CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+  %1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
+  std.return
+}


        


More information about the Mlir-commits mailing list