[Mlir-commits] [mlir] 1469ebf - [mlir][vector] Allow unroll of contraction in arbitrary order
Christopher Bate
llvmlistbot at llvm.org
Mon Jun 6 13:31:27 PDT 2022
Author: Christopher Bate
Date: 2022-06-06T14:31:04-06:00
New Revision: 1469ebf8382107e0344173f362b690d19e24029d
URL: https://github.com/llvm/llvm-project/commit/1469ebf8382107e0344173f362b690d19e24029d
DIFF: https://github.com/llvm/llvm-project/commit/1469ebf8382107e0344173f362b690d19e24029d.diff
LOG: [mlir][vector] Allow unroll of contraction in arbitrary order
Adds supprot for vector unroll transformations to unroll in different
orders. For example, the `vector.contract` can be unrolled into a
smaller set of contractions. There is a choice of how to unroll the
decomposition based on the traversal order of (dim0, dim1, dim2).
The choice of traversal order can now be specified by a callback which
given by the caller of the transform. For now, only the
`vector.contract`, `vector.transfer_read/transfer_write` operations
support the callback.
Differential Revision: https://reviews.llvm.org/D127004
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e7226f4a6ac0f..e215be49b74ef 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -128,6 +128,19 @@ struct UnrollVectorOptions {
};
return *this;
}
+
+ /// Function that returns the traversal order (in terms of "for loop order",
+ /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
+ /// be used when unrolling the given operation into units of the native vector
+ /// size.
+ using UnrollTraversalOrderFnType =
+ std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
+ UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
+ UnrollVectorOptions &
+ setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
+ traversalOrderCallback = std::move(traversalOrderFn);
+ return *this;
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 7f00788d888b2..b6c4114b391c1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -15,8 +15,11 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include <numeric>
#define DEBUG_TYPE "vector-unrolling"
@@ -36,20 +39,78 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
return elementOffsets;
}
+/// A functor that accomplishes the same thing as `getVectorOffset` but allows
+/// for reordering the traversal of the dimensions. The order of traversal is
+/// given in "for loop order" (outer to inner).
+namespace {
+class DecomposeShapeIterator {
+private:
+ SmallVector<int64_t, 4> vectorShape;
+ SmallVector<int64_t> loopOrder;
+ SmallVector<int64_t> sliceStrides;
+ int64_t maxIndexVal{1};
+
+public:
+ DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> loopOrder)
+ : vectorShape(targetShape.begin(), targetShape.end()),
+ loopOrder(loopOrder.begin(), loopOrder.end()),
+ sliceStrides(originalShape.size()) {
+ // Compute the count for each dimension.
+ SmallVector<int64_t> sliceDimCounts(originalShape.size());
+ for (unsigned r = 0; r < originalShape.size(); ++r) {
+ sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
+ maxIndexVal *= sliceDimCounts[r];
+ }
+
+ // Reversing "loop order" gives dimensions from fastest varying to slowest
+ // varying (smallest stride to largest stride).
+ int64_t accum = 1;
+ for (auto idx : llvm::reverse(loopOrder)) {
+ sliceStrides[idx] = accum;
+ accum *= sliceDimCounts[idx];
+ }
+ }
+
+ // Turn the linear index into a d-tuple based on units of vectors of size
+ // `vectorShape`. The linear index is assumed to represent traversal of the
+ // dimensions based on `order`.
+ SmallVector<int64_t> delinearize(int64_t index) const {
+ // Traverse in for loop order (largest stride to smallest stride).
+ SmallVector<int64_t, 4> vectorOffsets(sliceStrides.size());
+ for (auto idx : loopOrder) {
+ vectorOffsets[idx] = index / sliceStrides[idx];
+ index %= sliceStrides[idx];
+ }
+ return vectorOffsets;
+ }
+
+ int64_t maxIndex() const { return maxIndexVal; }
+
+ /// Return the offset within d-tuple based on the ordering given by
+ /// `loopOrder`.
+ SmallVector<int64_t> getVectorOffset(int64_t index) const {
+ SmallVector<int64_t> vectorOffsets = delinearize(index);
+ SmallVector<int64_t> elementOffsets =
+ computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
+ return elementOffsets;
+ }
+};
+} // namespace
+
/// Compute the indices of the slice `index` for a tranfer op.
-static SmallVector<Value>
-sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
- AffineMap permutationMap, Location loc,
- OpBuilder &builder) {
+static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
+ ArrayRef<Value> indices,
+ AffineMap permutationMap,
+ Location loc,
+ OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
return constExpr.getValue() == 0;
return false;
};
- SmallVector<int64_t, 4> elementOffsets =
- getVectorOffset(originalShape, targetShape, index);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value> slicedIndices(indices.begin(), indices.end());
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
@@ -99,6 +160,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
return targetShape;
}
+static SmallVector<int64_t>
+getUnrollOrder(unsigned numLoops, Operation *op,
+ const vector::UnrollVectorOptions &options) {
+ SmallVector<int64_t> loopOrder =
+ llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
+ if (options.traversalOrderCallback != nullptr) {
+ Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
+ if (order.hasValue()) {
+ loopOrder = std::move(*order);
+ }
+ }
+ return loopOrder;
+}
+
namespace {
struct UnrollTransferReadPattern
@@ -122,8 +197,7 @@ struct UnrollTransferReadPattern
Location loc = readOp.getLoc();
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
- // Compute shape ratio of 'shape' and 'sizes'.
- int64_t sliceCount = computeMaxLinearIndex(ratio);
+
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
@@ -131,17 +205,22 @@ struct UnrollTransferReadPattern
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
readOp.getIndices().end());
- for (int64_t i = 0; i < sliceCount; i++) {
+
+ SmallVector<int64_t> loopOrder =
+ getUnrollOrder(ratio.size(), readOp, options);
+ DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+ loopOrder);
+ for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
+ SmallVector<int64_t, 4> elementOffsets =
+ indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> indices =
- sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+ sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
loc, targetType, readOp.getSource(), indices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- SmallVector<int64_t, 4> elementOffsets =
- getVectorOffset(originalSize, *targetShape, i);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
@@ -174,20 +253,21 @@ struct UnrollTransferWritePattern
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
- SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
- // Compute shape ratio of 'shape' and 'sizes'.
- int64_t sliceCount = computeMaxLinearIndex(ratio);
SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
+
+ SmallVector<int64_t> loopOrder =
+ getUnrollOrder(originalIndices.size(), writeOp, options);
+ DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+ loopOrder);
Value resultTensor;
- for (int64_t i = 0; i < sliceCount; i++) {
+ for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
SmallVector<int64_t, 4> elementOffsets =
- getVectorOffset(originalSize, *targetShape, i);
+ indexToOffsets.getVectorOffset(i);
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
-
SmallVector<Value, 4> indices =
- sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+ sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
@@ -238,8 +318,6 @@ struct UnrollContractionPattern
SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
- // Compute shape ratio of 'shape' and 'sizes'.
- int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = contractOp.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
@@ -247,9 +325,14 @@ struct UnrollContractionPattern
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;
+
+ SmallVector<int64_t> loopOrder = getUnrollOrder(
+ contractOp.getIndexingMaps().size(), contractOp, options);
+ DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+ loopOrder);
+ const int64_t sliceCount = indexToOffsets.maxIndex();
for (int64_t i = 0; i < sliceCount; i++) {
- SmallVector<int64_t, 4> offsets =
- getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
// Helper to coompute the new shape of each operand and extract the slice.
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 3d0affd2a4be0..f6f218b6e39eb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER
// CHECK-LABEL: func @transfer_read_unroll
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -13,6 +14,19 @@
// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32>
+// ORDER-LABEL: func @transfer_read_unroll
+// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
+// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
+// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// ORDER-NEXT: return %[[VEC3]] : vector<4x4xf32>
+
func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -33,6 +47,19 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
// CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: return
+// ORDER-LABEL: func @transfer_write_unroll
+// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
+// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
+// ORDER: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// ORDER-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// ORDER-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// ORDER-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// ORDER-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// ORDER-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// ORDER-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// ORDER-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// ORDER-NEXT: return
+
func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -222,6 +249,25 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32>
+
+// ORDER-LABEL: func @transfer_read_unroll_
diff erent_rank
+// ORDER-DAG: %[[C4:.*]] = arith.constant 4 : index
+// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
+// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
+// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32>
+
#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
func.func @transfer_read_unroll_
diff erent_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 3b0aeb48665f4..272c04a34e8ea 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -1,50 +1,156 @@
// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1" --split-input-file | FileCheck %s --check-prefix=ORDER
-func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
%init : vector<8x8xf32>) -> vector<8x8xf32> {
%0 = vector.contract
{indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (j, k)>,
affine_map<(i, j, k) -> (i, j)>],
iterator_types = ["parallel", "parallel", "reduction"]}
- %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+ %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
return %0 : vector<8x8xf32>
}
// CHECK-LABEL: func @vector_contract_f32
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
-// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [0, 0]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [0, 0]
+// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// CHECK-SAME: offsets = [0, 0]
+// CHECK: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [0, 2]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [0, 2]
+// CHECK: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [0, 0]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [4, 0]
+// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// CHECK-SAME: offsets = [0, 4]
+// CHECK: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [0, 2]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [4, 2]
+// CHECK: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [4, 0]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [0, 0]
+// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// CHECK-SAME: offsets = [4, 0]
+// CHECK: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [4, 2]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [0, 2]
+// CHECK: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [4, 0]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [4, 0]
+// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// CHECK-SAME: offsets = [4, 4]
+// CHECK: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-// CHECK: vector.contract {
+
+// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// CHECK-SAME: offsets = [4, 2]
+// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// CHECK-SAME: offsets = [4, 2]
+// CHECK: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
// CHECK: return
+// ORDER-LABEL: func @vector_contract_f32
+// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [0, 0]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [0, 0]
+// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// ORDER-SAME: offsets = [0, 0]
+// ORDER: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [0, 0]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [4, 0]
+// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// ORDER-SAME: offsets = [0, 4]
+// ORDER: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [4, 0]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [0, 0]
+// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// ORDER-SAME: offsets = [4, 0]
+// ORDER: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [4, 0]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [4, 0]
+// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]]
+// ORDER-SAME: offsets = [4, 4]
+// ORDER: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [0, 2]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [0, 2]
+// ORDER: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [0, 2]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [4, 2]
+// ORDER: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [4, 2]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [0, 2]
+// ORDER: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]]
+// ORDER-SAME: offsets = [4, 2]
+// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]]
+// ORDER-SAME: offsets = [4, 2]
+// ORDER: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
+// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+// ORDER: return
+
+
+
func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
%init : vector<8x8xf16>) -> vector<8x8xf16> {
%0 = vector.contract
@@ -158,3 +264,4 @@ func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: return %[[V7]] : vector<2x3x8x4xf32>
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a81aa536df4ad..cbec267734eba 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
@@ -322,12 +323,18 @@ struct TestVectorUnrollingPatterns
}
return nativeShape;
};
- populateVectorUnrollPatterns(patterns,
- UnrollVectorOptions()
- .setNativeShapeFn(nativeShapeFn)
- .setFilterConstraint([](Operation *op) {
- return success(isa<ContractionOp>(op));
- }));
+
+ UnrollVectorOptions opts;
+ opts.setNativeShapeFn(nativeShapeFn)
+ .setFilterConstraint(
+ [](Operation *op) { return success(isa<ContractionOp>(op)); });
+ if (!unrollOrder.empty()) {
+ opts.setUnrollTraversalOrderFn([this](Operation *op)
+ -> Optional<SmallVector<int64_t>> {
+ return SmallVector<int64_t>{unrollOrder.begin(), unrollOrder.end()};
+ });
+ }
+ populateVectorUnrollPatterns(patterns, opts);
} else {
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
@@ -340,6 +347,10 @@ struct TestVectorUnrollingPatterns
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
+ ListOption<int64_t> unrollOrder{*this, "unroll-order",
+ llvm::cl::desc("set the unroll order"),
+ llvm::cl::ZeroOrMore};
+
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
@@ -472,6 +483,11 @@ struct TestVectorTransferUnrollingPatterns
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferUnrollingPatterns)
+ TestVectorTransferUnrollingPatterns() = default;
+ TestVectorTransferUnrollingPatterns(
+ const TestVectorTransferUnrollingPatterns &pass)
+ : PassWrapper(pass) {}
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<AffineDialect>();
}
@@ -485,17 +501,36 @@ struct TestVectorTransferUnrollingPatterns
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateVectorUnrollPatterns(
- patterns,
- UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{2, 2})
- .setFilterConstraint([](Operation *op) {
- return success(
- isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
- }));
+ UnrollVectorOptions opts;
+ opts.setNativeShape(ArrayRef<int64_t>{2, 2})
+ .setFilterConstraint([](Operation *op) {
+ return success(
+ isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
+ });
+ if (reverseUnrollOrder.getValue()) {
+ opts.setUnrollTraversalOrderFn(
+ [](Operation *op) -> Optional<SmallVector<int64_t>> {
+ int64_t numLoops = 0;
+ if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
+ numLoops = readOp.getVectorType().getRank();
+ else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
+ numLoops = writeOp.getVectorType().getRank();
+ else
+ return None;
+ auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
+ return llvm::to_vector(order);
+ });
+ }
+ populateVectorUnrollPatterns(patterns, opts);
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
+
+ Option<bool> reverseUnrollOrder{
+ *this, "reverse-unroll-order",
+ llvm::cl::desc(
+ "reverse the order of unrolling of vector transfer operations"),
+ llvm::cl::init(false)};
};
struct TestVectorTransferFullPartialSplitPatterns
More information about the Mlir-commits
mailing list