[Mlir-commits] [mlir] 63b3933 - [mlir] [VectorOps] Replace zero fma with mult for vector.contract
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 30 09:04:37 PDT 2020
Author: aartbik
Date: 2020-06-30T09:04:20-07:00
New Revision: 63b3933d0c3381447a706193d3c0d84927a0fbed
URL: https://github.com/llvm/llvm-project/commit/63b3933d0c3381447a706193d3c0d84927a0fbed
DIFF: https://github.com/llvm/llvm-project/commit/63b3933d0c3381447a706193d3c0d84927a0fbed.diff
LOG: [mlir] [VectorOps] Replace zero fma with mult for vector.contract
More efficient implementation of the multiply-reduce pair,
no need to add in a zero vector. Microbenchmarking on AVX2
yields the following difference in vector.contract speedup
(over strict-order scalar reduction).
SPEEDUP SIMD-fma SIMD-mul
4x4 1.45 2.00
8x8 1.40 1.90
32x32 5.32 5.80
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D82833
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index effae86e4597..d8cd5a7fe0e8 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1291,8 +1291,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value m;
if (acc) {
- Value z = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
- m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), z);
+ Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+ m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
} else {
m = rewriter.create<MulFOp>(loc, b, op.rhs());
}
@@ -1732,7 +1732,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a fma/reduction op.
+/// which is replaced by a dot-product/reduction pair.
///
/// TODO(ajcbik): break down into transpose/reshape/cast ops
/// when they become available to avoid code dup
@@ -1882,11 +1882,9 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
- Value zero = rewriter.create<ConstantOp>(loc, lhsType,
- rewriter.getZeroAttr(lhsType));
- Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
+ Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
StringAttr kind = rewriter.getStringAttr("add");
- return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
+ return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
op.acc());
}
// Construct new iterator types and affine map array attribute.
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index a07675515a9b..0ebe049f5351 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -16,8 +16,7 @@
// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
+// CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32>
// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
// CHECK: return %[[R]] : f32
@@ -42,15 +41,14 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
@@ -78,15 +76,14 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
@@ -124,7 +121,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
@@ -134,7 +131,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
@@ -147,7 +144,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
-// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32>
// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
@@ -157,7 +154,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
-// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
+// CHECK: %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
@@ -185,14 +182,13 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
-// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
// CHECK: return %[[T7]] : f32
@@ -229,7 +225,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
-// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
@@ -241,7 +237,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
-// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
// CHECK: return %[[T23]] : f32
More information about the Mlir-commits
mailing list