[Mlir-commits] [mlir] 0ba9ee9 - [mlir] [VectorOps] Framework for progressive lowering of vector.contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 11:36:36 PST 2020


Author: aartbik
Date: 2020-02-19T11:36:11-08:00
New Revision: 0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2

URL: https://github.com/llvm/llvm-project/commit/0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2
DIFF: https://github.com/llvm/llvm-project/commit/0ba9ee9f0e55b1f9155a5bb5739550860fa5fec2.diff

LOG: [mlir] [VectorOps] Framework for progressive lowering of vector.contract

Summary:
Lowers all free/batch dimensions in a vector.contract progressively
into simpler vector.contract operations until a direct vector.reduction
operation is reached. Then lowers 1-D reductions into vector.reduce.

Still TBD:
multi-dimensional contractions that remain after removing all the parallel dims

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: andydavis1

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74797

Added: 
    

Modified: 
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index fe62666a2838..61a3d556a70b 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/Dialect/VectorOps/VectorTransforms.h"
 #include "mlir/Dialect/VectorOps/VectorUtils.h"
@@ -864,6 +865,19 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
 };
 
 /// Progressive lowering of ConstractionOp.
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a fma/reduction op.
+///
+/// TODO(ajcbik): break down into transpose/reshape/cast ops
+///               when they become available to avoid code dup
+/// TODO(ajcbik): investigate lowering order impact on performance
 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -874,16 +888,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     if (llvm::size(op.masks()) != 0)
       return matchFailure();
 
-    auto loc = op.getLoc();
-    VectorType lhsType = op.getLhsType();
-    VectorType rhsType = op.getRhsType();
-    Type resType = op.getResultType();
-
-    // Find first batch dimension in lhs/rhs, and lower when found.
+    // Find first batch dimension in LHS/RHS, and lower when found.
     std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
     if (!batchDimMap.empty()) {
-      // TODO(ajcbik): implement batch
-      return matchFailure();
+      int64_t lhsIndex = batchDimMap[0].first;
+      int64_t rhsIndex = batchDimMap[0].second;
+      rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
+      return matchSuccess();
     }
 
     // Collect contracting dimensions.
@@ -896,24 +907,35 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
       rhsContractingDimSet.insert(dimPair.second);
     }
 
-    // Find free dimension in lhs/rhs, and lower first when found.
-    for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
-      if (lhsContractingDimSet.count(i) == 0) {
-        // TODO(ajcbik): implement free
-        return matchFailure();
+    // Find first free dimension in LHS, and lower when found.
+    VectorType lhsType = op.getLhsType();
+    for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
+         ++lhsIndex) {
+      if (lhsContractingDimSet.count(lhsIndex) == 0) {
+        rewriter.replaceOp(
+            op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
+        return matchSuccess();
       }
     }
-    for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
-      if (rhsContractingDimSet.count(i) == 0) {
-        // TODO(ajcbik): implement free
-        return matchFailure();
+
+    // Find first free dimension in RHS, and lower when found.
+    VectorType rhsType = op.getRhsType();
+    for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
+         ++rhsIndex) {
+      if (rhsContractingDimSet.count(rhsIndex) == 0) {
+        rewriter.replaceOp(
+            op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
+        return matchSuccess();
       }
     }
 
-    // Only contraction dimensions remain.
+    // Lower the only remaining contraction dimensions.
+    // TODO(ajcbik): handle multi-dim reductions
+    auto loc = op.getLoc();
+    Type resType = op.getResultType();
     if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
         rhsType.getRank() == 1) {
-      // Handle reduction into scalar.
+
       Value zero = rewriter.create<ConstantOp>(loc, resType,
                                                rewriter.getZeroAttr(resType));
       Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
@@ -924,9 +946,191 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
                                                          op.acc());
       return matchSuccess();
     }
-    // TODO(ajcbik): implement more contraction
+
     return matchFailure();
   }
+
+private:
+  // Lower one parallel dimension.
+  // TODO(ajcbik): consider reusing existing contract unrolling
+  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+                      int64_t rhsIndex, PatternRewriter &rewriter) const {
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    VectorType resType = op.getResultType().cast<VectorType>();
+    // Find the iterator type index and result index.
+    SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+    int64_t iterIndex = -1;
+    int64_t dimSize = -1;
+    if (lhsIndex >= 0) {
+      iterIndex =
+          iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
+      assert((rhsIndex < 0 || iterIndex == iMap[1]
+                                               .getResult(rhsIndex)
+                                               .cast<AffineDimExpr>()
+                                               .getPosition()) &&
+             "parallel index should be free in LHS or batch in LHS/RHS");
+      dimSize = lhsType.getDimSize(lhsIndex);
+    } else {
+      assert(rhsIndex >= 0 && "missing parallel index");
+      iterIndex =
+          iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
+      dimSize = rhsType.getDimSize(rhsIndex);
+    }
+    assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
+    Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
+    assert(lookup.hasValue() && "parallel index not listed in reduction");
+    int64_t resIndex = lookup.getValue();
+    // Construct new iterator types.
+    ArrayAttr iteratorTypes = op.iterator_types();
+    SmallVector<Attribute, 4> lowIterTypes;
+    for (auto it : llvm::enumerate(iteratorTypes)) {
+      int64_t idx = it.index();
+      if (idx == iterIndex) {
+        assert(it.value().cast<StringAttr>().getValue() ==
+                   getParallelIteratorTypeName() &&
+               "parallel index not marked as such");
+        continue;
+      }
+      lowIterTypes.push_back(it.value());
+    }
+    // Construct new affine map array attribute.
+    SmallVector<AffineMap, 4> lowIndexingMaps;
+    lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+    lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+    lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+    auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+    // Construct new iterator types array attribute.
+    auto lowIter = rewriter.getArrayAttr(lowIterTypes);
+    // Unroll into a series of lower dimensional vector.contract ops.
+    Location loc = op.getLoc();
+    Value result = zeroVector(loc, resType, rewriter);
+    for (int64_t d = 0; d < dimSize; ++d) {
+      auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+      auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+      auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
+      Value lowContract = rewriter.create<vector::ContractionOp>(
+          loc, lhs, rhs, acc, lowAffine, lowIter);
+      result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
+                            rewriter);
+    }
+    return result;
+  }
+
+  // Helper method to construct a zero vector.
+  static Value zeroVector(Location loc, VectorType vType,
+                          PatternRewriter &rewriter) {
+    Type eltType = vType.getElementType();
+    Value zero = rewriter.create<ConstantOp>(loc, eltType,
+                                             rewriter.getZeroAttr(eltType));
+    return rewriter.create<SplatOp>(loc, vType, zero);
+  }
+
+  // Helper to find an index in an affine map.
+  static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+      int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+      if (idx == index)
+        return i;
+    }
+    return None;
+  }
+
+  // Helper to construct an affine map with one index removed.
+  static AffineMap adjustMap(AffineMap map, int64_t index,
+                             PatternRewriter &rewriter) {
+    SmallVector<AffineExpr, 4> results;
+    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+      int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+      if (idx == index)
+        continue;
+      // Re-insert remaining indices, but renamed when occurring
+      // after the removed index.
+      auto targetExpr =
+          getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext());
+      results.push_back(targetExpr);
+    }
+    // Since (...) -> () cannot be represented properly,
+    // we resort to an empty map when this situation happens.
+    return results.empty() ? AffineMap::get(rewriter.getContext())
+                           : AffineMap::get(map.getNumDims() - 1, 0, results);
+  }
+
+  // Helper to drop dimension from vector type.
+  static Type adjustType(VectorType tp, int64_t index) {
+    int64_t rank = tp.getRank();
+    Type eltType = tp.getElementType();
+    if (rank == 1) {
+      assert(index == 0 && "index for scalar result out of bounds");
+      return eltType;
+    }
+    SmallVector<int64_t, 4> adjustedShape;
+    for (int64_t i = 0; i < rank; ++i) {
+      // Omit dimension at the given index.
+      if (i == index)
+        continue;
+      // Otherwise, add dimension back.
+      adjustedShape.push_back(tp.getDimSize(i));
+    }
+    return VectorType::get(adjustedShape, eltType);
+  }
+
+  // Helper method to possibly drop a dimension in a load.
+  // TODO(ajcbik): use a reshaping vector load (and share lowering code)
+  static Value reshapeLoad(Location loc, Value val, VectorType type,
+                           int64_t index, int64_t pos,
+                           PatternRewriter &rewriter) {
+    if (index == -1)
+      return val;
+    Type lowType = adjustType(type, 0);
+    // At extraction dimension?
+    if (index == 0) {
+      auto posAttr = rewriter.getI64ArrayAttr(pos);
+      return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
+    }
+    // Unroll leading dimensions.
+    VectorType vType = lowType.cast<VectorType>();
+    VectorType resType = adjustType(type, index).cast<VectorType>();
+    Value result = zeroVector(loc, resType, rewriter);
+    for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
+      auto posAttr = rewriter.getI64ArrayAttr(d);
+      Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+      Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
+      result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
+                                                 posAttr);
+    }
+    return result;
+  }
+
+  // Helper method to possibly drop a dimension in a store.
+  // TODO(ajcbik): use a reshaping vector store (and share lowering code)
+  static Value reshapeStore(Location loc, Value val, Value result,
+                            VectorType type, int64_t index, int64_t pos,
+                            PatternRewriter &rewriter) {
+    // Unmodified?
+    if (index == -1)
+      return val;
+    // At insertion dimension?
+    if (index == 0) {
+      auto posAttr = rewriter.getI64ArrayAttr(pos);
+      return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
+    }
+    // Unroll leading dimensions.
+    Type lowType = adjustType(type, 0);
+    VectorType vType = lowType.cast<VectorType>();
+    Type insType = adjustType(vType, 0);
+    for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
+      auto posAttr = rewriter.getI64ArrayAttr(d);
+      Value ext =
+          rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
+      Value ins =
+          rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+      Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+      result =
+          rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+    }
+    return result;
+  }
 };
 
 } // namespace

diff  --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 6c4cb5f4bfd0..f781e37d586b 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -14,7 +14,7 @@
 // CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
-// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00>
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
 // CHECK:      %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
 // CHECK:      %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
 // CHECK:      return %[[R]] : f32
@@ -24,3 +24,148 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
     : vector<4xf32>, vector<4xf32> into f32
   return %0 : f32
 }
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract2
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
+// CHECK:      %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
+// CHECK:      %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:      return %[[T9]] : vector<2xf32>
+
+func @extract_contract2(%arg0: vector<2x3xf32>,
+                        %arg1: vector<3xf32>,
+			%arg2: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
+    : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+#vecmat_accesses = [
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i)>
+]
+#vecmat_trait = {
+  indexing_maps = #vecmat_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract3
+// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK:      %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
+// CHECK:      %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
+// CHECK:      %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:      return %[[T9]] : vector<2xf32>
+
+func @extract_contract3(%arg0: vector<3xf32>,
+                        %arg1: vector<2x3xf32>,
+                        %arg2: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
+    : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract4
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
+// CHECK:    %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK:    %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
+// CHECK:    %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
+// CHECK:    %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK:    %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
+// CHECK:    %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK:    %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
+// CHECK:    %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK:    %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32
+// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK:    %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
+// CHECK:    %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK:    %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
+// CHECK:    %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK:    %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32
+// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
+// CHECK:    %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
+// CHECK:    %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK:    %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
+// CHECK:    %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK:    %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
+// CHECK:    %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
+// CHECK:    %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32
+// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
+// CHECK:    %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
+// CHECK:    %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
+// CHECK:    %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
+// CHECK:    %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
+// CHECK:    %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
+// CHECK:    %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32
+// CHECK:    %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK:    return %[[T45]] : vector<2x2xf32>
+
+func @extract_contract4(%arg0: vector<2x2xf32>,
+                        %arg1: vector<2x2xf32>,
+                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+


        


More information about the Mlir-commits mailing list