[Mlir-commits] [mlir] 5b1b710 - [mlir][vector] Add unrolling pattern for TransposeOp
Thomas Raoux
llvmlistbot at llvm.org
Wed Apr 13 12:44:41 PDT 2022
Author: Thomas Raoux
Date: 2022-04-13T19:44:16Z
New Revision: 5b1b7108c8975159c1112ceea1cd7e213e1be97a
URL: https://github.com/llvm/llvm-project/commit/5b1b7108c8975159c1112ceea1cd7e213e1be97a
DIFF: https://github.com/llvm/llvm-project/commit/5b1b7108c8975159c1112ceea1cd7e213e1be97a.diff
LOG: [mlir][vector] Add unrolling pattern for TransposeOp
Support unrolling for vector.transpose following the same interface as
other vector unrolling ops.
Differential Revision: https://reviews.llvm.org/D123688
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 289b50256d454..3e9ad30cd8bc3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2217,6 +2217,7 @@ def Vector_CreateMaskOp :
def Vector_TransposeOp :
Vector_Op<"transpose", [NoSideEffect,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 758478f8d7ff8..f326217af299c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4320,6 +4320,10 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}
+Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultType().getShape());
+}
+
namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 2b730182d2088..7f00788d888b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -681,14 +681,62 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
const vector::UnrollVectorOptions options;
};
+struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
+ UnrollTranposePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options)
+ : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
+ options(options) {}
+ LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
+ PatternRewriter &rewriter) const override {
+ if (tranposeOp.getResultType().getRank() == 0)
+ return failure();
+ auto targetShape = getTargetShape(options, tranposeOp);
+ if (!targetShape)
+ return failure();
+ auto originalVectorType = tranposeOp.getResultType();
+ SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+ Location loc = tranposeOp.getLoc();
+ ArrayRef<int64_t> originalSize = originalVectorType.getShape();
+ SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+ int64_t sliceCount = computeMaxLinearIndex(ratio);
+ // Prepare the result vector;
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
+ SmallVector<int64_t> permutation;
+ tranposeOp.getTransp(permutation);
+ for (int64_t i = 0; i < sliceCount; i++) {
+ SmallVector<int64_t, 4> elementOffsets =
+ getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
+ SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
+ // Compute the source offsets and shape.
+ for (auto &indices : llvm::enumerate(permutation)) {
+ permutedOffsets[indices.value()] = elementOffsets[indices.index()];
+ permutedShape[indices.value()] = (*targetShape)[indices.index()];
+ }
+ Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
+ Value tranposedSlice =
+ rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, tranposedSlice, result, elementOffsets, strides);
+ }
+ rewriter.replaceOp(tranposeOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern>(
- patterns.getContext(), options);
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTranposePattern>(patterns.getContext(), options);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 5a0014451b2c4..55132d63f2899 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -107,6 +107,11 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>
+
+func @vector_reduction(%v : vector<8xf32>) -> f32 {
+ %0 = vector.reduction <add>, %v : vector<8xf32> into f32
+ return %0 : f32
+}
// CHECK-LABEL: func @vector_reduction(
// CHECK-SAME: %[[v:.*]]: vector<8xf32>
// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
@@ -121,8 +126,35 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
// CHECK: return %[[add3]]
-func @vector_reduction(%v : vector<8xf32>) -> f32 {
- %0 = vector.reduction <add>, %v : vector<8xf32> into f32
- return %0 : f32
-}
+func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
+ %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
+ return %t : vector<2x3x8x4xf32>
+}
+// CHECK-LABEL: func @vector_tranpose
+// CHECK: %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
+// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
+// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
+// CHECK: return %[[V7]] : vector<2x3x8x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 73ff660014b75..12f74e91c29e4 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -282,6 +282,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::TransposeOp>(op));
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
More information about the Mlir-commits
mailing list