[Mlir-commits] [mlir] 22c8a08 - [mlir][Vector] Fold chains of ExtractOp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 10 06:34:48 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-10T09:32:02-04:00
New Revision: 22c8a08fd8a1487159564f74f24561964f6a6c97

URL: https://github.com/llvm/llvm-project/commit/22c8a08fd8a1487159564f74f24561964f6a6c97
DIFF: https://github.com/llvm/llvm-project/commit/22c8a08fd8a1487159564f74f24561964f6a6c97.diff

LOG: [mlir][Vector] Fold chains of ExtractOp

This revision adds folding to ExtractOp by simply concatenating the position attributes.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 019b73a0e00b..0aae97f24d20 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -571,16 +571,43 @@ static LogicalResult verify(vector::ExtractOp op) {
   return success();
 }
 
-static SmallVector<unsigned, 4> extractUnsignedVector(ArrayAttr arrayAttr) {
+template <typename IntType>
+static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(llvm::map_range(
       arrayAttr.getAsRange<IntegerAttr>(),
-      [](IntegerAttr attr) { return static_cast<unsigned>(attr.getInt()); }));
+      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
 }
 
-static Value foldExtractOp(ExtractOp extractOp) {
+/// Fold the result of chains of ExtractOp in place by simply concatenating the
+/// positions.
+static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
+  if (!extractOp.vector().getDefiningOp<ExtractOp>())
+    return failure();
+
+  SmallVector<int64_t, 4> globalPosition;
+  ExtractOp currentOp = extractOp;
+  auto extractedPos = extractVector<int64_t>(currentOp.position());
+  globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+  while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
+    currentOp = nextOp;
+    auto extractedPos = extractVector<int64_t>(currentOp.position());
+    globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+  }
+  extractOp.setOperand(currentOp.vector());
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
+  OpBuilder b(extractOp.getContext());
+  std::reverse(globalPosition.begin(), globalPosition.end());
+  extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                    b.getI64ArrayAttr(globalPosition));
+  return success();
+}
+
+/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
+/// result is always the input to some InsertOp.
+static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
   MLIRContext *context = extractOp.getContext();
   AffineMap permutationMap;
-  auto extractedPos = extractUnsignedVector(extractOp.position());
+  auto extractedPos = extractVector<unsigned>(extractOp.position());
   // Walk back a chain of InsertOp/TransposeOp until we hit a match.
   // Compose TransposeOp permutations as we walk back.
   auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
@@ -588,7 +615,7 @@ static Value foldExtractOp(ExtractOp extractOp) {
   while (insertOp || transposeOp) {
     if (transposeOp) {
       // If it is transposed, compose the map and iterate.
-      auto permutation = extractUnsignedVector(transposeOp.transp());
+      auto permutation = extractVector<unsigned>(transposeOp.transp());
       AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
       if (!permutationMap)
         permutationMap = newMap;
@@ -610,7 +637,7 @@ static Value foldExtractOp(ExtractOp extractOp) {
     // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
     // produces a new vector with 1 modified value/slice in exactly the static
     // position we need to match.
-    auto insertedPos = extractUnsignedVector(insertOp.position());
+    auto insertedPos = extractVector<unsigned>(insertOp.position());
     // Trivial permutations are solved with position equality checks.
     if (!permutationMap || permutationMap.isIdentity()) {
       if (extractedPos == insertedPos)
@@ -660,7 +687,9 @@ static Value foldExtractOp(ExtractOp extractOp) {
 }
 
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
-  if (auto val = foldExtractOp(*this))
+  if (succeeded(foldExtractOpFromExtractChain(*this)))
+    return getResult();
+  if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
     return val;
   return OpFoldResult();
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2b2adf0cca64..09162aa0236b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -172,29 +172,25 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
 // CHECK-SAME:  %[[A:.*]]: !llvm<"[4 x [1 x <2 x float>]]">)
 // CHECK:       %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]">
 // CHECK:       %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK:       %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T2:.*]] = llvm.extractvalue %[[A]][0, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[T2]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T2]], %[[T4]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T6:.*]] = llvm.insertvalue %[[T2]], %[[T5]][2] : !llvm<"[3 x <2 x float>]">
 // CHECK:       %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK:       %[[T8:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK:       %[[T9:.*]] = llvm.extractvalue %[[T8]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK:       %[[T10:.*]] = llvm.insertvalue %[[T9]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T12:.*]] = llvm.insertvalue %[[T9]], %[[T11]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T8:.*]] = llvm.extractvalue %[[A]][1, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK:       %[[T10:.*]] = llvm.insertvalue %[[T8]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T11:.*]] = llvm.insertvalue %[[T8]], %[[T10]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T12:.*]] = llvm.insertvalue %[[T8]], %[[T11]][2] : !llvm<"[3 x <2 x float>]">
 // CHECK:       %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T7]][1] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK:       %[[T14:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK:       %[[T15:.*]] = llvm.extractvalue %[[T14]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK:       %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T17:.*]] = llvm.insertvalue %[[T15]], %[[T16]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T18:.*]] = llvm.insertvalue %[[T15]], %[[T17]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T14:.*]] = llvm.extractvalue %[[A]][2, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK:       %[[T16:.*]] = llvm.insertvalue %[[T14]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T17:.*]] = llvm.insertvalue %[[T14]], %[[T16]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T18:.*]] = llvm.insertvalue %[[T14]], %[[T17]][2] : !llvm<"[3 x <2 x float>]">
 // CHECK:       %[[T19:.*]] = llvm.insertvalue %[[T18]], %[[T13]][2] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK:       %[[T20:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK:       %[[T21:.*]] = llvm.extractvalue %[[T20]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK:       %[[T22:.*]] = llvm.insertvalue %[[T21]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T23:.*]] = llvm.insertvalue %[[T21]], %[[T22]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK:       %[[T24:.*]] = llvm.insertvalue %[[T21]], %[[T23]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T20:.*]] = llvm.extractvalue %[[A]][3, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK:       %[[T22:.*]] = llvm.insertvalue %[[T20]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T23:.*]] = llvm.insertvalue %[[T20]], %[[T22]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK:       %[[T24:.*]] = llvm.insertvalue %[[T20]], %[[T23]][2] : !llvm<"[3 x <2 x float>]">
 // CHECK:       %[[T25:.*]] = llvm.insertvalue %[[T24]], %[[T19]][3] : !llvm<"[4 x [3 x <2 x float>]]">
 // CHECK:       llvm.return %[[T25]] : !llvm<"[4 x [3 x <2 x float>]]">
 
@@ -630,7 +626,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
 // CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">)
 //      CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]">
-//      CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]">
+//      CHECK: %[[s2:.*]] = llvm.extractvalue %[[B]][0, 0] : !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s3:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //      CHECK: %[[s4:.*]] = llvm.extractelement %[[s1]][%[[s3]] : !llvm.i64] : !llvm<"<4 x float>">
 //      CHECK: %[[s5:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
@@ -649,7 +645,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
 //      CHECK: %[[s18:.*]] = llvm.insertelement %[[s16]], %[[s14]][%[[s17]] : !llvm.i64] : !llvm<"<8 x float>">
 //      CHECK: %[[s19:.*]] = llvm.insertvalue %[[s18]], %[[s0]][0] : !llvm<"[4 x <8 x float>]">
 //      CHECK: %[[s20:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <4 x float>]">
-//      CHECK: %[[s21:.*]] = llvm.extractvalue %[[s0]][1] : !llvm<"[4 x <8 x float>]">
+//      CHECK: %[[s21:.*]] = llvm.extractvalue %[[B]][0, 1] : !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s22:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //      CHECK: %[[s23:.*]] = llvm.extractelement %[[s20]][%[[s22]] : !llvm.i64] : !llvm<"<4 x float>">
 //      CHECK: %[[s24:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7ba79528aee6..94f3f627e777 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -295,3 +295,18 @@ func @insert_extract_transpose_3d_2d(
   // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
   return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: fold_extracts
+//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
+//  CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
+//  CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
+//  CHECK-NEXT: return
+func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
+  %b = vector.extract %a[0] : vector<3x4x5x6xf32>
+  %c = vector.extract %b[1, 2] : vector<4x5x6xf32>
+  %d = vector.extract %c[3] : vector<6xf32>
+  %e = vector.extract %a[0] : vector<3x4x5x6xf32>
+  return %d, %e : f32, vector<4x5x6xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 6a933d5e24b5..f6f215a50616 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -113,52 +113,46 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
 // CHECK:    %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
 // CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
-// CHECK:    %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
-// CHECK:    %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK:    %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
-// CHECK:    %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK:    %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
-// CHECK:    %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK:    %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
+// CHECK:    %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
+// CHECK:    %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32>
 // CHECK:    %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
 // CHECK:    %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
 // CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK:    %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
-// CHECK:    %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK:    %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
-// CHECK:    %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+//
+// CHECK:    %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
+// CHECK:    %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
+// CHECK:    %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32>
 // CHECK:    %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
 // CHECK:    %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
 // CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
 // CHECK:    %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+//
 // CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
-// CHECK:    %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
-// CHECK:    %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK:    %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
-// CHECK:    %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK:    %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
-// CHECK:    %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
-// CHECK:    %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32>
-// CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
+// CHECK:    %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
+// CHECK:    %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
+// CHECK:    %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32>
+// CHECK:    %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32>
+// CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32
 // CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK:    %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
-// CHECK:    %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK:    %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK:    %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
-// CHECK:    %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
-// CHECK:    %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
-// CHECK:    %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
-// CHECK:    %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
-// CHECK:    %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
-// CHECK:    return %[[T45]] : vector<2x2xf32>
+//
+// CHECK:    %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
+// CHECK:    %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK:    %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
+// CHECK:    %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32>
+// CHECK:    %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32>
+// CHECK:    %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32
+//
+// CHECK:    %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32>
+// CHECK:    %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK:    return %[[T52]] : vector<2x2xf32>
 
 func @extract_contract4(%arg0: vector<2x2xf32>,
                         %arg1: vector<2x2xf32>,
@@ -216,27 +210,22 @@ func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
 // CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
-// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
-// CHECK:      %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32>
-// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
-// CHECK:      %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32>
-// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32>
-// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
-// CHECK:      %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
-// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32>
+// CHECK:      %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
+//
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
-// CHECK:      %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32>
-// CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32>
-// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
-// CHECK:      %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32>
-// CHECK:      %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32>
-// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
-// CHECK:      %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
-// CHECK:      %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
+// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
+// CHECK:      %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32>
+// CHECK:      %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
+// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
+// CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
 // CHECK:      %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
 // CHECK:      return %[[T23]] : f32
@@ -657,21 +646,17 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
 // CHECK-LABEL: func @broadcast_stretch_at_end
 // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
 // CHECK:      %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3xf32>
-// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1xf32>
-// CHECK:      %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1xf32>
-// CHECK:      %[[T2:.*]] = splat %[[T1]] : vector<3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
+// CHECK:      %[[T2:.*]] = splat %[[T0]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
-// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<4x1xf32>
-// CHECK:      %[[T5:.*]] = vector.extract %[[T4]][0] : vector<1xf32>
-// CHECK:      %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
+// CHECK:      %[[T6:.*]] = splat %[[T4]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
-// CHECK:      %[[T8:.*]] = vector.extract %[[A]][2] : vector<4x1xf32>
-// CHECK:      %[[T9:.*]] = vector.extract %[[T8]][0] : vector<1xf32>
-// CHECK:      %[[T10:.*]] = splat %[[T9]] : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
+// CHECK:      %[[T10:.*]] = splat %[[T8]] : vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
-// CHECK:      %[[T12:.*]] = vector.extract %[[A]][3] : vector<4x1xf32>
-// CHECK:      %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1xf32>
-// CHECK:      %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
+// CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
+// CHECK:      %[[T14:.*]] = splat %[[T12]] : vector<3xf32>
 // CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      return %[[T15]] : vector<4x3xf32>
 
@@ -684,29 +669,25 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32>
 // CHECK:      %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32>
 // CHECK:      %[[C1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1x2xf32>
-// CHECK:      %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1x2xf32>
-// CHECK:      %[[T2:.*]] = vector.insert %[[T1]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T4:.*]] = vector.insert %[[T1]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32>
+// CHECK:      %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
 // CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : vector<4x1x2xf32>
-// CHECK:      %[[T7:.*]] = vector.extract %[[T6]][0] : vector<1x2xf32>
-// CHECK:      %[[T8:.*]] = vector.insert %[[T7]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T10:.*]] = vector.insert %[[T7]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32>
+// CHECK:      %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK:      %[[T12:.*]] = vector.extract %[[A]][2] : vector<4x1x2xf32>
-// CHECK:      %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1x2xf32>
-// CHECK:      %[[T14:.*]] = vector.insert %[[T13]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T15:.*]] = vector.insert %[[T13]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T16:.*]] = vector.insert %[[T13]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32>
+// CHECK:      %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
 // CHECK:      %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK:      %[[T18:.*]] = vector.extract %[[A]][3] : vector<4x1x2xf32>
-// CHECK:      %[[T19:.*]] = vector.extract %[[T18]][0] : vector<1x2xf32>
-// CHECK:      %[[T20:.*]] = vector.insert %[[T19]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK:      %[[T22:.*]] = vector.insert %[[T19]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32>
+// CHECK:      %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK:      %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
 // CHECK:      %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
 // CHECK:      return %[[T23]] : vector<4x3x2xf32>
 


        


More information about the Mlir-commits mailing list