[Mlir-commits] [mlir] [mlir][vector] Extend vector.{insert|extract}_strided_slice (PR #79052)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jan 23 04:03:08 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/79052

>From 7792e1dae6f22802b0293f0053bc9d7e4621e379 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 16 Jan 2024 08:21:08 +0000
Subject: [PATCH 1/3] [mlir][vector] Extend
 vector.{insert|extract}_strided_slice

Extends `vector.insert_strided_slice` and `vector.insert_strided_slice`
to allow scalable input and output vectors. For scalable sizes, the
corresponding slice size has to match the corresponding dimension in the
output/input vector (insert/extract, respectively).

This is supported:
```mlir
vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 4],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
```

This is not supported:
```mlir
vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 2],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
```
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 18 +++++++-
 .../VectorToLLVM/vector-to-llvm.mlir          | 43 +++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir         |  8 ++++
 mlir/test/Dialect/Vector/ops.mlir             |  7 +++
 4 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad4..b168b7d7afe9db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
 // Inference works as follows:
 //   1. Add 'sizes' from prefix of dims in 'offsets'.
 //   2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
 static Type inferStridedSliceOpResultType(VectorType vectorType,
                                           ArrayAttr offsets, ArrayAttr sizes,
                                           ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
     shape.push_back(vectorType.getShape()[idx]);
 
-  return VectorType::get(shape, vectorType.getElementType());
+  return VectorType::get(shape, vectorType.getElementType(),
+                         vectorType.getScalableDims());
 }
 
 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
   if (getResult().getType() != resultType)
     return emitOpError("expected result type to be ") << resultType;
 
+  unsigned idx = 0;
+  for (unsigned ub = sizes.size(); idx < ub; ++idx) {
+    if (type.getScalableDims()[idx]) {
+      auto inputDim = type.getShape()[idx];
+      auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+      if (inputDim != inputSize)
+        return emitOpError("expected size at idx=")
+               << idx
+               << (" to match the corresponding base size from the input "
+                   "vector (")
+               << inputSize << (" vs ") << inputDim << (")");
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab3179998..394b4dea3dab25 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+  return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL:   func.func @extract_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+//      CHECK:      %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_2:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+//      CHECK:      %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+//      CHECK:      %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+//      CHECK:      return %[[VAL_12]] : vector<1x1x[4]xi32>
+
+// -----
+
 func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
   %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
   return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
 
 // -----
 
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+  return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL:   func.func @insert_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+//      CHECK:      %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+//      CHECK:      return %[[VAL_10]] : vector<1x4x[4]xi32>
+
+// -----
+
 func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
   // CHECK-LABEL: @vector_fma
   //  CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973..2072262864c4cd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+    // expected-error at +1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+    %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+    return %1 : vector<1x1x[2]xi32>
+  }
+
+// -----
+
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{op expected strides to be confined to [1, 2)}}
   %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18..c95d0dfba69ada 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
   return %1: vector<2x2x16xf32>
 }
 
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+  // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+  %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+  return %1: vector<2x[8]x16xf32>
+}
+
 #contraction_to_scalar_accesses = [
   affine_map<(i) -> (i)>,
   affine_map<(i) -> (i)>,

>From f0c46076bcab7d8d826cebbc3f71d9e722c8892f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 23 Jan 2024 10:50:02 +0000
Subject: [PATCH 2/3] fixup! [mlir][vector] Extend
 vector.{insert|extract}_strided_slice

Update the verifier for insert_strided_slice
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 21 +++++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir    | 16 ++++++++++++++++
 2 files changed, 37 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b168b7d7afe9db..3f6ccfaae5b51e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2857,6 +2857,27 @@ LogicalResult InsertStridedSliceOp::verify() {
           /*halfOpen=*/false, /*min=*/1)))
     return failure();
 
+  unsigned idx = 0;
+  unsigned rankDiff = destShape.size() - sourceShape.size();
+  for (unsigned ub = sourceShape.size(); idx < ub; ++idx) {
+    if (sourceVectorType.getScalableDims()[idx] !=
+        destVectorType.getScalableDims()[idx + rankDiff]) {
+      return emitOpError("mismatching scalable flags (at source vector idx=")
+             << idx << ")";
+    }
+    if (sourceVectorType.getScalableDims()[idx]) {
+      auto sourceSize = sourceShape[idx];
+      auto destSize = destShape[idx + rankDiff];
+      if (sourceSize != destSize) {
+        return emitOpError("expected size at idx=")
+               << idx
+               << (" to match the corresponding base size from the input "
+                   "vector (")
+               << sourceSize << (" vs ") << destSize << (")");
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 2072262864c4cd..c16f1cb2876dbd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -652,6 +652,22 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @insert_strided_slice_scalable(%a : vector<1x1x[2]xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+  // expected-error at +1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+  %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[2]xi32> into vector<1x4x[4]xi32>
+  return %0 : vector<1x4x[4]xi32>
+}
+
+// -----
+
+func.func @insert_strided_slice_scalable(%a : vector<1x1x4xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+  // expected-error at +1 {{op mismatching scalable flags (at source vector idx=2)}}
+  %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x4xi32> into vector<1x4x[4]xi32>
+  return %0 : vector<1x4x[4]xi32>
+}
+
+// -----
+
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets, sizes and strides attributes of same size}}
   %1 = vector.extract_strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>

>From 6040af9fa9985184ab3e3fb06d8bd9c4e5d79288 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 23 Jan 2024 12:02:36 +0000
Subject: [PATCH 3/3] fixup! [mlir][vector] Extend
 vector.{insert|extract}_strided_slice

Add a test in ops.mlir
---
 mlir/test/Dialect/Vector/ops.mlir | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c95d0dfba69ada..2f8530e7c171aa 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -319,6 +319,13 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   return
 }
 
+// CHECK-LABEL: @insert_strided_slice_scalable
+func.func @insert_strided_slice_scalable(%a: vector<4x[16]xf32>, %b: vector<4x8x[16]xf32>) {
+  // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
+  %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
+  return
+}
+
 // CHECK-LABEL: @extract_strided_slice
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
   // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>



More information about the Mlir-commits mailing list