[Mlir-commits] [mlir] 365434a - [mlir] [VectorOps] Merge OUTER/AXPY vector.contract lowering into single case

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 10 13:12:12 PDT 2020


Author: aartbik
Date: 2020-07-10T13:11:54-07:00
New Revision: 365434a584078577a7af6b91ffd2640c72c6d265

URL: https://github.com/llvm/llvm-project/commit/365434a584078577a7af6b91ffd2640c72c6d265
DIFF: https://github.com/llvm/llvm-project/commit/365434a584078577a7af6b91ffd2640c72c6d265.diff

LOG: [mlir] [VectorOps] Merge OUTER/AXPY vector.contract lowering into single case

We temporarily had separate OUTER lowering (for matmat flavors) and
AXPY lowering (for matvec flavors). With the new generalized
"vector.outerproduct" semantics, these cases can be merged into
a single lowering method. This refactoring will simplify future
decisions on cost models and lowering heuristics.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D83585

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 29c320903aec..0f6aa66e926c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -53,8 +53,6 @@ enum class VectorContractLowering {
   Matmul = 1,
   /// Lower to `vector.outerproduct`.
   OuterProduct = 2,
-  /// Lower to series of AXPY chained through FMA.
-  AXPY = 3,
 };
 /// Enum to control the lowering of `vector.transpose` operations.
 enum class VectorTransposeLowering {

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index ca67be9512f7..e95329c3e505 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -177,33 +177,6 @@ class ContractionOpToOuterProductOpLowering
   vector::VectorTransformsOptions vectorTransformsOptions;
 };
 
-/// Progressive lowering of a `vector.contract %a, %b, %c` with
-/// matvec semantics to series of AXPY operations that are chained
-/// through FMA operations.
-///
-/// This only kicks in when VectorTransformsOptions is set to AXPY.
-//
-// TODO: this is very similar, but not quite the same as the outerproduct
-// lowering above; merge the two?
-class ContractionOpToAXPYLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-  ContractionOpToAXPYLowering(
-      vector::VectorTransformsOptions vectorTransformsOptions,
-      MLIRContext *context)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
-
-  LogicalResult match(vector::ContractionOp op) const override;
-  void rewrite(vector::ContractionOp op,
-               PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
-};
-
 /// Progressive lowering of ContractionOp.
 ///
 /// One:

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index aa5264ae0d33..2f77fd5ff60a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1638,7 +1638,7 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
 LogicalResult
-ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
+ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
   // TODO: implement masks
   if (llvm::size(op.masks()) != 0)
     return failure();
@@ -1647,30 +1647,46 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
       vector::VectorContractLowering::OuterProduct)
     return failure();
 
-  // Transpose arguments to make them ready for lowering to OuterProduct. The
-  // constraint to match is that we must load full rows at a time with
-  // vector::ExtractOp.
+  // Determine if the parallel/reduction structure matches something
+  // that can be expressed a reduction_size unrolled sequence.
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(op.getContext(), m, n, k);
   auto iteratorTypes = op.iterator_types().getValue();
-  if (!isParallelIterator(iteratorTypes[0]) ||
-      !isParallelIterator(iteratorTypes[1]) ||
-      !isReductionIterator(iteratorTypes[2]))
-    return failure();
   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  // When lowering to outerproduct we can support all permutations.
-  if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
-      maps != infer({{m, k}, {n, k}, {m, n}}) &&
-      maps != infer({{k, m}, {k, n}, {m, n}}) &&
-      maps != infer({{k, m}, {n, k}, {m, n}}) &&
-      maps != infer({{m, k}, {k, n}, {n, m}}) &&
-      maps != infer({{m, k}, {n, k}, {n, m}}) &&
-      maps != infer({{k, m}, {k, n}, {n, m}}) &&
-      maps != infer({{k, m}, {n, k}, {n, m}}))
-    return failure();
-  return success();
+  if (isParallelIterator(iteratorTypes[0]) &&
+      isParallelIterator(iteratorTypes[1]) &&
+      isReductionIterator(iteratorTypes[2])) {
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    // When lowering to outerproduct we can support all permutations.
+    //
+    if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
+        maps != infer({{m, k}, {n, k}, {m, n}}) &&
+        maps != infer({{k, m}, {k, n}, {m, n}}) &&
+        maps != infer({{k, m}, {n, k}, {m, n}}) &&
+        maps != infer({{m, k}, {k, n}, {n, m}}) &&
+        maps != infer({{m, k}, {n, k}, {n, m}}) &&
+        maps != infer({{k, m}, {k, n}, {n, m}}) &&
+        maps != infer({{k, m}, {n, k}, {n, m}}))
+      return failure();
+    return success();
+  } else if (isParallelIterator(iteratorTypes[0]) &&
+             isReductionIterator(iteratorTypes[1])) {
+    //
+    // One outer parallel, one inner reduction (matvec flavor)
+    // See if a series of AXPY operations chained through FMA operations
+    // could replace the default DOT implementation.
+    //
+    if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec
+        maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec
+        maps != infer({{n}, {m, n}, {m}}) && // vec-mat
+        maps != infer({{n}, {n, m}, {m}}))   // vec-mat-trans
+      return failure();
+    return success();
+  }
+  return failure();
 }
 
 void ContractionOpToOuterProductOpLowering::rewrite(
@@ -1680,61 +1696,87 @@ void ContractionOpToOuterProductOpLowering::rewrite(
   VectorType lhsType = op.getLhsType();
   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
 
-  // Transpose arguments to make them ready for lowering to OuterProduct. The
-  // constraint to match is that we must load full rows at a time with
-  // vector::ExtractOp.
+  // Set up the parallel/reduction structure in right form.
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(rewriter.getContext(), m, n, k);
   SmallVector<int64_t, 2> perm{1, 0};
+  auto iteratorTypes = op.iterator_types().getValue();
   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  // First batch of cases, no need to output permute.
-  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
-    // This is the classical row-major matmul. Just permute the lhs.
-    reductionSize = lhsType.getDimSize(1);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    reductionSize = lhsType.getDimSize(1);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
-    // No need to permute anything.
-    reductionSize = lhsType.getDimSize(0);
-  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
-    // Just permute the rhs.
-    reductionSize = lhsType.getDimSize(0);
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-  }
-  // Second batch of cases, reshuffle to avoid output permute.
-  else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
-    // This is the classical row-major matmul. Just permute the lhs.
-    reductionSize = lhsType.getDimSize(1);
-    Value tmp = rhs;
-    rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    lhs = tmp;
-  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    reductionSize = lhsType.getDimSize(1);
-    Value tmp = rhs;
-    rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
-  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
-    // No need to permute anything, but still swap lhs and rhs.
-    reductionSize = lhsType.getDimSize(0);
-    std::swap(lhs, rhs);
-  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
-    // Just permute the rhs.
-    reductionSize = lhsType.getDimSize(0);
-    Value tmp = lhs;
-    lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    rhs = tmp;
+  if (isParallelIterator(iteratorTypes[0]) &&
+      isParallelIterator(iteratorTypes[1]) &&
+      isReductionIterator(iteratorTypes[2])) {
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    //
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      // No need to permute anything.
+      reductionSize = lhsType.getDimSize(0);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      // Just permute the rhs.
+      reductionSize = lhsType.getDimSize(0);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      reductionSize = lhsType.getDimSize(1);
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = tmp;
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+      reductionSize = lhsType.getDimSize(1);
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      // No need to permute anything, but still swap lhs and rhs.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      // Just permute the rhs.
+      reductionSize = lhsType.getDimSize(0);
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = tmp;
+    }
+  } else {
+    //
+    // One outer parallel, one inner reduction (matvec flavor)
+    //
+    assert(isParallelIterator(iteratorTypes[0]) &&
+           isReductionIterator(iteratorTypes[1]));
+    if (maps == infer({{m, n}, {n}, {m}})) {
+      // Case mat-vec: transpose.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n, m}, {n}, {m}})) {
+      // Case mat-trans-vec: ready to go.
+      reductionSize = lhsType.getDimSize(0);
+    } else if (maps == infer({{n}, {m, n}, {m}})) {
+      // Case vec-mat: swap and transpose.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n}, {n, m}, {m}})) {
+      // Case vec-mat-trans: swap and ready to go.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+    }
   }
   assert(reductionSize > 0);
 
-  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
-  for (unsigned k = 0; k < reductionSize; ++k) {
+  // Unroll outer-products along reduction.
+  for (int64_t k = 0; k < reductionSize; ++k) {
     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
     Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
     res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
@@ -1742,88 +1784,6 @@ void ContractionOpToOuterProductOpLowering::rewrite(
   rewriter.replaceOp(op, res);
 }
 
-/// Progressive lowering of a `vector.contract %a, %b, %c` with
-/// matvec semantics to series of AXPY operations that are chained
-/// through FMA operations.
-///
-/// This only kicks in when VectorTransformsOptions is set to AXPY.
-//
-// TODO: this is very similar, but not quite the same as the outerproduct
-// lowering above; merge the two?
-LogicalResult
-ContractionOpToAXPYLowering::match(vector::ContractionOp op) const {
-  // TODO: implement masks
-  if (llvm::size(op.masks()) != 0)
-    return failure();
-
-  if (vectorTransformsOptions.vectorContractLowering !=
-      vector::VectorContractLowering::AXPY)
-    return failure();
-
-  auto iteratorTypes = op.iterator_types().getValue();
-  if (!isParallelIterator(iteratorTypes[0]) ||
-      !isReductionIterator(iteratorTypes[1]))
-    return failure();
-
-  // See if a series of AXPY operations chained through FMA operations
-  // could replace the default DOT implementation.
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-  AffineExpr m, n;
-  bindDims(op.getContext(), m, n);
-  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec
-      maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec
-      maps != infer({{n}, {m, n}, {m}}) && // vec-mat
-      maps != infer({{n}, {n, m}, {m}}))   // vec-mat-trans
-    return failure();
-  return success();
-}
-
-void ContractionOpToAXPYLowering::rewrite(vector::ContractionOp op,
-                                          PatternRewriter &rewriter) const {
-  Location loc = op.getLoc();
-  VectorType lhsType = op.getLhsType();
-  Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
-
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-  AffineExpr m, n;
-  bindDims(op.getContext(), m, n);
-  SmallVector<int64_t, 2> perm{1, 0};
-  //
-  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  int64_t reductionSize = 0;
-  if (maps == infer({{m, n}, {n}, {m}})) {
-    // Case mat-vec: transpose.
-    reductionSize = lhsType.getDimSize(1);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{n, m}, {n}, {m}})) {
-    // Case mat-trans-vec: ready to go.
-    reductionSize = lhsType.getDimSize(0);
-  } else if (maps == infer({{n}, {m, n}, {m}})) {
-    // Case vec-mat: swap and transpose.
-    reductionSize = lhsType.getDimSize(0);
-    std::swap(lhs, rhs);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{n}, {n, m}, {m}})) {
-    // Case vec-mat-trans: swap and ready to go.
-    reductionSize = lhsType.getDimSize(0);
-    std::swap(lhs, rhs);
-  }
-  assert(reductionSize > 0);
-
-  // A direct series of AXPY operations, chained through FMA.
-  Type resType = op.getResultType();
-  for (int64_t k = 0; k < reductionSize; ++k) {
-    Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
-    Value s = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-    Value b = rewriter.create<vector::BroadcastOp>(loc, resType, s);
-    res = rewriter.create<vector::FMAOp>(loc, a, b, res);
-  }
-  rewriter.replaceOp(op, res);
-}
-
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1862,9 +1822,6 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
   if (succeeded(pat2.match(op)))
     return failure();
-  ContractionOpToAXPYLowering pat3(vectorTransformsOptions, ctx);
-  if (succeeded(pat3.match(op)))
-    return failure();
 
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
@@ -2050,7 +2007,6 @@ void mlir::vector::populateVectorContractLoweringPatterns(
   patterns.insert<TransposeOpLowering,
                   ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
-                  ContractionOpToOuterProductOpLowering,
-                  ContractionOpToAXPYLowering>(parameters, context);
+                  ContractionOpToOuterProductOpLowering>(parameters, context);
   // clang-format on
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 81c70983cded..3e2896c82bbe 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
 
 #matvec_accesses = [
   affine_map<(i, j) -> (i, j)>,
@@ -44,27 +44,18 @@
 // CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
 // CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
 // CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
 // CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
 // CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
-// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
-// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
-// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
-// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
-// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
-// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
-// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
-// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
-// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
+// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: return
 func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
                                                 %arg2: memref<vector<2xf32>>) {
   %A = load %arg0[] : memref<vector<2x2xf32>>
@@ -84,13 +75,12 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
 // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
-// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
-// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
+// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: return
 func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
                                                      %arg2: memref<vector<2xf32>>) {
   %A = load %arg0[] : memref<vector<2x2xf32>>
@@ -105,27 +95,18 @@ func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>
 // CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
 // CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
 // CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
 // CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
 // CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
-// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
-// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
-// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
-// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
-// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
-// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
-// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
-// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
-// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
+// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: return
 func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
                                                 %arg2: memref<vector<2xf32>>) {
   %A = load %arg0[] : memref<vector<2x2xf32>>
@@ -145,13 +126,12 @@ func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
 // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
-// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
-// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
+// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: return
 func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
                                                      %arg2: memref<vector<2xf32>>) {
   %A = load %arg0[] : memref<vector<2x2xf32>>

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 1af6c3564b80..7e28ebbd9b72 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -59,9 +59,6 @@ struct TestVectorContractionConversion
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
       llvm::cl::init(false)};
-  Option<bool> lowerToAXPY{*this, "vector-axpy",
-                           llvm::cl::desc("Lower vector.contract to AXPY"),
-                           llvm::cl::init(false)};
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
@@ -80,8 +77,6 @@ struct TestVectorContractionConversion
     VectorContractLowering contractLowering = VectorContractLowering::Dot;
     if (lowerToFlatMatrix)
       contractLowering = VectorContractLowering::Matmul;
-    else if (lowerToAXPY)
-      contractLowering = VectorContractLowering::AXPY;
     VectorTransposeLowering transposeLowering =
         VectorTransposeLowering::EltWise;
     if (lowerToFlatTranspose)


        


More information about the Mlir-commits mailing list