[Mlir-commits] [mlir] ad9b5a4 - [mlir][vector] Add pattern to drop lead unit dim for Contraction Op

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 10 09:52:17 PST 2022

Author: Nirvedh
Date: 2022-02-10T09:51:07-08:00
New Revision: ad9b5a4b8e47489e4ae952b21484b9a98c3e6e0d

URL: https://github.com/llvm/llvm-project/commit/ad9b5a4b8e47489e4ae952b21484b9a98c3e6e0d
DIFF: https://github.com/llvm/llvm-project/commit/ad9b5a4b8e47489e4ae952b21484b9a98c3e6e0d.diff

LOG: [mlir][vector] Add pattern to drop lead unit dim for Contraction Op

If the result operand has a unit leading dim it is removed from all operands.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D119206




diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e6568b8ed2706..009df114ec2c2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -200,7 +200,10 @@ def Vector_ContractionOp :
       "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
-      "ArrayRef<StringRef>":$iteratorTypes)>
+      "ArrayRef<StringRef>":$iteratorTypes)>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
+      "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
+      "CombiningKind":$kind)>
   let extraClassDeclaration = [{
     VectorType getLhsType() {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f7a89389afde3..2d504cb0029c4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -502,13 +502,20 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                   Value lhs, Value rhs, Value acc,
                                   ArrayAttr indexingMaps,
                                   ArrayAttr iteratorTypes) {
+  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
+        ContractionOp::getDefaultKind());
+void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
+                                  Value lhs, Value rhs, Value acc,
+                                  ArrayAttr indexingMaps,
+                                  ArrayAttr iteratorTypes, CombiningKind kind) {
   result.addOperands({lhs, rhs, acc});
   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
-                      CombiningKindAttr::get(ContractionOp::getDefaultKind(),
-                                             builder.getContext()));
+                      CombiningKindAttr::get(kind, builder.getContext()));
 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 22fda05e0e7db..8db685e802336 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,7 @@
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
@@ -220,6 +221,128 @@ struct CastAwayTransferWriteLeadingOneDim
+/// Turns vector.contract on vector with leading 1 dimensions into
+/// vector.extract followed by vector.contract on vector without leading
+/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
+/// prior to extract.
+struct CastAwayContractionLeadingOneDim
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
+    if (oldAccType == nullptr)
+      return failure();
+    if (oldAccType.getRank() < 2)
+      return failure();
+    // TODO: implement masks.
+    if (llvm::size(contractOp.masks()) != 0)
+      return failure();
+    if (oldAccType.getShape()[0] != 1)
+      return failure();
+    // currently we support only dropping one dim but the pattern can be applied
+    // greedily to drop more.
+    int64_t dropDim = 1;
+    auto oldIndexingMaps = contractOp.getIndexingMaps();
+    SmallVector<AffineMap> newIndexingMaps;
+    auto oldIteratorTypes = contractOp.iterator_types();
+    SmallVector<Attribute> newIteratorTypes;
+    int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
+    if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
+      // only parallel type iterators can be dropped.
+      return failure();
+    for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
+      int64_t currDim = it.index();
+      if (currDim == dimToDrop)
+        continue;
+      newIteratorTypes.push_back(it.value());
+    }
+    SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(),
+                                   contractOp.acc()};
+    SmallVector<Value> newOperands;
+    for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
+      // Check if the dim to be dropped exists as a leading dim in the operand
+      // if it does then we use vector.extract to drop it.
+      bool validExtract = false;
+      SmallVector<AffineExpr> results;
+      auto map = it.value();
+      int64_t orginalZeroDim = it.value().getDimPosition(0);
+      if (orginalZeroDim != dimToDrop) {
+        // There are two reasons to be in this path, 1. We need to
+        // tranpose the operand to make the dim to be dropped
+        // leading. 2. The dim to be dropped does not exist and in
+        // that case we dont want to add a unit tranpose but we must
+        // check all the indices to make sure this is the case.
+        bool tranposeNeeded = false;
+        SmallVector<int64_t> perm;
+        SmallVector<AffineExpr> transposeResults;
+        for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+          int64_t currDim = map.getDimPosition(i);
+          if (currDim == dimToDrop) {
+            tranposeNeeded = true;
+            perm.insert(perm.begin(), i);
+            auto targetExpr = rewriter.getAffineDimExpr(currDim);
+            transposeResults.insert(transposeResults.begin(), targetExpr);
+          } else {
+            perm.push_back(i);
+            auto targetExpr = rewriter.getAffineDimExpr(currDim);
+            transposeResults.push_back(targetExpr);
+          }
+        }
+        // Do the tranpose now if needed so that we can drop the
+        // correct dim using extract later.
+        if (tranposeNeeded) {
+          map = AffineMap::get(map.getNumDims(), 0, transposeResults,
+                               contractOp.getContext());
+          operands[it.index()] = rewriter.create<vector::TransposeOp>(
+              contractOp.getLoc(), operands[it.index()], perm);
+        }
+      }
+      // We have taken care to have the dim to be dropped be
+      // the leading dim. If its still not leading that means it
+      // does not exist in this operand and hence we do not need
+      // an extract.
+      if (map.getDimPosition(0) == dimToDrop)
+        validExtract = true;
+      for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+        int64_t currDim = map.getDimPosition(i);
+        if (currDim == dimToDrop)
+          // This is the dim we are dropping.
+          continue;
+        auto targetExpr = rewriter.getAffineDimExpr(
+            currDim < dimToDrop ? currDim : currDim - 1);
+        results.push_back(targetExpr);
+      }
+      newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
+                                               contractOp.getContext()));
+      // Extract if its a valid extraction, otherwise use the operand
+      // without extraction.
+      newOperands.push_back(validExtract
+                                ? rewriter.create<vector::ExtractOp>(
+                                      contractOp.getLoc(), operands[it.index()],
+                                      splatZero(dropDim))
+                                : operands[it.index()]);
+    }
+    auto newContractOp = rewriter.create<vector::ContractionOp>(
+        contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+        rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        rewriter.getArrayAttr(newIteratorTypes), contractOp.kind());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        contractOp, contractOp->getResultTypes()[0], newContractOp);
+    return success();
+  }
 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
@@ -260,10 +383,11 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
-               CastAwayInsertStridedSliceLeadingOneDim,
-               CastAwayTransferReadLeadingOneDim,
-               CastAwayTransferWriteLeadingOneDim,
-               CastAwayElementwiseLeadingOneDim>(patterns.getContext());
+  patterns
+      .add<CastAwayExtractStridedSliceLeadingOneDim,
+           CastAwayInsertStridedSliceLeadingOneDim,
+           CastAwayTransferReadLeadingOneDim,
+           CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
+           CastAwayContractionLeadingOneDim>(patterns.getContext());

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
new file mode 100644
index 0000000000000..70beb0fe43c5f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -0,0 +1,267 @@
+// RUN: mlir-opt %s -test-vector-to-vector-lowering -split-input-file| FileCheck %s
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: cast_away_contraction_leading_one_dims
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<1x16x8xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0] : vector<1x16x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
+//  CHECK-NEXT:  return %[[R4]] : vector<1x16x16xf32>
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  return %0: vector<1x16x16xf32>
+// -----
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0, 0] : vector<1x1x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>}
+//  CHECK-SAME:   %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32>
+//  CHECK-NEXT:   %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32>
+//  CHECK-NEXT:  return %[[R5]] : vector<1x1x16xf32>
+#contraction_accesses1 = [
+  affine_map<(l, i, j, k) -> (i, l, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+#contraction_trait1 = {
+  indexing_maps = #contraction_accesses1,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+  kind = #vector.kind<mul>
+func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<1x1x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x1x16xf32>) -> vector<1x1x16xf32> {
+  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2  : vector<1x1x8xf32>, vector<1x8x16xf32> into vector<1x1x16xf32>
+  return %0: vector<1x1x16xf32>
+// -----
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %[[R0]][0] : vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0] : vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
+//  CHECK-NEXT:   %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:  return %[[R6]] : vector<1x2x16xf32>
+#contraction_accesses2 = [
+  affine_map<(l, i, j, k) -> (k, l, j)>,
+  affine_map<(l, i, j, k) -> (i, k, l)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+#contraction_trait2 = {
+  indexing_maps = #contraction_accesses2,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector<8x1x16xf32>, %arg1: vector<2x8x1xf32>, %arg2: vector<1x2x16xf32>) -> vector<1x2x16xf32> {
+  %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<8x1x16xf32>, vector<2x8x1xf32> into vector<1x2x16xf32>
+  return %0: vector<1x2x16xf32>
+// -----
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<1x8x1x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<1x2x8x1xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R5:.+]] =  vector.extract %[[R4]][0] : vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R6:.+]] =  vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32>
+//  CHECK-NEXT:   %[[R7:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
+//  CHECK-NEXT:   %[[R8:.+]] =  vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R9:.+]] =  vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+//  CHECK-NEXT:  return %[[R9]] : vector<1x1x2x16xf32>
+#contraction_accesses2 = [
+  affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
+  affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
+  affine_map<(m, l, i, j, k) -> (m, l, i, j)>
+#contraction_trait2 = {
+  indexing_maps = #contraction_accesses2,
+  iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
+func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
+  %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
+  return %0: vector<1x1x2x16xf32>
+// -----
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %[[R0]][0, 0] : vector<1x1x8x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R1]][0, 0] : vector<1x1x2x8xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32>
+//  CHECK-NEXT:   %[[R5:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
+//  CHECK-NEXT:   %[[R6:.+]] =  vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R7:.+]] =  vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+//  CHECK-NEXT:  return %[[R7]] : vector<1x1x2x16xf32>
+#contraction_accesses3 = [
+  affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
+  affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
+  affine_map<(m, l, i, j, k) -> (l, m, i, j)>
+#contraction_trait3 = {
+  indexing_maps = #contraction_accesses3,
+  iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
+func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
+  %0 = vector.contract #contraction_trait3 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
+  return %0: vector<1x1x2x16xf32>
+// -----
+// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
+func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x1x8xf16>
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
+func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
+  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x8x8xf16>
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
+//  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
+func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
+  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16>
+  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
+  // CHECK: return %[[B]]
+  return %0: vector<1x1x1xf16>
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
+func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x4xf16>
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
+func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0. : f16
+  // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
+  return %0: vector<1x1xf16>
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
+func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+  return
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
+func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16>
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
+  return
+// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
+func @cast_away_elementwise_leading_one_dims(
+  %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
+  %arg3: vector<1x4xf32>, %arg4: i1) ->
+  (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+  // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+  %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+  %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
+  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 97d7316a55721..907252b98f0ac 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -419,106 +419,6 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
   return %r : tensor<4x4xf32>
-// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
-func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
-  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
-  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
-  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
-  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
-  // CHECK: return %[[RET]]
-  return %0: vector<1x1x8xf16>
-// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
-func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
-  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16>
-  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
-  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
-  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
-  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
-  // CHECK: return %[[RET]]
-  return %0: vector<1x8x8xf16>
-// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
-//  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
-func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
-  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16>
-  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
-  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
-  // CHECK: return %[[B]]
-  return %0: vector<1x1x1xf16>
-// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
-func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
-  // CHECK: %[[C0:.+]] = arith.constant 0 : index
-  %c0 = arith.constant 0 : index
-  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
-  %f0 = arith.constant 0. : f16
-  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
-  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
-  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
-  // CHECK: return %[[CAST]]
-  return %0: vector<1x4xf16>
-// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
-func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
-  %c0 = arith.constant 0 : index
-  %f0 = arith.constant 0. : f16
-  // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
-  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
-  return %0: vector<1x1xf16>
-// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
-func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
-  // CHECK: %[[C0:.+]] = arith.constant 0 : index
-  %c0 = arith.constant 0 : index
-  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16>
-  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
-  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
-  return
-// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
-func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
-  %c0 = arith.constant 0 : index
-  // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16>
-  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
-  return
-// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
-func @cast_away_elementwise_leading_one_dims(
-  %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
-  %arg3: vector<1x4xf32>, %arg4: i1) ->
-  (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
-  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
-  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
-  // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
-  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
-  %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
-  %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
-  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
-  %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
-  // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
-  %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
-  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
 // CHECK-LABEL: func @bubble_down_bitcast_in_extract
 //  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
 func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {


More information about the Mlir-commits mailing list