[Mlir-commits] [mlir] 9bf6354 - [mlir] [VectorOps] Allow AXPY to be expressed as special case of OUTERPRODUCT

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 10 12:23:45 PDT 2020


Author: aartbik
Date: 2020-07-10T12:23:24-07:00
New Revision: 9bf6354301ac4e1c7a00e4ef46decba38840fe62

URL: https://github.com/llvm/llvm-project/commit/9bf6354301ac4e1c7a00e4ef46decba38840fe62
DIFF: https://github.com/llvm/llvm-project/commit/9bf6354301ac4e1c7a00e4ef46decba38840fe62.diff

LOG: [mlir] [VectorOps] Allow AXPY to be expressed as special case of OUTERPRODUCT

This specialization allows sharing more code where an AXPY follows naturally
in cases where an OUTERPRODUCT on a scalar would be generated.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D83453

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b205e5a2e286..57490c378041 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -91,7 +91,7 @@ def Vector_ContractionOp :
     Example:
 
     ```mlir
-    // Simple dot product (K = 0).
+    // Simple DOT product (K = 0).
     #contraction_accesses = [
      affine_map<(i) -> (i)>,
      affine_map<(i) -> (i)>,
@@ -668,19 +668,36 @@ def Vector_InsertStridedSliceOp :
 }
 
 def Vector_OuterProductOp :
-  Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
-    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
+  Vector_Op<"outerproduct", [NoSideEffect,
+    PredOpTrait<"lhs operand and result have same element type",
+                TCresVTEtIsSameAsOpBase<0, 0>>,
+    PredOpTrait<"rhs operand and result have same element type",
+                TCresVTEtIsSameAsOpBase<0, 1>>]>,
+    Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic<AnyVector>:$acc)>,
     Results<(outs AnyVector)> {
   let summary = "vector outerproduct with optional fused add";
   let description = [{
-    Takes 2 1-D vectors and returns the 2-D vector containing the outer-product.
+    Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
+    as illustrated below:
+    ```
+     outer |   [c, d]
+     ------+------------
+       [a, | [ [a*c, a*d],
+        b] |   [b*c, b*d] ]
+    ```
+    This operation also accepts a 1-D vector lhs and a scalar rhs. In this
+    case a simple AXPY operation is performed, which returns a 1-D vector.
+    ```
+        [a, b] * c = [a*c, b*c]
+    ```
 
-    An optional extra 2-D vector argument may be specified in which case the
-    operation returns the sum of the outer-product and the extra vector. In this
-    multiply-accumulate scenario, the rounding mode is that obtained by
-    guaranteeing that a fused-multiply add operation is emitted. When lowered to
-    the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to
-    lower to actual `fma` instructions on x86.
+    An optional extra vector argument with the same shape as the output
+    vector may be specified in which case the operation returns the sum of
+    the outer-product and the extra vector. In this multiply-accumulate
+    scenario for floating-point arguments, the rounding mode is enforced
+    by guaranteeing that a fused-multiply add operation is emitted. When
+    lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which
+    is guaranteed to lower to actual `fma` instructions on x86.
 
     Example:
 
@@ -691,6 +708,10 @@ def Vector_OuterProductOp :
     %3 = vector.outerproduct %0, %1, %2:
       vector<4xf32>, vector<8xf32>, vector<4x8xf32>
     return %3: vector<4x8xf32>
+
+    %6 = vector.outerproduct %4, %5: vector<10xf32>, f32
+    return %6: vector<10xf32>
+
     ```
   }];
   let builders = [
@@ -702,12 +723,13 @@ def Vector_OuterProductOp :
     VectorType getOperandVectorTypeLHS() {
       return lhs().getType().cast<VectorType>();
     }
-    VectorType getOperandVectorTypeRHS() {
-      return rhs().getType().cast<VectorType>();
+    Type getOperandTypeRHS() {
+      return rhs().getType();
     }
     VectorType getOperandVectorTypeACC() {
-      return (llvm::size(acc()) == 0) ? VectorType() :
-        (*acc().begin()).getType().cast<VectorType>();
+      return (llvm::size(acc()) == 0)
+        ? VectorType()
+        : (*acc().begin()).getType().cast<VectorType>();
     }
     VectorType getVectorType() {
       return getResult().getType().cast<VectorType>();

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
index a6470c656f8c..9b86976e4901 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir
@@ -11,6 +11,8 @@
 !vector_type_Y = type vector<3xf32>
 !vector_type_Z = type vector<2x3xf32>
 
+!vector_type_R = type vector<7xf32>
+
 func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C {
   %a = splat %fa: !vector_type_A
   %b = splat %fb: !vector_type_B
@@ -33,6 +35,7 @@ func @vector_outerproduct_vec_2x3_acc(%x : !vector_type_X,
 }
 
 func @entry() {
+  %f0 = constant 0.0: f32
   %f1 = constant 1.0: f32
   %f2 = constant 2.0: f32
   %f3 = constant 3.0: f32
@@ -72,5 +75,26 @@ func @entry() {
   //
   // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) )
 
+  %3 = vector.broadcast %f0 : f32 to !vector_type_R
+  %4 = vector.insert %f1,  %3[1] : f32 into !vector_type_R
+  %5 = vector.insert %f2,  %4[2] : f32 into !vector_type_R
+  %6 = vector.insert %f3,  %5[3] : f32 into !vector_type_R
+  %7 = vector.insert %f4,  %6[4] : f32 into !vector_type_R
+  %8 = vector.insert %f5,  %7[5] : f32 into !vector_type_R
+  %9 = vector.insert %f10, %8[6] : f32 into !vector_type_R
+
+  %o = vector.broadcast %f1 : f32 to !vector_type_R
+
+  %axpy1 = vector.outerproduct %9, %f2     : !vector_type_R, f32
+  %axpy2 = vector.outerproduct %9, %f2, %o : !vector_type_R, f32
+
+  vector.print %axpy1 : !vector_type_R
+  vector.print %axpy2 : !vector_type_R
+  //
+  // axpy operations:
+  //
+  // CHECK: ( 0, 2, 4, 6, 8, 10, 20 )
+  // CHECK: ( 1, 3, 5, 7, 9, 11, 21 )
+
   return
 }

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
index 12a71d85b3f7..c2724aec9b37 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir
@@ -11,6 +11,8 @@
 !vector_type_Y = type vector<3xi64>
 !vector_type_Z = type vector<2x3xi64>
 
+!vector_type_R = type vector<7xi64>
+
 func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
   %a = splat %ia: !vector_type_A
   %b = splat %ib: !vector_type_B
@@ -33,6 +35,7 @@ func @vector_outerproduct_vec_2x3_acc(%x : !vector_type_X,
 }
 
 func @entry() {
+  %i0 = constant 0: i64
   %i1 = constant 1: i64
   %i2 = constant 2: i64
   %i3 = constant 3: i64
@@ -72,5 +75,26 @@ func @entry() {
   //
   // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) )
 
+  %3 = vector.broadcast %i0 : i64 to !vector_type_R
+  %4 = vector.insert %i1,  %3[1] : i64 into !vector_type_R
+  %5 = vector.insert %i2,  %4[2] : i64 into !vector_type_R
+  %6 = vector.insert %i3,  %5[3] : i64 into !vector_type_R
+  %7 = vector.insert %i4,  %6[4] : i64 into !vector_type_R
+  %8 = vector.insert %i5,  %7[5] : i64 into !vector_type_R
+  %9 = vector.insert %i10, %8[6] : i64 into !vector_type_R
+
+  %o = vector.broadcast %i1 : i64 to !vector_type_R
+
+  %axpy1 = vector.outerproduct %9, %i2     : !vector_type_R, i64
+  %axpy2 = vector.outerproduct %9, %i2, %o : !vector_type_R, i64
+
+  vector.print %axpy1 : !vector_type_R
+  vector.print %axpy2 : !vector_type_R
+  //
+  // axpy operations:
+  //
+  // CHECK: ( 0, 2, 4, 6, 8, 10, 20 )
+  // CHECK: ( 1, 3, 5, 7, 9, 11, 21 )
+
   return
 }

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index cdf09c4a8f68..ca85625c7e4c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1203,10 +1203,13 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
                             "expected at least 2 operands");
   VectorType vLHS = tLHS.dyn_cast<VectorType>();
   VectorType vRHS = tRHS.dyn_cast<VectorType>();
-  if (!vLHS || !vRHS)
-    return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
-  VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                                       vLHS.getElementType());
+  if (!vLHS)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected vector type for operand #1");
+  VectorType resType =
+      vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+                             vLHS.getElementType())
+           : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
   return failure(
       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
@@ -1216,19 +1219,32 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
 }
 
 static LogicalResult verify(OuterProductOp op) {
+  Type tRHS = op.getOperandTypeRHS();
   VectorType vLHS = op.getOperandVectorTypeLHS(),
-             vRHS = op.getOperandVectorTypeRHS(),
+             vRHS = tRHS.dyn_cast<VectorType>(),
              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+
   if (vLHS.getRank() != 1)
     return op.emitOpError("expected 1-d vector for operand #1");
-  if (vRHS.getRank() != 1)
-    return op.emitOpError("expected 1-d vector for operand #2");
-  if (vRES.getRank() != 2)
-    return op.emitOpError("expected 2-d vector result");
-  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
-    return op.emitOpError("expected #1 operand dim to match result dim #1");
-  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
-    return op.emitOpError("expected #2 operand dim to match result dim #2");
+
+  if (vRHS) {
+    // Proper OUTER operation.
+    if (vRHS.getRank() != 1)
+      return op.emitOpError("expected 1-d vector for operand #2");
+    if (vRES.getRank() != 2)
+      return op.emitOpError("expected 2-d vector result");
+    if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+      return op.emitOpError("expected #1 operand dim to match result dim #1");
+    if (vRHS.getDimSize(0) != vRES.getDimSize(1))
+      return op.emitOpError("expected #2 operand dim to match result dim #2");
+  } else {
+    // An AXPY operation.
+    if (vRES.getRank() != 1)
+      return op.emitOpError("expected 1-d vector result");
+    if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+      return op.emitOpError("expected #1 operand dim to match result dim #1");
+  }
+
   if (vACC && vACC != vRES)
     return op.emitOpError("expected operand #3 of same type as result type");
   return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index ebad34fcd593..aa5264ae0d33 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1262,7 +1262,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 ///   %0 = vector.extract %lhs[0]
 ///   %1 = vector.broadcast %0
 ///   %2 = vector.extract %acc[0]
-///   %3 = vector.fma %1, %arg1, %2
+///   %3 = vector.fma %1, %rhs, %2
 ///   %4 = vector.insert %3, %z[0]
 ///   ..
 ///   %x = vector.insert %.., %..[N-1]
@@ -1275,36 +1275,49 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
 
-    VectorType rhsType = op.getOperandVectorTypeRHS();
+    VectorType lhsType = op.getOperandVectorTypeLHS();
+    VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
     VectorType resType = op.getVectorType();
     Type eltType = resType.getElementType();
+    bool isInt = eltType.isa<IntegerType>();
     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
 
+    if (!rhsType) {
+      // Special case: AXPY operation.
+      Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
+      rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
+      return success();
+    }
+
     Value result = rewriter.create<ConstantOp>(loc, resType,
                                                rewriter.getZeroAttr(resType));
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
       auto pos = rewriter.getI64ArrayAttr(d);
       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
-      Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
-      Value m;
-      if (acc) {
-        Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-        if (eltType.isa<IntegerType>())
-          m = rewriter.create<AddIOp>(
-              loc, rewriter.create<MulIOp>(loc, b, op.rhs()), e);
-        else
-          m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
-      } else {
-        if (eltType.isa<IntegerType>())
-          m = rewriter.create<MulIOp>(loc, b, op.rhs());
-        else
-          m = rewriter.create<MulFOp>(loc, b, op.rhs());
-      }
+      Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+      Value r = nullptr;
+      if (acc)
+        r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+      Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
       result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
     }
     rewriter.replaceOp(op, result);
     return success();
   }
+
+private:
+  static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
+                       PatternRewriter &rewriter) {
+    if (acc) {
+      if (isInt)
+        return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
+                                       acc);
+      return rewriter.create<vector::FMAOp>(loc, x, y, acc);
+    }
+    if (isInt)
+      return rewriter.create<MulIOp>(loc, x, y);
+    return rewriter.create<MulFOp>(loc, x, y);
+  }
 };
 
 /// Progressive lowering of ConstantMaskOp.

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 84d596bf512f..916403800fe1 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -187,7 +187,7 @@ func @outerproduct_num_operands(%arg0: f32) {
 // -----
 
 func @outerproduct_non_vector_operand(%arg0: f32) {
-  // expected-error at +1 {{expected 2 vector types}}
+  // expected-error at +1 {{expected vector type for operand #1}}
   %1 = vector.outerproduct %arg0, %arg0 : f32, f32
 }
 
@@ -228,6 +228,27 @@ func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf
 
 // -----
 
+func @outerproduct_axpy_operand(%arg0: vector<4x8xf32>, %arg1: f32) {
+  // expected-error at +1 {{expected 1-d vector for operand #1}}
+  %1 = vector.outerproduct %arg0, %arg1 : vector<4x8xf32>, f32
+}
+
+// -----
+
+func @outerproduct_axpy_result_generic(%arg0: vector<4xf32>, %arg1: f32) {
+  // expected-error at +1 {{expected 1-d vector result}}
+  %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, f32) -> (vector<4x8xf32>)
+}
+
+// -----
+
+func @outerproduct_axpy_operand_dim_generic(%arg0: vector<8xf32>, %arg1: f32) {
+  // expected-error at +1 {{expected #1 operand dim to match result dim #1}}
+  %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<8xf32>, f32) -> (vector<16xf32>)
+}
+
+// -----
+
 func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
   // expected-error at +1 {{expected operand #3 of same type as result type}}
   %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index f6f215a50616..82faadf100e9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -326,6 +326,53 @@ func @outerproduct_acc_int(%arg0: vector<2xi32>,
   return %0: vector<2x3xi32>
 }
 
+// CHECK-LABEL: func @axpy_fp(
+// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
+// CHECK-SAME: %[[B:.*1]]: f32)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T1:.*]] = mulf %[[A]], %[[T0]] : vector<16xf32>
+// CHECK: return %[[T1]] : vector<16xf32>
+func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
+   %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
+   return %0: vector<16xf32>
+}
+
+// CHECK-LABEL: func @axpy_fp_add(
+// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
+// CHECK-SAME: %[[B:.*1]]: f32,
+// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
+// CHECK: return %[[T1]] : vector<16xf32>
+func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
+   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
+   return %0: vector<16xf32>
+}
+
+// CHECK-LABEL: func @axpy_int(
+// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
+// CHECK-SAME: %[[B:.*1]]: i32)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32>
+// CHECK: return %[[T1]] : vector<16xi32>
+func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
+   %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
+   return %0: vector<16xi32>
+}
+
+// CHECK-LABEL: func @axpy_int_add(
+// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
+// CHECK-SAME: %[[B:.*1]]: i32,
+// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
+// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32>
+// CHECK: %[[T2:.*]] = addi %[[T1]], %[[C]] : vector<16xi32>
+// CHECK: return %[[T2]] : vector<16xi32>
+func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
+   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
+   return %0: vector<16xi32>
+}
+
 // CHECK-LABEL: func @transpose23
 // CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
 // CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>


        


More information about the Mlir-commits mailing list