[Mlir-commits] [mlir] 63b3933 - [mlir] [VectorOps] Replace zero fma with mult for vector.contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 30 09:04:37 PDT 2020


Author: aartbik
Date: 2020-06-30T09:04:20-07:00
New Revision: 63b3933d0c3381447a706193d3c0d84927a0fbed

URL: https://github.com/llvm/llvm-project/commit/63b3933d0c3381447a706193d3c0d84927a0fbed
DIFF: https://github.com/llvm/llvm-project/commit/63b3933d0c3381447a706193d3c0d84927a0fbed.diff

LOG: [mlir] [VectorOps] Replace zero fma with mult for vector.contract

More efficient implementation of the multiply-reduce pair,
no need to add in a zero vector. Microbenchmarking on AVX2
yields the following difference in vector.contract speedup
(over strict-order scalar reduction).

SPEEDUP     SIMD-fma SIMD-mul
4x4	    1.45 	 2.00
8x8	    1.40 	 1.90
32x32    	5.32 	 5.80

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D82833

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index effae86e4597..d8cd5a7fe0e8 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1291,8 +1291,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
       Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
       Value m;
       if (acc) {
-        Value z = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-        m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), z);
+        Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+        m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
       } else {
         m = rewriter.create<MulFOp>(loc, b, op.rhs());
       }
@@ -1732,7 +1732,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
 ///   ..
 ///   %x = combine %a %b ..
 /// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a fma/reduction op.
+/// which is replaced by a dot-product/reduction pair.
 ///
 /// TODO(ajcbik): break down into transpose/reshape/cast ops
 ///               when they become available to avoid code dup
@@ -1882,11 +1882,9 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   // Base case.
   if (lhsType.getRank() == 1) {
     assert(rhsType.getRank() == 1 && "corrupt contraction");
-    Value zero = rewriter.create<ConstantOp>(loc, lhsType,
-                                             rewriter.getZeroAttr(lhsType));
-    Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
+    Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
     StringAttr kind = rewriter.getStringAttr("add");
-    return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
+    return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
                                                 op.acc());
   }
   // Construct new iterator types and affine map array attribute.

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index a07675515a9b..0ebe049f5351 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -16,8 +16,7 @@
 // CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
-// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
-// CHECK:      %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
+// CHECK:      %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32>
 // CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
 // CHECK:      return %[[R]] : f32
 
@@ -42,15 +41,14 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
 // CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
-// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
 // CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK:      %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK:      %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
 // CHECK:      return %[[T9]] : vector<2xf32>
@@ -78,15 +76,14 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
 // CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
-// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
 // CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK:      %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK:      %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32>
 // CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
 // CHECK:      return %[[T9]] : vector<2xf32>
@@ -124,7 +121,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
 // CHECK:    %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK:    %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
 // CHECK:    %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
 // CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
 // CHECK:    %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
@@ -134,7 +131,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
 // CHECK:    %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK:    %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
 // CHECK:    %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
 // CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
@@ -147,7 +144,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
 // CHECK:    %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
-// CHECK:    %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32>
 // CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
 // CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
 // CHECK:    %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
@@ -157,7 +154,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
 // CHECK:    %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
-// CHECK:    %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
 // CHECK:    %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
 // CHECK:    %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
@@ -185,14 +182,13 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
 // CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
-// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
 // CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
-// CHECK:      %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK:      %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
 // CHECK:      return %[[T7]] : f32
 
@@ -229,7 +225,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
-// CHECK:      %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
@@ -241,7 +237,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
 // CHECK:      %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
 // CHECK:      %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
-// CHECK:      %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
 // CHECK:      %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
 // CHECK:      return %[[T23]] : f32
 


        


More information about the Mlir-commits mailing list