[Mlir-commits] [mlir] 078776a - [mlir] [VectorOps] Progressively lower vector.outerproduct to LLVM

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 13:45:53 PDT 2020


Author: aartbik
Date: 2020-03-12T13:45:42-07:00
New Revision: 078776a679b94c1fc970febe3c72f0b337af9a97

URL: https://github.com/llvm/llvm-project/commit/078776a679b94c1fc970febe3c72f0b337af9a97
DIFF: https://github.com/llvm/llvm-project/commit/078776a679b94c1fc970febe3c72f0b337af9a97.diff

LOG: [mlir] [VectorOps] Progressively lower vector.outerproduct to LLVM

Summary:
This replaces the direct lowering of vector.outerproduct to LLVM with progressive lowering into elementary vectors ops to avoid having the similar lowering logic at several places.

NOTE1: with the new progressive rule, the lowered llvm is slightly more elaborate than with the direct lowering, but the generated assembly is just as optimized; still if we want to stay closer to the original, we should add a "broadcast on extract" to shuffle rewrite (rather than special cases all the lowering steps)

NOTE2: the original outerproduct lowering code should now be removed but some linalg test work directly on vector and contain some dead code, so this requires another CL

Reviewers: nicolasvasilache, andydavis1

Reviewed By: nicolasvasilache, andydavis1

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75956

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/VectorOps/VectorOps.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index 85ea7b9f1b83..e32752fe6030 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -55,9 +55,13 @@ void populateVectorToVectorTransformationPatterns(
 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                           MLIRContext *context);
 
-/// Collect a set of vector contraction transformation patterns
-/// that express all vector.contract ops in terms of more elementary
-/// extraction and reduction ops.
+/// Collect a set of transformation patterns that are related to contracting
+/// or expanding vector operations:
+///   ContractionOpLowering,
+///   ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern
+///   OuterproductOpLowering
+/// These transformation express higher level vector ops in terms of more
+/// elementary extraction, insertion, reduction, product, and broadcast ops.
 void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns,
                                             MLIRContext *context);
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d2167c52a2d2..a41a9d257417 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -817,6 +817,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   }
 };
 
+// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
 class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorOuterProductOpConversion(MLIRContext *context,
@@ -1176,7 +1177,7 @@ struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
 } // namespace
 
 void LowerVectorToLLVMPass::runOnModule() {
-  // Perform progressive lowering of operations on "slices" and
+  // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
     OwningRewritePatternList patterns;

diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index fba0f00d0679..b164df276c3b 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -864,6 +864,53 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
   }
 };
 
+/// Progressive lowering of OuterProductOp.
+/// One:
+///   %x = vector.outerproduct %lhs, %rhs, %acc
+/// is replaced by:
+///   %z = zero-result
+///   %0 = vector.extract %lhs[0]
+///   %1 = vector.broadcast %0
+///   %2 = vector.extract %acc[0]
+///   %3 = vector.fma %1, %arg1, %2
+///   %4 = vector.insert %3, %z[0]
+///   ..
+///   %x = vector.insert %.., %..[N-1]
+///
+class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
+public:
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::OuterProductOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType rhsType = op.getOperandVectorTypeRHS();
+    VectorType resType = op.getVectorType();
+    Type eltType = resType.getElementType();
+    Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
+
+    Value zero = rewriter.create<ConstantOp>(loc, eltType,
+                                             rewriter.getZeroAttr(eltType));
+    Value result = rewriter.create<SplatOp>(loc, resType, zero);
+    for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
+      auto pos = rewriter.getI64ArrayAttr(d);
+      Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
+      Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+      Value m;
+      if (acc) {
+        Value z = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+        m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), z);
+      } else {
+        m = rewriter.create<MulFOp>(loc, b, op.rhs());
+      }
+      result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
+    }
+    rewriter.replaceOp(op, result);
+    return matchSuccess();
+  }
+};
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1256,9 +1303,7 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
 
 void mlir::vector::populateVectorContractLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<ContractionOpLowering,
-                  // Shape 2d up/down casts are used as part of contraction
-                  // lowering.
-                  ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern>(context);
+  patterns.insert<ContractionOpLowering, ShapeCastOp2DDownCastRewritePattern,
+                  ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
+      context);
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f70fb0cac6da..0cc6789a0619 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -204,39 +204,64 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
   %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
   return %2 : vector<2x3xf32>
 }
-// CHECK-LABEL: llvm.func @outerproduct
-//       CHECK:   llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
-//       CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
-//       CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
+// CHECK-LABEL: llvm.func @outerproduct(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">)
+//      CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x3xf32>)
+//      CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+//      CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>">
+//      CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+//      CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+//      CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i64] : !llvm<"<3 x float>">
+//      CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+//      CHECK: %[[T7:.*]] = llvm.fmul %[[T6]], %[[B]] : !llvm<"<3 x float>">
+//      CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
+//      CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
+//      CHECK: %[[T10:.*]] = llvm.extractelement %[[A]][%9 : !llvm.i64] : !llvm<"<2 x float>">
+//      CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+//      CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+//      CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i64] : !llvm<"<3 x float>">
+//      CHECK: %[[T14:.*]] = llvm.shufflevector %[[T13]], %[[T11]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+//      CHECK: %[[T15:.*]] = llvm.fmul %[[T14]], %[[B]] : !llvm<"<3 x float>">
+//      CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[2 x <3 x float>]">
+//      CHECK: llvm.return %[[T16]] : !llvm<"[2 x <3 x float>]">
 
 func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
   return %2 : vector<2x3xf32>
 }
-// CHECK-LABEL: llvm.func @outerproduct_add
-//       CHECK:   llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
-//       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
-//       CHECK:   llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
+// CHECK-LABEL: llvm.func @outerproduct_add(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">,
+// CHECK-SAME: %[[C:.*]]: !llvm<"[2 x <3 x float>]">)
+//      CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x3xf32>)
+//      CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+//      CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>">
+//      CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+//      CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+//      CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<3 x float>">
+//      CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+//      CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
+//      CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
+//      CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
+//      CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
+//      CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
+//      CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+//      CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+//      CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i64] : !llvm<"<3 x float>">
+//      CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+//      CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
+//      CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
+//      CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]">
+//      CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]">
 
 func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> {
   %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xf32>
   return %1 : vector<2xf32>
 }
-// CHECK-LABEL: llvm.func @shuffle_1D_direct
+// CHECK-LABEL: llvm.func @shuffle_1D_direct(
 // CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
-// CHECK-SAME: %[[B:.*]]: !llvm<"<2 x float>">
+// CHECK-SAME: %[[B:.*]]: !llvm<"<2 x float>">)
 //       CHECK:   %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
 //       CHECK:   llvm.return %[[s]] : !llvm<"<2 x float>">
 
@@ -244,9 +269,9 @@ func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
   %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
   return %1 : vector<5xf32>
 }
-// CHECK-LABEL: llvm.func @shuffle_1D
+// CHECK-LABEL: llvm.func @shuffle_1D(
 // CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
-// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">
+// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">)
 //       CHECK:   %[[u0:.*]] = llvm.mlir.undef : !llvm<"<5 x float>">
 //       CHECK:   %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
 //       CHECK:   %[[e1:.*]] = llvm.extractelement %[[B]][%[[c2]] : !llvm.i64] : !llvm<"<3 x float>">
@@ -274,9 +299,9 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
   %1 = vector.shuffle %a, %b[1, 0, 2] : vector<1x4xf32>, vector<2x4xf32>
   return %1 : vector<3x4xf32>
 }
-// CHECK-LABEL: llvm.func @shuffle_2D
+// CHECK-LABEL: llvm.func @shuffle_2D(
 // CHECK-SAME: %[[A:.*]]: !llvm<"[1 x <4 x float>]">,
-// CHECK-SAME: %[[B:.*]]: !llvm<"[2 x <4 x float>]">
+// CHECK-SAME: %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
 //       CHECK:   %[[u0:.*]] = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
 //       CHECK:   %[[e1:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
 //       CHECK:   %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm<"[3 x <4 x float>]">
@@ -291,8 +316,8 @@ func @extract_element(%arg0: vector<16xf32>) -> f32 {
   %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
   return %1 : f32
 }
-// CHECK-LABEL: llvm.func @extract_element
-// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
+// CHECK-LABEL: llvm.func @extract_element(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
 //       CHECK:   %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : !llvm.i32] : !llvm<"<16 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm.float
@@ -337,9 +362,9 @@ func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
   return %1 : vector<4xf32>
 }
-// CHECK-LABEL: llvm.func @insert_element
+// CHECK-LABEL: llvm.func @insert_element(
 // CHECK-SAME: %[[A:.*]]: !llvm.float,
-// CHECK-SAME: %[[B:.*]]: !llvm<"<4 x float>">
+// CHECK-SAME: %[[B:.*]]: !llvm<"<4 x float>">)
 //       CHECK:   %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
 //       CHECK:   %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : !llvm.i32] : !llvm<"<4 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm<"<4 x float>">
@@ -399,8 +424,8 @@ func @vector_print_scalar_i32(%arg0: i32) {
   vector.print %arg0 : i32
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_scalar_i32
-// CHECK-SAME: %[[A:.*]]: !llvm.i32
+// CHECK-LABEL: llvm.func @vector_print_scalar_i32(
+// CHECK-SAME: %[[A:.*]]: !llvm.i32)
 //       CHECK:    llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
@@ -408,8 +433,8 @@ func @vector_print_scalar_i64(%arg0: i64) {
   vector.print %arg0 : i64
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_scalar_i64
-// CHECK-SAME: %[[A:.*]]: !llvm.i64
+// CHECK-LABEL: llvm.func @vector_print_scalar_i64(
+// CHECK-SAME: %[[A:.*]]: !llvm.i64)
 //       CHECK:    llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
@@ -417,8 +442,8 @@ func @vector_print_scalar_f32(%arg0: f32) {
   vector.print %arg0 : f32
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_scalar_f32
-// CHECK-SAME: %[[A:.*]]: !llvm.float
+// CHECK-LABEL: llvm.func @vector_print_scalar_f32(
+// CHECK-SAME: %[[A:.*]]: !llvm.float)
 //       CHECK:    llvm.call @print_f32(%[[A]]) : (!llvm.float) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
@@ -426,8 +451,8 @@ func @vector_print_scalar_f64(%arg0: f64) {
   vector.print %arg0 : f64
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_scalar_f64
-// CHECK-SAME: %[[A:.*]]: !llvm.double
+// CHECK-LABEL: llvm.func @vector_print_scalar_f64(
+// CHECK-SAME: %[[A:.*]]: !llvm.double)
 //       CHECK:    llvm.call @print_f64(%[[A]]) : (!llvm.double) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
@@ -435,8 +460,8 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
   vector.print %arg0 : vector<2x2xf32>
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_vector
-// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <2 x float>]">
+// CHECK-LABEL: llvm.func @vector_print_vector(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <2 x float>]">)
 //       CHECK:    llvm.call @print_open() : () -> ()
 //       CHECK:    %[[x0:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <2 x float>]">
 //       CHECK:    llvm.call @print_open() : () -> ()
@@ -575,9 +600,9 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
         vector<2x4xf32> into vector<16x4x8xf32>
   return %0 : vector<16x4x8xf32>
 }
-// CHECK-LABEL: llvm.func @insert_strided_slice3
+// CHECK-LABEL: llvm.func @insert_strided_slice3(
 // CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <4 x float>]">,
-// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">
+// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">)
 //      CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]">
 //      CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]">
@@ -626,8 +651,8 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
   %1 = vector.tuple_get %0, 3 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
   return %1 : vector<1x1xf32>
 }
-// CHECK-LABEL: llvm.func @extract_strides
-// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <3 x float>]">
+// CHECK-LABEL: llvm.func @extract_strides(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <3 x float>]">)
 //      CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]">
 //      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[3 x <3 x float>]">
 //      CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>">
@@ -667,8 +692,8 @@ func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
   %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
   return %0 : f32
 }
-// CHECK-LABEL: llvm.func @reduce_f32
-// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
+// CHECK-LABEL: llvm.func @reduce_f32(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
 //      CHECK: llvm.return %[[V]] : !llvm.float
@@ -677,8 +702,8 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
   %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64
   return %0 : f64
 }
-// CHECK-LABEL: llvm.func @reduce_f64
-// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">
+// CHECK-LABEL: llvm.func @reduce_f64(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
 //      CHECK: llvm.return %[[V]] : !llvm.double
@@ -687,8 +712,8 @@ func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
   %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
   return %0 : i32
 }
-// CHECK-LABEL: llvm.func @reduce_i32
-// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">
+// CHECK-LABEL: llvm.func @reduce_i32(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">)
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
 //      CHECK: llvm.return %[[V]] : !llvm.i32
 
@@ -696,8 +721,8 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
   %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64
   return %0 : i64
 }
-// CHECK-LABEL: llvm.func @reduce_i64
-// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">
+// CHECK-LABEL: llvm.func @reduce_i64(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">)
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
 //      CHECK: llvm.return %[[V]] : !llvm.i64
 

diff  --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index c5e40a7c18ca..8d6b0d8d9ccd 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -251,6 +251,50 @@ func @full_contract2(%arg0: vector<2x3xf32>,
   return %0 : f32
 }
 
+// CHECK-LABEL: func @outerproduct_noacc
+// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
+// CHECK:      %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
+// CHECK:      %[[T2:.*]] = mulf %[[T1]], %[[B]] : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
+// CHECK:      %[[T6:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
+// CHECK:      return %[[T7]] : vector<2x3xf32>
+
+func @outerproduct_noacc(%arg0: vector<2xf32>,
+                         %arg1: vector<3xf32>) -> vector<2x3xf32> {
+  %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
+  return %0: vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @outerproduct_acc
+// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
+// CHECK:      %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
+// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
+// CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
+// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
+// CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
+// CHECK:      return %[[T9]] : vector<2x3xf32>
+
+func @outerproduct_acc(%arg0: vector<2xf32>,
+                       %arg1: vector<3xf32>,
+                       %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
+  return %0: vector<2x3xf32>
+}
+
 // Shape up and downcasts for 2-D vectors, for supporting conversion to
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts


        


More information about the Mlir-commits mailing list