[Mlir-commits] [mlir] ba10daa - [mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.

Nicolas Vasilache llvmlistbot at llvm.org
Tue May 26 12:45:16 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-26T15:40:55-04:00
New Revision: ba10daa820fa868816eed2b85e70197d354ebfe6

URL: https://github.com/llvm/llvm-project/commit/ba10daa820fa868816eed2b85e70197d354ebfe6
DIFF: https://github.com/llvm/llvm-project/commit/ba10daa820fa868816eed2b85e70197d354ebfe6.diff

LOG: [mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.

This revision expands the types of vector contractions that can be lowered to vector.outerproduct.
All 8 permutation cases are support.
The idiomatic manipulation of AffineMap written declaratively makes this straightforward.

In the process a bug with the vector.contract verifier was uncovered.
The vector shape verification part of the contract op is rewritten to use AffineMap composition.
One bug in the vector `ops.mlir` test is fixed and a new case not yet captured is added
to the vector`invalid.mlir` test.

Differential Revision: https://reviews.llvm.org/D80393

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 5a36aabfab75..02d276256076 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -81,12 +81,24 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
 
 /// Use to encode that a particular iterator type has parallel semantics.
 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
+constexpr bool isParallelIterator(Attribute attr) {
+  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+}
 
 /// Use to encode that a particular iterator type has reduction semantics.
 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
+constexpr bool isReductionIterator(Attribute attr) {
+  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+}
 
 /// Use to encode that a particular iterator type has window semantics.
 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
+constexpr bool isWindowIterator(Attribute attr) {
+  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+  return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
+}
 
 /// Use to encode that a particular iterator type has window semantics.
 inline ArrayRef<StringRef> getAllIteratorTypeNames() {

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1574edb34494..63891d1004d4 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -223,8 +223,9 @@ static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
   return true;
 }
 
-static bool verifyOutputShape(
-    VectorType lhsType, VectorType rhsType, Type accType, Type resType,
+static LogicalResult verifyOutputShape(
+    ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
+    Type resType,
     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
   DenseSet<int64_t> lhsContractingDimSet;
@@ -256,26 +257,56 @@ static bool verifyOutputShape(
   if (expectedResultDims.size() == 0) {
     // No batch or free dimension implies a scalar result.
     if (resType.isa<VectorType>() || accType.isa<VectorType>())
-      return false;
-
+      return op.emitOpError("invalid accumulator/result vector shape");
   } else {
     // At least one batch or free dimension implies a vector result.
     auto resVectorType = resType.dyn_cast<VectorType>();
     auto accVectorType = accType.dyn_cast<VectorType>();
     if (!resVectorType || !accVectorType)
-      return false;
-
-    // Verify dimension from 'resType' against 'expectedResultDims'.
-    if (resVectorType.getShape().size() != expectedResultDims.size() ||
-        accVectorType.getShape().size() != expectedResultDims.size())
-      return false;
-    for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) {
-      if (resVectorType.getDimSize(i) != expectedResultDims[i] ||
-          accVectorType.getDimSize(i) != expectedResultDims[i])
-        return false;
+      return op.emitOpError("invalid accumulator/result vector shape");
+
+    // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
+    // types fully define the result vector type. This assumes the affine maps
+    // are well-formed, which must have been verified already.
+    MLIRContext *ctx = op.getContext();
+    AffineMap lhsMap = op.getIndexingMaps()[0];
+    AffineMap rhsMap = op.getIndexingMaps()[1];
+    SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
+    for (auto pair :
+         {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
+      VectorType v = pair.first;
+      auto map = pair.second;
+      for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
+        unsigned pos = map.getResult(idx).cast<AffineDimExpr>().getPosition();
+        if (!extents[pos])
+          extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
+      }
     }
+    assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
+           "expected extent along all dimensions.");
+
+    AffineMap resMap = op.getIndexingMaps()[2];
+    auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
+                                     /*symCount=*/0, extents, ctx);
+    // Compose the resMap with the extentsMap, which is a constant map.
+    AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
+    assert(llvm::all_of(
+               expectedMap.getResults(),
+               [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
+           "expected constant extent along all dimensions.");
+    // Extract the expected shape and build the type.
+    auto expectedShape = llvm::to_vector<4>(
+        llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
+          return e.cast<AffineConstantExpr>().getValue();
+        }));
+    auto expected =
+        VectorType::get(expectedShape, resVectorType.getElementType());
+    if (resVectorType != expected || accVectorType != expected)
+      return op.emitOpError(
+                 "invalid accumulator/result vector shape, expected: ")
+             << expected;
   }
-  return true;
+  return success();
 }
 
 static LogicalResult verify(ContractionOp op) {
@@ -329,9 +360,9 @@ static LogicalResult verify(ContractionOp op) {
     return op.emitOpError("invalid batch dimension map");
 
   // Verify 'accType' and 'resType' shape.
-  if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
-                         batchDimMap))
-    return op.emitOpError("invalid accumulator/result vector shape");
+  if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
+                               contractingDimMap, batchDimMap)))
+    return failure();
 
   // Verify that either two vector masks are set or none are set.
   auto lhsMaskType = op.getLHSVectorMaskType();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 44ff03a04f22..491ad62affcb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1454,10 +1454,17 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
   if (llvm::size(op.masks()) != 0)
     return failure();
 
+  auto iteratorTypes = op.iterator_types().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isParallelIterator(iteratorTypes[1]) ||
+      !isReductionIterator(iteratorTypes[2]))
+    return failure();
+
   if (vectorTransformsOptions.vectorContractLowering !=
           vector::VectorContractLowering::Matmul ||
       !isRowMajorMatmul(op.indexing_maps()))
     return failure();
+
   return success();
 }
 
@@ -1503,34 +1510,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
 /// ```
 ///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-void ContractionOpToOuterProductOpLowering::rewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
-  VectorType lhsType = op.getLhsType();
-  // TODO(ntv) other modes.
-  // We know we are in row-major.
-  bool transposeLhs = false;
-  unsigned reductionSize =
-      transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
-
-  // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
-  // transpose it to extract the proper vector<m x f32>. Otherwise, just take
-  // the lhs.
-  Value lhs = transposeLhs
-                  ? op.lhs()
-                  : rewriter.create<vector::TransposeOp>(
-                        op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
-  Value res = op.acc();
-  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
-  for (unsigned k = 0; k < reductionSize; ++k) {
-    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
-    Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
-    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
-  }
-  rewriter.replaceOp(op, res);
-}
-
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
 LogicalResult
 ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
   // TODO(ajcbik): implement masks
@@ -1538,12 +1519,104 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
     return failure();
 
   if (vectorTransformsOptions.vectorContractLowering !=
-          vector::VectorContractLowering::OuterProduct ||
-      !isRowMajorMatmul(op.indexing_maps()))
+      vector::VectorContractLowering::OuterProduct)
+    return failure();
+
+  // Transpose arguments to make them ready for lowering to OuterProduct. The
+  // constraint to match is that we must load full rows at a time with
+  // vector::ExtractOp.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n, k;
+  bindDims(op.getContext(), m, n, k);
+  auto iteratorTypes = op.iterator_types().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isParallelIterator(iteratorTypes[1]) ||
+      !isReductionIterator(iteratorTypes[2]))
+    return failure();
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  // When lowering to outerproduct we can support all permutations.
+  if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
+      maps != infer({{m, k}, {n, k}, {m, n}}) &&
+      maps != infer({{k, m}, {k, n}, {m, n}}) &&
+      maps != infer({{k, m}, {n, k}, {m, n}}) &&
+      maps != infer({{m, k}, {k, n}, {n, m}}) &&
+      maps != infer({{m, k}, {n, k}, {n, m}}) &&
+      maps != infer({{k, m}, {k, n}, {n, m}}) &&
+      maps != infer({{k, m}, {n, k}, {n, m}}))
     return failure();
   return success();
 }
 
+void ContractionOpToOuterProductOpLowering::rewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  unsigned reductionSize = 0;
+  VectorType lhsType = op.getLhsType();
+  Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
+
+  // Transpose arguments to make them ready for lowering to OuterProduct. The
+  // constraint to match is that we must load full rows at a time with
+  // vector::ExtractOp.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n, k;
+  bindDims(rewriter.getContext(), m, n, k);
+  SmallVector<int64_t, 2> perm{1, 0};
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  // First batch of cases, no need to output permute.
+  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+    // This is the classical row-major matmul. Just permute the lhs.
+    reductionSize = lhsType.getShape()[1];
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    reductionSize = lhsType.getShape()[1];
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+    // No need to permute anything.
+    reductionSize = lhsType.getShape()[0];
+  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+    // Just permute the rhs.
+    reductionSize = lhsType.getShape()[0];
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  }
+  // Second batch of cases, reshuffle to avoid output permute.
+  else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+    // This is the classical row-major matmul. Just permute the lhs.
+    reductionSize = lhsType.getShape()[1];
+    Value tmp = rhs;
+    rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    lhs = tmp;
+  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    reductionSize = lhsType.getShape()[1];
+    Value tmp = rhs;
+    rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+    // No need to permute anything, but still swap lhs and rhs.
+    reductionSize = lhsType.getShape()[0];
+    std::swap(lhs, rhs);
+  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+    // Just permute the rhs.
+    reductionSize = lhsType.getShape()[0];
+    Value tmp = lhs;
+    lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    rhs = tmp;
+  }
+  assert(reductionSize > 0);
+
+  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+  for (unsigned k = 0; k < reductionSize; ++k) {
+    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
+    Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
+    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
+  }
+  rewriter.replaceOp(op, res);
+}
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c18cf38edfc9..cc72511a6e78 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -767,6 +767,26 @@ func @contraction(%arg0: vector<4x3xi32>,
 
 // -----
 
+#contraction_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#contraction_trait = {
+  indexing_maps = #contraction_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<3x2xf32>
+{
+// expected-error at +1 {{invalid accumulator/result vector shape, expected: 'vector<3x2xf32>'}}
+  %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// -----
+
 func @create_mask() {
   %c2 = constant 2 : index
   %c3 = constant 3 : index

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c194cbe23811..57c03c903fe8 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -160,9 +160,11 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
   indexing_maps = #contraction_accesses0,
   iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
 }
-#contraction_accesses1 = [
+#contraction_accesses1 = [              // 7,  8, 16, 15
   affine_map<(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2)>,
+                                        // 8, 16,  7,  5
   affine_map<(f0, f1, f2, f3, c0, c1) -> (f1, c1, c0, f3)>,
+                                        // 8,  8, 15,  5
   affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)>
 ]
 #contraction_trait1 = {
@@ -172,7 +174,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
 }
 // CHECK-LABEL: contraction
 func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
-                  %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
+                  %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
                   %arg4 : index) {
   // Test contraction with batch and contracting dims.
   // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
@@ -181,16 +183,16 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
   // Test contraction with only contracting dims. In this case the lhs/rhs
   // dimension of size 8 will be considered a parallel dim for lhs/rhs and will
   // appear twice in the output.
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
-      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   // Test contraction with optional vector mask arguments.
   %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
   %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
                                            %rhs_mask
-      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   return
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 7eea3baa8d87..1dd2f377a29c 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -681,3 +681,219 @@ func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
   %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
   return %0 : vector<2x3xi1>
 }
+
+#matmat_accesses_0 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_0 = {
+  indexing_maps = #matmat_accesses_0,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_0
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_1 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+  indexing_maps = #matmat_accesses_1,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_1
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_2 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+  indexing_maps = #matmat_accesses_2,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_2
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_3 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+  indexing_maps = #matmat_accesses_3,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_3
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+#matmat_accesses_4 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+  indexing_maps = #matmat_accesses_4,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_4
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_5 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_5 = {
+  indexing_maps = #matmat_accesses_5,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_5
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+  %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_6 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_6 = {
+  indexing_maps = #matmat_accesses_6,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_6
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+  %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
+#matmat_accesses_7 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_7 = {
+  indexing_maps = #matmat_accesses_7,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// OUTERPRODUCT-LABEL: func @matmul_7
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
+//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
+//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
+func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
+-> vector<3x2xf32>
+{
+  %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}


        


More information about the Mlir-commits mailing list