[Mlir-commits] [mlir] 00ac874 - [mlir][Vector] Add InsertStridedSliceOp -> ShuffleOp for the rank-1 cases.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Oct 27 00:58:22 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-27T07:57:17Z
New Revision: 00ac874ff605573b1a9c7c5daa707f10a96ff26c
URL: https://github.com/llvm/llvm-project/commit/00ac874ff605573b1a9c7c5daa707f10a96ff26c
DIFF: https://github.com/llvm/llvm-project/commit/00ac874ff605573b1a9c7c5daa707f10a96ff26c.diff
LOG: [mlir][Vector] Add InsertStridedSliceOp -> ShuffleOp for the rank-1 cases.
This also fixes the vector.shuffle C++ builder which had an incorrect type assumption that triggers with this new rewrite.
The vector.shuffle semantics were correct though.
Differential revision: https://reviews.llvm.org/D112578
Added:
Modified:
mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 1dc04027266a5..166f09eaa8762 100644
--- a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -121,9 +121,10 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
if (op.offsets().getValue().empty())
return failure();
- int64_t rankDiff = dstType.getRank() - srcType.getRank();
- assert(rankDiff >= 0);
- if (rankDiff != 0)
+ int64_t srcRank = srcType.getRank();
+ int64_t dstRank = dstType.getRank();
+ assert(dstRank >= srcRank);
+ if (dstRank != srcRank)
return failure();
if (srcType == dstType) {
@@ -139,6 +140,34 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
auto loc = op.getLoc();
Value res = op.dest();
+
+ if (srcRank == 1) {
+ int nSrc = srcType.getShape().front();
+ int nDest = dstType.getShape().front();
+ // 1. Scale source to destType so we can shufflevector them together.
+ SmallVector<int64_t> offsets(nDest, 0);
+ for (int64_t i = 0; i < nSrc; ++i)
+ offsets[i] = i;
+ Value scaledSource =
+ rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
+
+ // 2. Create a mask where we take the value from scaledSource of dest
+ // depending on the offset.
+ offsets.clear();
+ for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
+ if (i < offset || i >= e || (i - offset) % stride != 0)
+ offsets.push_back(nDest + i);
+ else
+ offsets.push_back((i - offset) / stride);
+ }
+
+ // 3. Replace with a ShuffleOp.
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
+ offsets);
+
+ return success();
+ }
+
// For each slice of the source vector along the most major dimension.
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 149662b124e78..bda4ee7899c2b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1451,7 +1451,10 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
Value v2, ArrayRef<int64_t> mask) {
result.addOperands({v1, v2});
auto maskAttr = getVectorSubscriptAttr(builder, mask);
- result.addTypes(v1.getType());
+ auto v1Type = v1.getType().cast<VectorType>();
+ auto shape = llvm::to_vector<4>(v1Type.getShape());
+ shape[0] = mask.size();
+ result.addTypes(VectorType::get(shape, v1Type.getElementType()));
result.addAttribute(getMaskAttrName(), maskAttr);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1951cf8339b76..7f619e7e8db8c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -900,46 +900,24 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
return %0 : vector<4x4xf32>
}
+
// CHECK-LABEL: @insert_strided_slice2
//
// Subvector vector<2xf32> @0 into vector<4xf32> @2
-// CHECK: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>>
-// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
+// CHECK: %[[V2_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V4_0:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
// Element @0 -> element @2
-// CHECK-NEXT: arith.constant 0 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-// CHECK-NEXT: arith.constant 2 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// Element @1 -> element @3
-// CHECK-NEXT: arith.constant 1 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-// CHECK-NEXT: arith.constant 3 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
+// CHECK: %[[R4_0:.*]] = llvm.shufflevector %[[V2_0]], %[[V2_0]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[R4_1:.*]] = llvm.shufflevector %[[R4_0]], %[[V4_0]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32>
+// CHECK: llvm.insertvalue %[[R4_1]], {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
//
// Subvector vector<2xf32> @1 into vector<4xf32> @3
-// CHECK: llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>>
-// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
+// CHECK: %[[V2_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V4_3:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
// Element @0 -> element @2
-// CHECK-NEXT: arith.constant 0 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-// CHECK-NEXT: arith.constant 2 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// Element @1 -> element @3
-// CHECK-NEXT: arith.constant 1 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-// CHECK-NEXT: arith.constant 3 : index
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
-// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
+// CHECK: %[[R4_2:.*]] = llvm.shufflevector %[[V2_1]], %[[V2_1]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[R4_3:.*]] = llvm.shufflevector %[[R4_2]], %[[V4_3]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32>
+// CHECK: llvm.insertvalue %[[R4_3]], {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
// -----
@@ -948,69 +926,18 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
vector<2x4xf32> into vector<16x4x8xf32>
return %0 : vector<16x4x8xf32>
}
-// CHECK-LABEL: @insert_strided_slice3(
-// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:.*]]: vector<16x4x8xf32>)
-// CHECK-DAG: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK-DAG: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s3:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s5:.*]] = llvm.extractvalue %[[s4]][0] : !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s2]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s8:.*]] = arith.constant 0 : index
-// CHECK: %[[s9:.*]] = builtin.unrealized_conversion_cast %[[s8]] : index to i64
-// CHECK: %[[s10:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s9]] : i64] : vector<4xf32>
-// CHECK: %[[s11:.*]] = arith.constant 2 : index
-// CHECK: %[[s12:.*]] = builtin.unrealized_conversion_cast %[[s11]] : index to i64
-// CHECK: %[[s13:.*]] = llvm.insertelement %[[s10]], %[[s7]]{{\[}}%[[s12]] : i64] : vector<8xf32>
-// CHECK: %[[s14:.*]] = arith.constant 1 : index
-// CHECK: %[[s15:.*]] = builtin.unrealized_conversion_cast %[[s14]] : index to i64
-// CHECK: %[[s16:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s15]] : i64] : vector<4xf32>
-// CHECK: %[[s17:.*]] = arith.constant 3 : index
-// CHECK: %[[s18:.*]] = builtin.unrealized_conversion_cast %[[s17]] : index to i64
-// CHECK: %[[s19:.*]] = llvm.insertelement %[[s16]], %[[s13]]{{\[}}%[[s18]] : i64] : vector<8xf32>
-// CHECK: %[[s20:.*]] = arith.constant 2 : index
-// CHECK: %[[s21:.*]] = builtin.unrealized_conversion_cast %[[s20]] : index to i64
-// CHECK: %[[s22:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s21]] : i64] : vector<4xf32>
-// CHECK: %[[s23:.*]] = arith.constant 4 : index
-// CHECK: %[[s24:.*]] = builtin.unrealized_conversion_cast %[[s23]] : index to i64
-// CHECK: %[[s25:.*]] = llvm.insertelement %[[s22]], %[[s19]]{{\[}}%[[s24]] : i64] : vector<8xf32>
-// CHECK: %[[s26:.*]] = arith.constant 3 : index
-// CHECK: %[[s27:.*]] = builtin.unrealized_conversion_cast %[[s26]] : index to i64
-// CHECK: %[[s28:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s27]] : i64] : vector<4xf32>
-// CHECK: %[[s29:.*]] = arith.constant 5 : index
-// CHECK: %[[s30:.*]] = builtin.unrealized_conversion_cast %[[s29]] : index to i64
-// CHECK: %[[s31:.*]] = llvm.insertelement %[[s28]], %[[s25]]{{\[}}%[[s30]] : i64] : vector<8xf32>
-// CHECK: %[[s32:.*]] = llvm.insertvalue %[[s31]], %[[s3]][0] : !llvm.array<4 x vector<8xf32>>
-// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s4]][1] : !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s2]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s37:.*]] = arith.constant 0 : index
-// CHECK: %[[s38:.*]] = builtin.unrealized_conversion_cast %[[s37]] : index to i64
-// CHECK: %[[s39:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s38]] : i64] : vector<4xf32>
-// CHECK: %[[s40:.*]] = arith.constant 2 : index
-// CHECK: %[[s41:.*]] = builtin.unrealized_conversion_cast %[[s40]] : index to i64
-// CHECK: %[[s42:.*]] = llvm.insertelement %[[s39]], %[[s36]]{{\[}}%[[s41]] : i64] : vector<8xf32>
-// CHECK: %[[s43:.*]] = arith.constant 1 : index
-// CHECK: %[[s44:.*]] = builtin.unrealized_conversion_cast %[[s43]] : index to i64
-// CHECK: %[[s45:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s44]] : i64] : vector<4xf32>
-// CHECK: %[[s46:.*]] = arith.constant 3 : index
-// CHECK: %[[s47:.*]] = builtin.unrealized_conversion_cast %[[s46]] : index to i64
-// CHECK: %[[s48:.*]] = llvm.insertelement %[[s45]], %[[s42]]{{\[}}%[[s47]] : i64] : vector<8xf32>
-// CHECK: %[[s49:.*]] = arith.constant 2 : index
-// CHECK: %[[s50:.*]] = builtin.unrealized_conversion_cast %[[s49]] : index to i64
-// CHECK: %[[s51:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s50]] : i64] : vector<4xf32>
-// CHECK: %[[s52:.*]] = arith.constant 4 : index
-// CHECK: %[[s53:.*]] = builtin.unrealized_conversion_cast %[[s52]] : index to i64
-// CHECK: %[[s54:.*]] = llvm.insertelement %[[s51]], %[[s48]]{{\[}}%[[s53]] : i64] : vector<8xf32>
-// CHECK: %[[s55:.*]] = arith.constant 3 : index
-// CHECK: %[[s56:.*]] = builtin.unrealized_conversion_cast %[[s55]] : index to i64
-// CHECK: %[[s57:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s56]] : i64] : vector<4xf32>
-// CHECK: %[[s58:.*]] = arith.constant 5 : index
-// CHECK: %[[s59:.*]] = builtin.unrealized_conversion_cast %[[s58]] : index to i64
-// CHECK: %[[s60:.*]] = llvm.insertelement %[[s57]], %[[s54]]{{\[}}%[[s59]] : i64] : vector<8xf32>
-// CHECK: %[[s61:.*]] = llvm.insertvalue %[[s60]], %[[s32]][1] : !llvm.array<4 x vector<8xf32>>
-// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s64:.*]] = builtin.unrealized_conversion_cast %[[s63]] : !llvm.array<16 x array<4 x vector<8xf32>>> to vector<16x4x8xf32>
-// CHECK: return %[[s64]] : vector<16x4x8xf32>
+// CHECK-LABEL: func @insert_strided_slice3
+// CHECK: %[[V4_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+// CHECK: %[[V4_0_0:.*]] = llvm.extractvalue {{.*}}[0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK: %[[R8_0:.*]] = llvm.shufflevector %[[V4_0]], %[[V4_0]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[R8_1:.*]] = llvm.shufflevector %[[R8_0:.*]], %[[V4_0_0]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: llvm.insertvalue %[[R8_1]], {{.*}}[0] : !llvm.array<4 x vector<8xf32>>
+
+// CHECK: %[[V4_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
+// CHECK: %[[V4_0_1:.*]] = llvm.extractvalue {{.*}}[0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK: %[[R8_2:.*]] = llvm.shufflevector %[[V4_1]], %[[V4_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[R8_3:.*]] = llvm.shufflevector %[[R8_2]], %[[V4_0_1]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: llvm.insertvalue %[[R8_3]], {{.*}}[1] : !llvm.array<4 x vector<8xf32>>
// -----
More information about the Mlir-commits
mailing list