[Mlir-commits] [mlir] c3c95b9 - [mlir] [VectorOps] Improve lowering of extract_strided_slice (and friends like shape_cast)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 7 09:21:24 PDT 2020


Author: aartbik
Date: 2020-08-07T09:21:05-07:00
New Revision: c3c95b9c808519662afe8b9053aa88b5be451d1d

URL: https://github.com/llvm/llvm-project/commit/c3c95b9c808519662afe8b9053aa88b5be451d1d
DIFF: https://github.com/llvm/llvm-project/commit/c3c95b9c808519662afe8b9053aa88b5be451d1d.diff

LOG: [mlir] [VectorOps] Improve lowering of extract_strided_slice (and friends like shape_cast)

Using a shuffle for the last recursive step in progressive lowering not only
results in much more compact IR, but also more efficient code (since the
backend is no longer confused on subvector aliasing for longer vectors).

E.g. the following

  %f = vector.shape_cast %v0: vector<1024xf32> to vector<32x32xf32>

yields much better x86-64 code that runs 3x faster than the original.

Reviewed By: bkramer, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D85482

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 12d3f5042bcd..1e92b80d830f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1349,9 +1349,9 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
 };
 
 /// Progressive lowering of ExtractStridedSliceOp to either:
-///   1. extractelement + insertelement for the 1-D case
-///   2. extract + optional strided_slice + insert for the n-D case.
-class VectorStridedSliceOpConversion
+///   1. express single offset extract as a direct shuffle.
+///   2. extract + lower rank strided_slice + insert for the n-D case.
+class VectorExtractStridedSliceOpConversion
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
@@ -1371,21 +1371,34 @@ class VectorStridedSliceOpConversion
     auto loc = op.getLoc();
     auto elemType = dstType.getElementType();
     assert(elemType.isSignlessIntOrIndexOrFloat());
+
+    // Single offset can be more efficiently shuffled.
+    if (op.offsets().getValue().size() == 1) {
+      SmallVector<int64_t, 4> offsets;
+      offsets.reserve(size);
+      for (int64_t off = offset, e = offset + size * stride; off < e;
+           off += stride)
+        offsets.push_back(off);
+      rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
+                                             op.vector(),
+                                             rewriter.getI64ArrayAttr(offsets));
+      return success();
+    }
+
+    // Extract/insert on a lower ranked extract strided slice op.
     Value zero = rewriter.create<ConstantOp>(loc, elemType,
                                              rewriter.getZeroAttr(elemType));
     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {
-      Value extracted = extractOne(rewriter, loc, op.vector(), off);
-      if (op.offsets().getValue().size() > 1) {
-        extracted = rewriter.create<ExtractStridedSliceOp>(
-            loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
-            getI64SubArray(op.sizes(), /* dropFront=*/1),
-            getI64SubArray(op.strides(), /* dropFront=*/1));
-      }
+      Value one = extractOne(rewriter, loc, op.vector(), off);
+      Value extracted = rewriter.create<ExtractStridedSliceOp>(
+          loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
+          getI64SubArray(op.sizes(), /* dropFront=*/1),
+          getI64SubArray(op.strides(), /* dropFront=*/1));
       res = insertOne(rewriter, loc, extracted, res, idx);
     }
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOp(op, res);
     return success();
   }
   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
@@ -1404,7 +1417,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
   patterns.insert<VectorFMAOpNDRewritePattern,
                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
                   VectorInsertStridedSliceOpSameRankRewritePattern,
-                  VectorStridedSliceOpConversion>(ctx);
+                  VectorExtractStridedSliceOpConversion>(ctx);
   patterns.insert<VectorReductionOpConversion>(
       ctx, converter, reassociateFPReductions);
   patterns

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5254d2eef4bf..d91d4db06106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -512,65 +512,38 @@ func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
 }
-// CHECK-LABEL: llvm.func @extract_strided_slice1
-//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
-//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float>
-//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float>
-//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
+// CHECK-LABEL: llvm.func @extract_strided_slice1(
+//  CHECK-SAME:    %[[A:.*]]: !llvm.vec<4 x float>)
+//       CHECK:    %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : !llvm.vec<4 x float>, !llvm.vec<4 x float>
+//       CHECK:    llvm.return %[[T0]] : !llvm.vec<2 x float>
 
 func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
   return %0 : vector<2x8xf32>
 }
-// CHECK-LABEL: llvm.func @extract_strided_slice2
-//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
-//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm.array<2 x vec<8 x float>>
-//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vec<8 x float>>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.array<2 x vec<8 x float>>
-//       CHECK:    llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vec<8 x float>>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.array<2 x vec<8 x float>>
+// CHECK-LABEL: llvm.func @extract_strided_slice2(
+//  CHECK-SAME:    %[[A:.*]]: !llvm.array<4 x vec<8 x float>>)
+//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vec<8 x float>>
+//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>>
+//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vec<8 x float>>
+//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>>
+//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vec<8 x float>>
+//       CHECK:    llvm.return %[[T4]] : !llvm.array<2 x vec<8 x float>>
 
 func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
   return %0 : vector<2x2xf32>
 }
-// CHECK-LABEL: llvm.func @extract_strided_slice3
-//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
-//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>>
-//
-// Subvector vector<8xf32> @2
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vec<8 x float>>
-//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
-//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
-//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
-//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
-//       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.array<2 x vec<2 x float>>
-//
-// Subvector vector<8xf32> @3
-//       CHECK:    llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vec<8 x float>>
-//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
-//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
-//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
-//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
-//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
-//       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.array<2 x vec<2 x float>>
+// CHECK-LABEL: llvm.func @extract_strided_slice3(
+//  CHECK-SAME:    %[[A:.*]]: !llvm.array<4 x vec<8 x float>>)
+//       CHECK:    %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>>
+//       CHECK:    %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>>
+//       CHECK:    %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float>
+//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<2 x vec<2 x float>>
+//       CHECK:    %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>>
+//       CHECK:    %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float>
+//       CHECK:    %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vec<2 x float>>
+//       CHECK:    llvm.return %[[T7]] : !llvm.array<2 x vec<2 x float>>
 
 func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
   %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
@@ -674,15 +647,11 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
 }
 // CHECK-LABEL: llvm.func @extract_strides(
 // CHECK-SAME: %[[A:.*]]: !llvm.array<3 x vec<3 x float>>)
-//      CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>>
-//      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>>
-//      CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm.vec<1 x float>
-//      CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
-//      CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm.vec<3 x float>
-//      CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-//      CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm.vec<1 x float>
-//      CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm.array<1 x vec<1 x float>>
-//      CHECK: llvm.return %[[s8]] : !llvm.array<1 x vec<1 x float>>
+//      CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>>
+//      CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>>
+//      CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2] : !llvm.vec<3 x float>, !llvm.vec<3 x float>
+//      CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<1 x vec<1 x float>>
+//      CHECK: llvm.return %[[T4]] : !llvm.array<1 x vec<1 x float>>
 
 // CHECK-LABEL: llvm.func @vector_fma(
 //  CHECK-SAME: %[[A:.*]]: !llvm.vec<8 x float>, %[[B:.*]]: !llvm.array<2 x vec<4 x float>>)


        


More information about the Mlir-commits mailing list