[Mlir-commits] [mlir] 00ac874 - [mlir][Vector] Add InsertStridedSliceOp -> ShuffleOp for the rank-1 cases.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Oct 27 00:58:22 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-27T07:57:17Z
New Revision: 00ac874ff605573b1a9c7c5daa707f10a96ff26c

URL: https://github.com/llvm/llvm-project/commit/00ac874ff605573b1a9c7c5daa707f10a96ff26c
DIFF: https://github.com/llvm/llvm-project/commit/00ac874ff605573b1a9c7c5daa707f10a96ff26c.diff

LOG: [mlir][Vector] Add InsertStridedSliceOp -> ShuffleOp for the rank-1 cases.

This also fixes the vector.shuffle C++ builder which had an incorrect type assumption that triggers with this new rewrite.
The vector.shuffle semantics were correct though.

Differential revision: https://reviews.llvm.org/D112578

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 1dc04027266a5..166f09eaa8762 100644
--- a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -121,9 +121,10 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
     if (op.offsets().getValue().empty())
       return failure();
 
-    int64_t rankDiff = dstType.getRank() - srcType.getRank();
-    assert(rankDiff >= 0);
-    if (rankDiff != 0)
+    int64_t srcRank = srcType.getRank();
+    int64_t dstRank = dstType.getRank();
+    assert(dstRank >= srcRank);
+    if (dstRank != srcRank)
       return failure();
 
     if (srcType == dstType) {
@@ -139,6 +140,34 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
 
     auto loc = op.getLoc();
     Value res = op.dest();
+
+    if (srcRank == 1) {
+      int nSrc = srcType.getShape().front();
+      int nDest = dstType.getShape().front();
+      // 1. Scale source to destType so we can shufflevector them together.
+      SmallVector<int64_t> offsets(nDest, 0);
+      for (int64_t i = 0; i < nSrc; ++i)
+        offsets[i] = i;
+      Value scaledSource =
+          rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
+
+      // 2. Create a mask where we take the value from scaledSource of dest
+      // depending on the offset.
+      offsets.clear();
+      for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
+        if (i < offset || i >= e || (i - offset) % stride != 0)
+          offsets.push_back(nDest + i);
+        else
+          offsets.push_back((i - offset) / stride);
+      }
+
+      // 3. Replace with a ShuffleOp.
+      rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
+                                             offsets);
+
+      return success();
+    }
+
     // For each slice of the source vector along the most major dimension.
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 149662b124e78..bda4ee7899c2b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1451,7 +1451,10 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
                       Value v2, ArrayRef<int64_t> mask) {
   result.addOperands({v1, v2});
   auto maskAttr = getVectorSubscriptAttr(builder, mask);
-  result.addTypes(v1.getType());
+  auto v1Type = v1.getType().cast<VectorType>();
+  auto shape = llvm::to_vector<4>(v1Type.getShape());
+  shape[0] = mask.size();
+  result.addTypes(VectorType::get(shape, v1Type.getElementType()));
   result.addAttribute(getMaskAttrName(), maskAttr);
 }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1951cf8339b76..7f619e7e8db8c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -900,46 +900,24 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
   %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
   return %0 : vector<4x4xf32>
 }
+
 // CHECK-LABEL: @insert_strided_slice2
 //
 // Subvector vector<2xf32> @0 into vector<4xf32> @2
-//  CHECK:    unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
-//       CHECK:    llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>>
-//  CHECK-NEXT:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
+//       CHECK:    %[[V2_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>>
+//       CHECK:    %[[V4_0:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
 // Element @0 -> element @2
-//  CHECK-NEXT:    arith.constant 0 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-//  CHECK-NEXT:    arith.constant 2 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// Element @1 -> element @3
-//  CHECK-NEXT:    arith.constant 1 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-//  CHECK-NEXT:    arith.constant 3 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-//  CHECK-NEXT:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
+//       CHECK:    %[[R4_0:.*]] = llvm.shufflevector %[[V2_0]], %[[V2_0]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32>
+//       CHECK:    %[[R4_1:.*]] = llvm.shufflevector %[[R4_0]], %[[V4_0]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32>
+//       CHECK:    llvm.insertvalue %[[R4_1]], {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
 //
 // Subvector vector<2xf32> @1 into vector<4xf32> @3
-//       CHECK:    llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>>
-//  CHECK-NEXT:    llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
+//       CHECK:    %[[V2_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>>
+//       CHECK:    %[[V4_3:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
 // Element @0 -> element @2
-//  CHECK-NEXT:    arith.constant 0 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-//  CHECK-NEXT:    arith.constant 2 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// Element @1 -> element @3
-//  CHECK-NEXT:    arith.constant 1 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32>
-//  CHECK-NEXT:    arith.constant 3 : index
-//  CHECK-NEXT:    unrealized_conversion_cast %{{.*}} : index to i64
-//  CHECK-NEXT:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-//  CHECK-NEXT:    llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
+//       CHECK:    %[[R4_2:.*]] = llvm.shufflevector %[[V2_1]], %[[V2_1]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32>
+//       CHECK:    %[[R4_3:.*]] = llvm.shufflevector %[[R4_2]], %[[V4_3]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32>
+//       CHECK:    llvm.insertvalue %[[R4_3]], {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
 
 // -----
 
@@ -948,69 +926,18 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
         vector<2x4xf32> into vector<16x4x8xf32>
   return %0 : vector<16x4x8xf32>
 }
-// CHECK-LABEL: @insert_strided_slice3(
-// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:.*]]: vector<16x4x8xf32>)
-//      CHECK-DAG: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
-//      CHECK-DAG: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
-//      CHECK: %[[s3:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-//      CHECK: %[[s5:.*]] = llvm.extractvalue %[[s4]][0] : !llvm.array<2 x vector<4xf32>>
-//      CHECK: %[[s7:.*]] = llvm.extractvalue %[[s2]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-//      CHECK: %[[s8:.*]] = arith.constant 0 : index
-//      CHECK: %[[s9:.*]] = builtin.unrealized_conversion_cast %[[s8]] : index to i64
-//      CHECK: %[[s10:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s9]] : i64] : vector<4xf32>
-//      CHECK: %[[s11:.*]] = arith.constant 2 : index
-//      CHECK: %[[s12:.*]] = builtin.unrealized_conversion_cast %[[s11]] : index to i64
-//      CHECK: %[[s13:.*]] = llvm.insertelement %[[s10]], %[[s7]]{{\[}}%[[s12]] : i64] : vector<8xf32>
-//      CHECK: %[[s14:.*]] = arith.constant 1 : index
-//      CHECK: %[[s15:.*]] = builtin.unrealized_conversion_cast %[[s14]] : index to i64
-//      CHECK: %[[s16:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s15]] : i64] : vector<4xf32>
-//      CHECK: %[[s17:.*]] = arith.constant 3 : index
-//      CHECK: %[[s18:.*]] = builtin.unrealized_conversion_cast %[[s17]] : index to i64
-//      CHECK: %[[s19:.*]] = llvm.insertelement %[[s16]], %[[s13]]{{\[}}%[[s18]] : i64] : vector<8xf32>
-//      CHECK: %[[s20:.*]] = arith.constant 2 : index
-//      CHECK: %[[s21:.*]] = builtin.unrealized_conversion_cast %[[s20]] : index to i64
-//      CHECK: %[[s22:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s21]] : i64] : vector<4xf32>
-//      CHECK: %[[s23:.*]] = arith.constant 4 : index
-//      CHECK: %[[s24:.*]] = builtin.unrealized_conversion_cast %[[s23]] : index to i64
-//      CHECK: %[[s25:.*]] = llvm.insertelement %[[s22]], %[[s19]]{{\[}}%[[s24]] : i64] : vector<8xf32>
-//      CHECK: %[[s26:.*]] = arith.constant 3 : index
-//      CHECK: %[[s27:.*]] = builtin.unrealized_conversion_cast %[[s26]] : index to i64
-//      CHECK: %[[s28:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s27]] : i64] : vector<4xf32>
-//      CHECK: %[[s29:.*]] = arith.constant 5 : index
-//      CHECK: %[[s30:.*]] = builtin.unrealized_conversion_cast %[[s29]] : index to i64
-//      CHECK: %[[s31:.*]] = llvm.insertelement %[[s28]], %[[s25]]{{\[}}%[[s30]] : i64] : vector<8xf32>
-//      CHECK: %[[s32:.*]] = llvm.insertvalue %[[s31]], %[[s3]][0] : !llvm.array<4 x vector<8xf32>>
-//      CHECK: %[[s34:.*]] = llvm.extractvalue %[[s4]][1] : !llvm.array<2 x vector<4xf32>>
-//      CHECK: %[[s36:.*]] = llvm.extractvalue %[[s2]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
-//      CHECK: %[[s37:.*]] = arith.constant 0 : index
-//      CHECK: %[[s38:.*]] = builtin.unrealized_conversion_cast %[[s37]] : index to i64
-//      CHECK: %[[s39:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s38]] : i64] : vector<4xf32>
-//      CHECK: %[[s40:.*]] = arith.constant 2 : index
-//      CHECK: %[[s41:.*]] = builtin.unrealized_conversion_cast %[[s40]] : index to i64
-//      CHECK: %[[s42:.*]] = llvm.insertelement %[[s39]], %[[s36]]{{\[}}%[[s41]] : i64] : vector<8xf32>
-//      CHECK: %[[s43:.*]] = arith.constant 1 : index
-//      CHECK: %[[s44:.*]] = builtin.unrealized_conversion_cast %[[s43]] : index to i64
-//      CHECK: %[[s45:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s44]] : i64] : vector<4xf32>
-//      CHECK: %[[s46:.*]] = arith.constant 3 : index
-//      CHECK: %[[s47:.*]] = builtin.unrealized_conversion_cast %[[s46]] : index to i64
-//      CHECK: %[[s48:.*]] = llvm.insertelement %[[s45]], %[[s42]]{{\[}}%[[s47]] : i64] : vector<8xf32>
-//      CHECK: %[[s49:.*]] = arith.constant 2 : index
-//      CHECK: %[[s50:.*]] = builtin.unrealized_conversion_cast %[[s49]] : index to i64
-//      CHECK: %[[s51:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s50]] : i64] : vector<4xf32>
-//      CHECK: %[[s52:.*]] = arith.constant 4 : index
-//      CHECK: %[[s53:.*]] = builtin.unrealized_conversion_cast %[[s52]] : index to i64
-//      CHECK: %[[s54:.*]] = llvm.insertelement %[[s51]], %[[s48]]{{\[}}%[[s53]] : i64] : vector<8xf32>
-//      CHECK: %[[s55:.*]] = arith.constant 3 : index
-//      CHECK: %[[s56:.*]] = builtin.unrealized_conversion_cast %[[s55]] : index to i64
-//      CHECK: %[[s57:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s56]] : i64] : vector<4xf32>
-//      CHECK: %[[s58:.*]] = arith.constant 5 : index
-//      CHECK: %[[s59:.*]] = builtin.unrealized_conversion_cast %[[s58]] : index to i64
-//      CHECK: %[[s60:.*]] = llvm.insertelement %[[s57]], %[[s54]]{{\[}}%[[s59]] : i64] : vector<8xf32>
-//      CHECK: %[[s61:.*]] = llvm.insertvalue %[[s60]], %[[s32]][1] : !llvm.array<4 x vector<8xf32>>
-//      CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-//      CHECK: %[[s64:.*]] = builtin.unrealized_conversion_cast %[[s63]] : !llvm.array<16 x array<4 x vector<8xf32>>> to vector<16x4x8xf32>
-//      CHECK: return %[[s64]] : vector<16x4x8xf32>
+// CHECK-LABEL: func @insert_strided_slice3
+//       CHECK:    %[[V4_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+//       CHECK:    %[[V4_0_0:.*]] = llvm.extractvalue {{.*}}[0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
+//       CHECK:    %[[R8_0:.*]] = llvm.shufflevector %[[V4_0]], %[[V4_0]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+//       CHECK:    %[[R8_1:.*]] = llvm.shufflevector %[[R8_0:.*]], %[[V4_0_0]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:    llvm.insertvalue %[[R8_1]], {{.*}}[0] : !llvm.array<4 x vector<8xf32>>
+
+//       CHECK:    %[[V4_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
+//       CHECK:    %[[V4_0_1:.*]] = llvm.extractvalue {{.*}}[0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
+//       CHECK:    %[[R8_2:.*]] = llvm.shufflevector %[[V4_1]], %[[V4_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+//       CHECK:    %[[R8_3:.*]] = llvm.shufflevector %[[R8_2]], %[[V4_0_1]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:    llvm.insertvalue %[[R8_3]], {{.*}}[1] : !llvm.array<4 x vector<8xf32>>
 
 // -----
 


        


More information about the Mlir-commits mailing list