[Mlir-commits] [mlir] 22c8a08 - [mlir][Vector] Fold chains of ExtractOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 10 06:34:48 PDT 2020
Author: Nicolas Vasilache
Date: 2020-07-10T09:32:02-04:00
New Revision: 22c8a08fd8a1487159564f74f24561964f6a6c97
URL: https://github.com/llvm/llvm-project/commit/22c8a08fd8a1487159564f74f24561964f6a6c97
DIFF: https://github.com/llvm/llvm-project/commit/22c8a08fd8a1487159564f74f24561964f6a6c97.diff
LOG: [mlir][Vector] Fold chains of ExtractOp
This revision adds folding to ExtractOp by simply concatenating the position attributes.
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 019b73a0e00b..0aae97f24d20 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -571,16 +571,43 @@ static LogicalResult verify(vector::ExtractOp op) {
return success();
}
-static SmallVector<unsigned, 4> extractUnsignedVector(ArrayAttr arrayAttr) {
+template <typename IntType>
+static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
- [](IntegerAttr attr) { return static_cast<unsigned>(attr.getInt()); }));
+ [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
-static Value foldExtractOp(ExtractOp extractOp) {
+/// Fold the result of chains of ExtractOp in place by simply concatenating the
+/// positions.
+static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
+ if (!extractOp.vector().getDefiningOp<ExtractOp>())
+ return failure();
+
+ SmallVector<int64_t, 4> globalPosition;
+ ExtractOp currentOp = extractOp;
+ auto extractedPos = extractVector<int64_t>(currentOp.position());
+ globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+ while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
+ currentOp = nextOp;
+ auto extractedPos = extractVector<int64_t>(currentOp.position());
+ globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+ }
+ extractOp.setOperand(currentOp.vector());
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(extractOp.getContext());
+ std::reverse(globalPosition.begin(), globalPosition.end());
+ extractOp.setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(globalPosition));
+ return success();
+}
+
+/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
+/// result is always the input to some InsertOp.
+static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
MLIRContext *context = extractOp.getContext();
AffineMap permutationMap;
- auto extractedPos = extractUnsignedVector(extractOp.position());
+ auto extractedPos = extractVector<unsigned>(extractOp.position());
// Walk back a chain of InsertOp/TransposeOp until we hit a match.
// Compose TransposeOp permutations as we walk back.
auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
@@ -588,7 +615,7 @@ static Value foldExtractOp(ExtractOp extractOp) {
while (insertOp || transposeOp) {
if (transposeOp) {
// If it is transposed, compose the map and iterate.
- auto permutation = extractUnsignedVector(transposeOp.transp());
+ auto permutation = extractVector<unsigned>(transposeOp.transp());
AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
if (!permutationMap)
permutationMap = newMap;
@@ -610,7 +637,7 @@ static Value foldExtractOp(ExtractOp extractOp) {
// InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
// produces a new vector with 1 modified value/slice in exactly the static
// position we need to match.
- auto insertedPos = extractUnsignedVector(insertOp.position());
+ auto insertedPos = extractVector<unsigned>(insertOp.position());
// Trivial permutations are solved with position equality checks.
if (!permutationMap || permutationMap.isIdentity()) {
if (extractedPos == insertedPos)
@@ -660,7 +687,9 @@ static Value foldExtractOp(ExtractOp extractOp) {
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
- if (auto val = foldExtractOp(*this))
+ if (succeeded(foldExtractOpFromExtractChain(*this)))
+ return getResult();
+ if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
return val;
return OpFoldResult();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2b2adf0cca64..09162aa0236b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -172,29 +172,25 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
// CHECK-SAME: %[[A:.*]]: !llvm<"[4 x [1 x <2 x float>]]">)
// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]">
// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T2]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T2]], %[[T4]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T2]], %[[T5]][2] : !llvm<"[3 x <2 x float>]">
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: %[[T9:.*]] = llvm.extractvalue %[[T8]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T9]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T9]], %[[T11]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T8]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T8]], %[[T10]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T8]], %[[T11]][2] : !llvm<"[3 x <2 x float>]">
// CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T7]][1] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: %[[T15:.*]] = llvm.extractvalue %[[T14]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T15]], %[[T16]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T15]], %[[T17]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T14]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T14]], %[[T16]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T14]], %[[T17]][2] : !llvm<"[3 x <2 x float>]">
// CHECK: %[[T19:.*]] = llvm.insertvalue %[[T18]], %[[T13]][2] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: %[[T21:.*]] = llvm.extractvalue %[[T20]][0] : !llvm<"[1 x <2 x float>]">
-// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T21]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T21]], %[[T22]][1] : !llvm<"[3 x <2 x float>]">
-// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T21]], %[[T23]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3, 0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T20]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T20]], %[[T22]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T20]], %[[T23]][2] : !llvm<"[3 x <2 x float>]">
// CHECK: %[[T25:.*]] = llvm.insertvalue %[[T24]], %[[T19]][3] : !llvm<"[4 x [3 x <2 x float>]]">
// CHECK: llvm.return %[[T25]] : !llvm<"[4 x [3 x <2 x float>]]">
@@ -630,7 +626,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">)
// CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]">
-// CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]">
+// CHECK: %[[s2:.*]] = llvm.extractvalue %[[B]][0, 0] : !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: %[[s3:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[s4:.*]] = llvm.extractelement %[[s1]][%[[s3]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s5:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
@@ -649,7 +645,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// CHECK: %[[s18:.*]] = llvm.insertelement %[[s16]], %[[s14]][%[[s17]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s19:.*]] = llvm.insertvalue %[[s18]], %[[s0]][0] : !llvm<"[4 x <8 x float>]">
// CHECK: %[[s20:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <4 x float>]">
-// CHECK: %[[s21:.*]] = llvm.extractvalue %[[s0]][1] : !llvm<"[4 x <8 x float>]">
+// CHECK: %[[s21:.*]] = llvm.extractvalue %[[B]][0, 1] : !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: %[[s22:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[s23:.*]] = llvm.extractelement %[[s20]][%[[s22]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s24:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7ba79528aee6..94f3f627e777 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -295,3 +295,18 @@ func @insert_extract_transpose_3d_2d(
// CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: fold_extracts
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
+// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
+// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
+// CHECK-NEXT: return
+func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
+ %b = vector.extract %a[0] : vector<3x4x5x6xf32>
+ %c = vector.extract %b[1, 2] : vector<4x5x6xf32>
+ %d = vector.extract %c[3] : vector<6xf32>
+ %e = vector.extract %a[0] : vector<3x4x5x6xf32>
+ return %d, %e : f32, vector<4x5x6xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 6a933d5e24b5..f6f215a50616 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -113,52 +113,46 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32>
-// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32>
// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
-// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
-// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+//
+// CHECK: %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
+// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
+// CHECK: %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32>
// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
+//
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
-// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32>
-// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32>
-// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
-// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
-// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32>
-// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
+// CHECK: %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32>
+// CHECK: %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32>
+// CHECK: %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32>
+// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32>
+// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
-// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
-// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32>
-// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32>
-// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
-// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
-// CHECK: %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
-// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
-// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
-// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
-// CHECK: return %[[T45]] : vector<2x2xf32>
+//
+// CHECK: %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32>
+// CHECK: %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32>
+// CHECK: %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32>
+// CHECK: %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32>
+// CHECK: %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32
+//
+// CHECK: %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK: return %[[T52]] : vector<2x2xf32>
func @extract_contract4(%arg0: vector<2x2xf32>,
%arg1: vector<2x2xf32>,
@@ -216,27 +210,22 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32>
+// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
// CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
+//
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
-// CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32>
-// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32>
-// CHECK: %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
-// CHECK: %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32>
-// CHECK: %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32>
-// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
-// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
-// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
+// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
+// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32>
+// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
+// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
+// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
// CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
// CHECK: return %[[T23]] : f32
@@ -657,21 +646,17 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
// CHECK-LABEL: func @broadcast_stretch_at_end
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1xf32>
-// CHECK: %[[T2:.*]] = splat %[[T1]] : vector<3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
+// CHECK: %[[T2:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<4x1xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<1xf32>
-// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
+// CHECK: %[[T6:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2] : vector<4x1xf32>
-// CHECK: %[[T9:.*]] = vector.extract %[[T8]][0] : vector<1xf32>
-// CHECK: %[[T10:.*]] = splat %[[T9]] : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
+// CHECK: %[[T10:.*]] = splat %[[T8]] : vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
-// CHECK: %[[T12:.*]] = vector.extract %[[A]][3] : vector<4x1xf32>
-// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1xf32>
-// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
+// CHECK: %[[T14:.*]] = splat %[[T12]] : vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
@@ -684,29 +669,25 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1x2xf32>
-// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1x2xf32>
-// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T4:.*]] = vector.insert %[[T1]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<4x1x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T6]][0] : vector<1x2xf32>
-// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T10:.*]] = vector.insert %[[T7]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32>
+// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK: %[[T12:.*]] = vector.extract %[[A]][2] : vector<4x1x2xf32>
-// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1x2xf32>
-// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T16:.*]] = vector.insert %[[T13]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32>
+// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
-// CHECK: %[[T18:.*]] = vector.extract %[[A]][3] : vector<4x1x2xf32>
-// CHECK: %[[T19:.*]] = vector.extract %[[T18]][0] : vector<1x2xf32>
-// CHECK: %[[T20:.*]] = vector.insert %[[T19]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
-// CHECK: %[[T22:.*]] = vector.insert %[[T19]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32>
+// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
// CHECK: return %[[T23]] : vector<4x3x2xf32>
More information about the Mlir-commits
mailing list