[Mlir-commits] [mlir] 3f906c5 - [mlir][Vector] Add 2-D vector contract lowering to ReduceOp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 7 03:21:26 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-07T06:17:48-04:00
New Revision: 3f906c54a2def26ba0c407a309c7c30ce0e0cc83

URL: https://github.com/llvm/llvm-project/commit/3f906c54a2def26ba0c407a309c7c30ce0e0cc83
DIFF: https://github.com/llvm/llvm-project/commit/3f906c54a2def26ba0c407a309c7c30ce0e0cc83.diff

LOG: [mlir][Vector] Add 2-D vector contract lowering to ReduceOp

This new pattern mixes vector.transpose and direct lowering to vector.reduce.
This allows more progressive lowering than immediately going to insert/extract and
composes more nicely with other canonicalizations.
This has 2 use cases:
1. for very wide vectors the generated IR may be much smaller
2. when we have a custom lowering for transpose ops we can target it directly
rather than rely LLVM

Differential Revision: https://reviews.llvm.org/D85428

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 35855b3b2137..9587c56c0255 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -258,6 +258,51 @@ class ContractionOpToOuterProductOpLowering
   FilterConstraintType filter;
 };
 
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+///    %out = constant ... : vector<MxNxelt_type>
+///    %bt = vector.transpose %b, [1, 0]
+///    %aRow0 = vector.extract %a[0]
+///    %btRow0 = vector.extract %bt[0]
+///    %c00 = vector.reduce %atRow0, %bRow0
+///    %out00 = vector.insert %c00, %out[0, 0]
+///    ...
+///    %aRowLast = vector.extract %at[M-1]
+///    %btRowLast = vector.extract %b[N-1]
+///    %cLastLast = vector.reduce %atRowLast, %bRowLast
+///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToDotLowering(
+      vector::VectorTransformsOptions vectorTransformsOptions,
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions),
+        filter(defaultFilter) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
+  FilterConstraintType filter;
+};
+
 /// Progressive lowering of ContractionOp.
 ///
 /// One:

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 218715318a86..b60b6b39a2b7 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -137,8 +137,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodName=*/"getVectorType",
       /*args=*/(ins),
       /*methodBody=*/"",
-      /*defaultImplementation=*/
-        "return $_op.vector().getType().template cast<VectorType>();"
+      /*defaultImplementation=*/[{
+        return $_op.vector().getType().template dyn_cast<VectorType>();
+        }]
     >,
     InterfaceMethod<
       /*desc=*/[{ Return the number of dimensions that participate in the

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 922168947ccf..16d10e558b5e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1754,6 +1754,121 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   return success();
 }
 
+LogicalResult
+ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
+                                            PatternRewriter &rewriter) const {
+  // TODO: implement masks
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
+  if (vectorTransformsOptions.vectorContractLowering !=
+      vector::VectorContractLowering::Dot)
+    return failure();
+
+  auto iteratorTypes = op.iterator_types().getValue();
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
+  Location loc = op.getLoc();
+  Value lhs = op.lhs(), rhs = op.rhs();
+
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n, k;
+  bindDims(rewriter.getContext(), m, n, k);
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  //
+  // In the following we wish to make the reduction dimension innermost so we
+  // can load vectors and just fmul + reduce into a scalar.
+  //
+  if (isParallelIterator(iteratorTypes[0]) &&
+      isParallelIterator(iteratorTypes[1]) &&
+      isReductionIterator(iteratorTypes[2])) {
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    //
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+      // No need to permute anything.
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = tmp;
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = tmp;
+    } else {
+      return failure();
+    }
+  } else if (isParallelIterator(iteratorTypes[0]) &&
+             isReductionIterator(iteratorTypes[1])) {
+    //
+    // One outer parallel, one inner reduction (matvec flavor)
+    //
+    if (maps == infer({{m, n}, {n}, {m}})) {
+      // No need to permute anything.
+    } else if (maps == infer({{n, m}, {n}, {m}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n}, {m, n}, {m}})) {
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{n}, {n, m}, {m}})) {
+      std::swap(lhs, rhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else {
+      return failure();
+    }
+  } else {
+    return failure();
+  }
+
+  VectorType dstType = op.getResultType().cast<VectorType>();
+  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
+         "Expected dst type of rank 1 or 2");
+
+  unsigned rank = dstType.getRank();
+  unsigned dstRows = dstType.getShape()[0];
+  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
+
+  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+  Value res =
+      rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
+  for (unsigned r = 0; r < dstRows; ++r) {
+    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+    for (unsigned c = 0; c < dstColumns; ++c) {
+      Value b = rank == 1
+                    ? rhs
+                    : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+      Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
+      Value reduced = rewriter.create<vector::ReductionOp>(
+          op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
+          m, ValueRange{});
+
+      SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
+                                              : SmallVector<int64_t, 2>{r, c};
+      res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+    }
+  }
+  if (auto acc = op.acc())
+    res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1795,6 +1910,9 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
     return success();
+  ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
+  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
+    return success();
 
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 6dae907b8bb0..e34e3428c185 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -43,16 +43,15 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
 // CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
-// CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
 // CHECK:      %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32>
-// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
+// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
 // CHECK:      %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
-// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
+// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:      return %[[T9]] : vector<2xf32>
+// CHECK:      %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32>
+// CHECK:      return %[[T10]] : vector<2xf32>
 
 func @extract_contract2(%arg0: vector<2x3xf32>,
                         %arg1: vector<3xf32>,
@@ -78,16 +77,15 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
 // CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
-// CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
-// CHECK:      %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32>
-// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
+// CHECK:      %[[T2:.*]] = mulf %[[T0]], %[[A]] : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
-// CHECK:      %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32>
-// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
+// CHECK:      %[[T7:.*]] = mulf %[[T5]], %[[A]] : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:      return %[[T9]] : vector<2xf32>
+// CHECK:      %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32>
+// CHECK:      return %[[T10]] : vector<2xf32>
 
 func @extract_contract3(%arg0: vector<3xf32>,
                         %arg1: vector<2x3xf32>,
@@ -112,47 +110,31 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
 // CHECK:    %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
-// CHECK:    %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// ... bunch of extract insert to transpose B into Bt
+// CHECK:    %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
 // CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
-// CHECK:    %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
-// CHECK:    %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
-// CHECK:    %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32>
-// CHECK:    %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
-// CHECK:    %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
-// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
+// CHECK:    %[[T9:.*]] = mulf %[[T0]], %[[T2]] : vector<2xf32>
+// CHECK:    %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xf32> into f32
+// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
 //
-// CHECK:    %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
-// CHECK:    %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
-// CHECK:    %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32>
-// CHECK:    %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
-// CHECK:    %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
-// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK:    %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
+// CHECK:    %[[T19:.*]] = mulf %[[T0]], %[[T12]] : vector<2xf32>
+// CHECK:    %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xf32> into f32
+// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
 //
 // CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
-// CHECK:    %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
-// CHECK:    %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
-// CHECK:    %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32>
-// CHECK:    %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32>
-// CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32
-// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
+// CHECK:    %[[T32:.*]] = mulf %[[T23]], %[[T24]] : vector<2xf32>
+// CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xf32> into f32
+// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
 //
-// CHECK:    %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
-// CHECK:    %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
-// CHECK:    %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32>
-// CHECK:    %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32>
-// CHECK:    %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32
+// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
+// CHECK:    %[[T41:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
+// CHECK:    %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xf32> into f32
+// CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
 //
-// CHECK:    %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK:    %[[T52:.*]] = addf %[[T43]], %[[C]] : vector<2x2xf32>
 // CHECK:    return %[[T52]] : vector<2x2xf32>
 
 func @extract_contract4(%arg0: vector<2x2xf32>,
@@ -574,6 +556,31 @@ func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 // OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
 //
 //      OUTERPRODUCT: return %[[c3]] : vector<2x3xf32>
+
+// REDUCE-LABEL: func @matmul
+// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//
+//      REDUCE: %[[RES:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+//      REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// REDUCE-SAME:  : vector<4x3f32> to vector<3x4xf32>
+//
+//      REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
+// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32>
+// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32>
+// REDUCE-NEXT: %[[s00:.*]] = vector.reduction "add", %[[ab00]] : vector<4xf32> into f32
+// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32>
+//
+//      ...
+//
+//      REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
+// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32>
+// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32>
+// REDUCE-NEXT: %[[s12:.*]] = vector.reduction "add", %[[ab12]] : vector<4xf32> into f32
+// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32>
+//
+//      REDUCE: return %[[c3]] : vector<2x3xf32>
 func @matmul(%arg0: vector<2x4xf32>,
                           %arg1: vector<4x3xf32>,
                           %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
@@ -1056,7 +1063,3 @@ func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg
     : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
-
-
-
-


        


More information about the Mlir-commits mailing list