[Mlir-commits] [mlir] 9bf6354 - [mlir] [VectorOps] Allow AXPY to be expressed as special case of OUTERPRODUCT
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 10 12:23:45 PDT 2020
Author: aartbik
Date: 2020-07-10T12:23:24-07:00
New Revision: 9bf6354301ac4e1c7a00e4ef46decba38840fe62
URL: https://github.com/llvm/llvm-project/commit/9bf6354301ac4e1c7a00e4ef46decba38840fe62
DIFF: https://github.com/llvm/llvm-project/commit/9bf6354301ac4e1c7a00e4ef46decba38840fe62.diff
LOG: [mlir] [VectorOps] Allow AXPY to be expressed as special case of OUTERPRODUCT
This specialization allows sharing more code where an AXPY follows naturally
in cases where an OUTERPRODUCT on a scalar would be generated.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D83453
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b205e5a2e286..57490c378041 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -91,7 +91,7 @@ def Vector_ContractionOp :
Example:
```mlir
- // Simple dot product (K = 0).
+ // Simple DOT product (K = 0).
#contraction_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
@@ -668,19 +668,36 @@ def Vector_InsertStridedSliceOp :
}
def Vector_OuterProductOp :
- Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
- Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
+ Vector_Op<"outerproduct", [NoSideEffect,
+ PredOpTrait<"lhs operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ PredOpTrait<"rhs operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 1>>]>,
+ Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic<AnyVector>:$acc)>,
Results<(outs AnyVector)> {
let summary = "vector outerproduct with optional fused add";
let description = [{
- Takes 2 1-D vectors and returns the 2-D vector containing the outer-product.
+ Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
+ as illustrated below:
+ ```
+ outer | [c, d]
+ ------+------------
+ [a, | [ [a*c, a*d],
+ b] | [b*c, b*d] ]
+ ```
+ This operation also accepts a 1-D vector lhs and a scalar rhs. In this
+ case a simple AXPY operation is performed, which returns a 1-D vector.
+ ```
+ [a, b] * c = [a*c, b*c]
+ ```
- An optional extra 2-D vector argument may be specified in which case the
- operation returns the sum of the outer-product and the extra vector. In this
- multiply-accumulate scenario, the rounding mode is that obtained by
- guaranteeing that a fused-multiply add operation is emitted. When lowered to
- the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to
- lower to actual `fma` instructions on x86.
+ An optional extra vector argument with the same shape as the output
+ vector may be specified in which case the operation returns the sum of
+ the outer-product and the extra vector. In this multiply-accumulate
+ scenario for floating-point arguments, the rounding mode is enforced
+ by guaranteeing that a fused-multiply add operation is emitted. When
+ lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which
+ is guaranteed to lower to actual `fma` instructions on x86.
Example:
@@ -691,6 +708,10 @@ def Vector_OuterProductOp :
%3 = vector.outerproduct %0, %1, %2:
vector<4xf32>, vector<8xf32>, vector<4x8xf32>
return %3: vector<4x8xf32>
+
+ %6 = vector.outerproduct %4, %5: vector<10xf32>, f32
+ return %6: vector<10xf32>
+
```
}];
let builders = [
@@ -702,12 +723,13 @@ def Vector_OuterProductOp :
VectorType getOperandVectorTypeLHS() {
return lhs().getType().cast<VectorType>();
}
- VectorType getOperandVectorTypeRHS() {
- return rhs().getType().cast<VectorType>();
+ Type getOperandTypeRHS() {
+ return rhs().getType();
}
VectorType getOperandVectorTypeACC() {
- return (llvm::size(acc()) == 0) ? VectorType() :
- (*acc().begin()).getType().cast<VectorType>();
+ return (llvm::size(acc()) == 0)
+ ? VectorType()
+ : (*acc().begin()).getType().cast<VectorType>();
}
VectorType getVectorType() {
return getResult().getType().cast<VectorType>();
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
index a6470c656f8c..9b86976e4901 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
@@ -11,6 +11,8 @@
!vector_type_Y = type vector<3xf32>
!vector_type_Z = type vector<2x3xf32>
+!vector_type_R = type vector<7xf32>
+
func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C {
%a = splat %fa: !vector_type_A
%b = splat %fb: !vector_type_B
@@ -33,6 +35,7 @@ func @vector_outerproduct_vec_2x3_acc(%x : !vector_type_X,
}
func @entry() {
+ %f0 = constant 0.0: f32
%f1 = constant 1.0: f32
%f2 = constant 2.0: f32
%f3 = constant 3.0: f32
@@ -72,5 +75,26 @@ func @entry() {
//
// CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) )
+ %3 = vector.broadcast %f0 : f32 to !vector_type_R
+ %4 = vector.insert %f1, %3[1] : f32 into !vector_type_R
+ %5 = vector.insert %f2, %4[2] : f32 into !vector_type_R
+ %6 = vector.insert %f3, %5[3] : f32 into !vector_type_R
+ %7 = vector.insert %f4, %6[4] : f32 into !vector_type_R
+ %8 = vector.insert %f5, %7[5] : f32 into !vector_type_R
+ %9 = vector.insert %f10, %8[6] : f32 into !vector_type_R
+
+ %o = vector.broadcast %f1 : f32 to !vector_type_R
+
+ %axpy1 = vector.outerproduct %9, %f2 : !vector_type_R, f32
+ %axpy2 = vector.outerproduct %9, %f2, %o : !vector_type_R, f32
+
+ vector.print %axpy1 : !vector_type_R
+ vector.print %axpy2 : !vector_type_R
+ //
+ // axpy operations:
+ //
+ // CHECK: ( 0, 2, 4, 6, 8, 10, 20 )
+ // CHECK: ( 1, 3, 5, 7, 9, 11, 21 )
+
return
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
index 12a71d85b3f7..c2724aec9b37 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
@@ -11,6 +11,8 @@
!vector_type_Y = type vector<3xi64>
!vector_type_Z = type vector<2x3xi64>
+!vector_type_R = type vector<7xi64>
+
func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
%a = splat %ia: !vector_type_A
%b = splat %ib: !vector_type_B
@@ -33,6 +35,7 @@ func @vector_outerproduct_vec_2x3_acc(%x : !vector_type_X,
}
func @entry() {
+ %i0 = constant 0: i64
%i1 = constant 1: i64
%i2 = constant 2: i64
%i3 = constant 3: i64
@@ -72,5 +75,26 @@ func @entry() {
//
// CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) )
+ %3 = vector.broadcast %i0 : i64 to !vector_type_R
+ %4 = vector.insert %i1, %3[1] : i64 into !vector_type_R
+ %5 = vector.insert %i2, %4[2] : i64 into !vector_type_R
+ %6 = vector.insert %i3, %5[3] : i64 into !vector_type_R
+ %7 = vector.insert %i4, %6[4] : i64 into !vector_type_R
+ %8 = vector.insert %i5, %7[5] : i64 into !vector_type_R
+ %9 = vector.insert %i10, %8[6] : i64 into !vector_type_R
+
+ %o = vector.broadcast %i1 : i64 to !vector_type_R
+
+ %axpy1 = vector.outerproduct %9, %i2 : !vector_type_R, i64
+ %axpy2 = vector.outerproduct %9, %i2, %o : !vector_type_R, i64
+
+ vector.print %axpy1 : !vector_type_R
+ vector.print %axpy2 : !vector_type_R
+ //
+ // axpy operations:
+ //
+ // CHECK: ( 0, 2, 4, 6, 8, 10, 20 )
+ // CHECK: ( 1, 3, 5, 7, 9, 11, 21 )
+
return
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index cdf09c4a8f68..ca85625c7e4c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1203,10 +1203,13 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
"expected at least 2 operands");
VectorType vLHS = tLHS.dyn_cast<VectorType>();
VectorType vRHS = tRHS.dyn_cast<VectorType>();
- if (!vLHS || !vRHS)
- return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
- VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType());
+ if (!vLHS)
+ return parser.emitError(parser.getNameLoc(),
+ "expected vector type for operand #1");
+ VectorType resType =
+ vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+ vLHS.getElementType())
+ : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
return failure(
parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
@@ -1216,19 +1219,32 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
}
static LogicalResult verify(OuterProductOp op) {
+ Type tRHS = op.getOperandTypeRHS();
VectorType vLHS = op.getOperandVectorTypeLHS(),
- vRHS = op.getOperandVectorTypeRHS(),
+ vRHS = tRHS.dyn_cast<VectorType>(),
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+
if (vLHS.getRank() != 1)
return op.emitOpError("expected 1-d vector for operand #1");
- if (vRHS.getRank() != 1)
- return op.emitOpError("expected 1-d vector for operand #2");
- if (vRES.getRank() != 2)
- return op.emitOpError("expected 2-d vector result");
- if (vLHS.getDimSize(0) != vRES.getDimSize(0))
- return op.emitOpError("expected #1 operand dim to match result dim #1");
- if (vRHS.getDimSize(0) != vRES.getDimSize(1))
- return op.emitOpError("expected #2 operand dim to match result dim #2");
+
+ if (vRHS) {
+ // Proper OUTER operation.
+ if (vRHS.getRank() != 1)
+ return op.emitOpError("expected 1-d vector for operand #2");
+ if (vRES.getRank() != 2)
+ return op.emitOpError("expected 2-d vector result");
+ if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+ return op.emitOpError("expected #1 operand dim to match result dim #1");
+ if (vRHS.getDimSize(0) != vRES.getDimSize(1))
+ return op.emitOpError("expected #2 operand dim to match result dim #2");
+ } else {
+ // An AXPY operation.
+ if (vRES.getRank() != 1)
+ return op.emitOpError("expected 1-d vector result");
+ if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+ return op.emitOpError("expected #1 operand dim to match result dim #1");
+ }
+
if (vACC && vACC != vRES)
return op.emitOpError("expected operand #3 of same type as result type");
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index ebad34fcd593..aa5264ae0d33 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1262,7 +1262,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
/// %0 = vector.extract %lhs[0]
/// %1 = vector.broadcast %0
/// %2 = vector.extract %acc[0]
-/// %3 = vector.fma %1, %arg1, %2
+/// %3 = vector.fma %1, %rhs, %2
/// %4 = vector.insert %3, %z[0]
/// ..
/// %x = vector.insert %.., %..[N-1]
@@ -1275,36 +1275,49 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- VectorType rhsType = op.getOperandVectorTypeRHS();
+ VectorType lhsType = op.getOperandVectorTypeLHS();
+ VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
VectorType resType = op.getVectorType();
Type eltType = resType.getElementType();
+ bool isInt = eltType.isa<IntegerType>();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
+ if (!rhsType) {
+ // Special case: AXPY operation.
+ Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
+ rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
+ return success();
+ }
+
Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
- Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
- Value m;
- if (acc) {
- Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
- if (eltType.isa<IntegerType>())
- m = rewriter.create<AddIOp>(
- loc, rewriter.create<MulIOp>(loc, b, op.rhs()), e);
- else
- m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
- } else {
- if (eltType.isa<IntegerType>())
- m = rewriter.create<MulIOp>(loc, b, op.rhs());
- else
- m = rewriter.create<MulFOp>(loc, b, op.rhs());
- }
+ Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+ Value r = nullptr;
+ if (acc)
+ r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+ Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
+
+private:
+ static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
+ PatternRewriter &rewriter) {
+ if (acc) {
+ if (isInt)
+ return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
+ acc);
+ return rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ }
+ if (isInt)
+ return rewriter.create<MulIOp>(loc, x, y);
+ return rewriter.create<MulFOp>(loc, x, y);
+ }
};
/// Progressive lowering of ConstantMaskOp.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 84d596bf512f..916403800fe1 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -187,7 +187,7 @@ func @outerproduct_num_operands(%arg0: f32) {
// -----
func @outerproduct_non_vector_operand(%arg0: f32) {
- // expected-error at +1 {{expected 2 vector types}}
+ // expected-error at +1 {{expected vector type for operand #1}}
%1 = vector.outerproduct %arg0, %arg0 : f32, f32
}
@@ -228,6 +228,27 @@ func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf
// -----
+func @outerproduct_axpy_operand(%arg0: vector<4x8xf32>, %arg1: f32) {
+ // expected-error at +1 {{expected 1-d vector for operand #1}}
+ %1 = vector.outerproduct %arg0, %arg1 : vector<4x8xf32>, f32
+}
+
+// -----
+
+func @outerproduct_axpy_result_generic(%arg0: vector<4xf32>, %arg1: f32) {
+ // expected-error at +1 {{expected 1-d vector result}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, f32) -> (vector<4x8xf32>)
+}
+
+// -----
+
+func @outerproduct_axpy_operand_dim_generic(%arg0: vector<8xf32>, %arg1: f32) {
+ // expected-error at +1 {{expected #1 operand dim to match result dim #1}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<8xf32>, f32) -> (vector<16xf32>)
+}
+
+// -----
+
func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
// expected-error at +1 {{expected operand #3 of same type as result type}}
%1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index f6f215a50616..82faadf100e9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -326,6 +326,53 @@ func @outerproduct_acc_int(%arg0: vector<2xi32>,
return %0: vector<2x3xi32>
}
+// CHECK-LABEL: func @axpy_fp(
+// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
+// CHECK-SAME: %[[B:.*1]]: f32)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T1:.*]] = mulf %[[A]], %[[T0]] : vector<16xf32>
+// CHECK: return %[[T1]] : vector<16xf32>
+func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
+ %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
+ return %0: vector<16xf32>
+}
+
+// CHECK-LABEL: func @axpy_fp_add(
+// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
+// CHECK-SAME: %[[B:.*1]]: f32,
+// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
+// CHECK: return %[[T1]] : vector<16xf32>
+func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
+ return %0: vector<16xf32>
+}
+
+// CHECK-LABEL: func @axpy_int(
+// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
+// CHECK-SAME: %[[B:.*1]]: i32)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32>
+// CHECK: return %[[T1]] : vector<16xi32>
+func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
+ %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
+ return %0: vector<16xi32>
+}
+
+// CHECK-LABEL: func @axpy_int_add(
+// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
+// CHECK-SAME: %[[B:.*1]]: i32,
+// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32>
+// CHECK: %[[T2:.*]] = addi %[[T1]], %[[C]] : vector<16xi32>
+// CHECK: return %[[T2]] : vector<16xi32>
+func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
+ %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
+ return %0: vector<16xi32>
+}
+
// CHECK-LABEL: func @transpose23
// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
More information about the Mlir-commits
mailing list