[Mlir-commits] [mlir] 9ddbcee - [mlir][vector] Extend vector.{insert|extract}_strided_slice (#79052)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 25 11:01:32 PST 2024
Author: Andrzej WarzyĆski
Date: 2024-01-25T19:01:28Z
New Revision: 9ddbcee25ea4b436ab478874221553f4a7bc2289
URL: https://github.com/llvm/llvm-project/commit/9ddbcee25ea4b436ab478874221553f4a7bc2289
DIFF: https://github.com/llvm/llvm-project/commit/9ddbcee25ea4b436ab478874221553f4a7bc2289.diff
LOG: [mlir][vector] Extend vector.{insert|extract}_strided_slice (#79052)
Extends `vector.insert_strided_slice` and `vector.insert_strided_slice`
to allow scalable input and output vectors. For scalable sizes, the
corresponding slice size has to match the corresponding dimension in the
output/input vector (insert/extract, respectively).
This is supported:
```mlir
vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 4],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
```
This is not supported:
```mlir
vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 2],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
```
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5f9860ef2311e7..452354413e8833 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2857,6 +2857,26 @@ LogicalResult InsertStridedSliceOp::verify() {
/*halfOpen=*/false, /*min=*/1)))
return failure();
+ unsigned rankDiff = destShape.size() - sourceShape.size();
+ for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
+ if (sourceVectorType.getScalableDims()[idx] !=
+ destVectorType.getScalableDims()[idx + rankDiff]) {
+ return emitOpError("mismatching scalable flags (at source vector idx=")
+ << idx << ")";
+ }
+ if (sourceVectorType.getScalableDims()[idx]) {
+ auto sourceSize = sourceShape[idx];
+ auto destSize = destShape[idx + rankDiff];
+ if (sourceSize != destSize) {
+ return emitOpError("expected size at idx=")
+ << idx
+ << (" to match the corresponding base size from the input "
+ "vector (")
+ << sourceSize << (" vs ") << destSize << (")");
+ }
+ }
+ }
+
return success();
}
@@ -3194,6 +3214,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
// Inference works as follows:
// 1. Add 'sizes' from prefix of dims in 'offsets'.
// 2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
static Type inferStridedSliceOpResultType(VectorType vectorType,
ArrayAttr offsets, ArrayAttr sizes,
ArrayAttr strides) {
@@ -3206,7 +3227,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);
- return VectorType::get(shape, vectorType.getElementType());
+ return VectorType::get(shape, vectorType.getElementType(),
+ vectorType.getScalableDims());
}
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3287,19 @@ LogicalResult ExtractStridedSliceOp::verify() {
if (getResult().getType() != resultType)
return emitOpError("expected result type to be ") << resultType;
+ for (unsigned idx = 0; idx < sizes.size(); ++idx) {
+ if (type.getScalableDims()[idx]) {
+ auto inputDim = type.getShape()[idx];
+ auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+ if (inputDim != inputSize)
+ return emitOpError("expected size at idx=")
+ << idx
+ << (" to match the corresponding base size from the input "
+ "vector (")
+ << inputSize << (" vs ") << inputDim << (")");
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab3179998..1c13b16dfd9af3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,28 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+ return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @extract_strided_slice_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+
+// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+// CHECK: %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+
+// CHECK: %[[EXT:.*]] = llvm.extractvalue %[[CAST_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT]], %[[CAST_3]][0] : !llvm.array<1 x vector<[4]xi32>>
+// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+
+// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+
+// -----
+
func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
return %0 : vector<4x4x4xf32>
@@ -1207,6 +1229,27 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
// -----
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+ return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL: func.func @insert_strided_slice_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+
+// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+
+// CHECK: %[[EXT_1:.*]] = llvm.extractvalue %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[EXT_2:.*]] = llvm.extractvalue %[[CAST_1]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+
+// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT_2]], %[[EXT_1]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+
+// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+
+// -----
+
func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973..c16f1cb2876dbd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -652,6 +652,22 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// -----
+func.func @insert_strided_slice_scalable(%a : vector<1x1x[2]xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+ // expected-error at +1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+ %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[2]xi32> into vector<1x4x[4]xi32>
+ return %0 : vector<1x4x[4]xi32>
+}
+
+// -----
+
+func.func @insert_strided_slice_scalable(%a : vector<1x1x4xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+ // expected-error at +1 {{op mismatching scalable flags (at source vector idx=2)}}
+ %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x4xi32> into vector<1x4x[4]xi32>
+ return %0 : vector<1x4x[4]xi32>
+}
+
+// -----
+
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected offsets, sizes and strides attributes of same size}}
%1 = vector.extract_strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
@@ -687,6 +703,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+ // expected-error at +1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+ return %1 : vector<1x1x[2]xi32>
+ }
+
+// -----
+
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{op expected strides to be confined to [1, 2)}}
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18..2f8530e7c171aa 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -319,6 +319,13 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
return
}
+// CHECK-LABEL: @insert_strided_slice_scalable
+func.func @insert_strided_slice_scalable(%a: vector<4x[16]xf32>, %b: vector<4x8x[16]xf32>) {
+ // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
+ %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
+ return
+}
+
// CHECK-LABEL: @extract_strided_slice
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
// CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
@@ -326,6 +333,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
return %1: vector<2x2x16xf32>
}
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+ // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+ return %1: vector<2x[8]x16xf32>
+}
+
#contraction_to_scalar_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
More information about the Mlir-commits
mailing list