[Mlir-commits] [mlir] 5b1b710 - [mlir][vector] Add unrolling pattern for TransposeOp

Thomas Raoux llvmlistbot at llvm.org
Wed Apr 13 12:44:41 PDT 2022


Author: Thomas Raoux
Date: 2022-04-13T19:44:16Z
New Revision: 5b1b7108c8975159c1112ceea1cd7e213e1be97a

URL: https://github.com/llvm/llvm-project/commit/5b1b7108c8975159c1112ceea1cd7e213e1be97a
DIFF: https://github.com/llvm/llvm-project/commit/5b1b7108c8975159c1112ceea1cd7e213e1be97a.diff

LOG: [mlir][vector] Add unrolling pattern for TransposeOp

Support unrolling for vector.transpose following the same interface as
other vector unrolling ops.

Differential Revision: https://reviews.llvm.org/D123688

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 289b50256d454..3e9ad30cd8bc3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2217,6 +2217,7 @@ def Vector_CreateMaskOp :
 
 def Vector_TransposeOp :
   Vector_Op<"transpose", [NoSideEffect,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
     Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 758478f8d7ff8..f326217af299c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4320,6 +4320,10 @@ LogicalResult vector::TransposeOp::verify() {
   return success();
 }
 
+Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultType().getShape());
+}
+
 namespace {
 
 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 2b730182d2088..7f00788d888b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -681,14 +681,62 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
   const vector::UnrollVectorOptions options;
 };
 
+struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
+  UnrollTranposePattern(MLIRContext *context,
+                        const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
+        options(options) {}
+  LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (tranposeOp.getResultType().getRank() == 0)
+      return failure();
+    auto targetShape = getTargetShape(options, tranposeOp);
+    if (!targetShape)
+      return failure();
+    auto originalVectorType = tranposeOp.getResultType();
+    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    Location loc = tranposeOp.getLoc();
+    ArrayRef<int64_t> originalSize = originalVectorType.getShape();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    // Prepare the result vector;
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
+    SmallVector<int64_t> permutation;
+    tranposeOp.getTransp(permutation);
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<int64_t, 4> elementOffsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
+      SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
+      // Compute the source offsets and shape.
+      for (auto &indices : llvm::enumerate(permutation)) {
+        permutedOffsets[indices.value()] = elementOffsets[indices.index()];
+        permutedShape[indices.value()] = (*targetShape)[indices.index()];
+      }
+      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
+      Value tranposedSlice =
+          rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, tranposedSlice, result, elementOffsets, strides);
+    }
+    rewriter.replaceOp(tranposeOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
-               UnrollReductionPattern, UnrollMultiReductionPattern>(
-      patterns.getContext(), options);
+               UnrollReductionPattern, UnrollMultiReductionPattern,
+               UnrollTranposePattern>(patterns.getContext(), options);
 }
 
 void mlir::vector::populatePropagateVectorDistributionPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 5a0014451b2c4..55132d63f2899 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -107,6 +107,11 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
 //       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   return %[[V2]] : vector<4xf32>
 
+
+func @vector_reduction(%v : vector<8xf32>) -> f32 {
+  %0 = vector.reduction <add>, %v : vector<8xf32> into f32
+  return %0 : f32
+}
 // CHECK-LABEL: func @vector_reduction(
 //  CHECK-SAME:     %[[v:.*]]: vector<8xf32>
 //       CHECK:   %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
@@ -121,8 +126,35 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
 //       CHECK:   %[[r3:.*]] = vector.reduction <add>, %[[s3]]
 //       CHECK:   %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
 //       CHECK:   return %[[add3]]
-func @vector_reduction(%v : vector<8xf32>) -> f32 {
-  %0 = vector.reduction <add>, %v : vector<8xf32> into f32
-  return %0 : f32
-}
 
+func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
+  %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
+  return %t : vector<2x3x8x4xf32>
+}
+// CHECK-LABEL: func @vector_tranpose
+//       CHECK:   %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+//       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+//       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+//       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 73ff660014b75..12f74e91c29e4 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -282,6 +282,12 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::ReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::TransposeOp>(op));
+                      }));
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =


        


More information about the Mlir-commits mailing list