[Mlir-commits] [mlir] 627fa0b - [mlir] MathApproximations: unroll virtual vectors into hardware vectors for ISA specific operation
Eugene Zhulenev
llvmlistbot at llvm.org
Thu Oct 28 12:52:10 PDT 2021
Author: Eugene Zhulenev
Date: 2021-10-28T12:52:04-07:00
New Revision: 627fa0b9a897b8d73b481e475abc4518efe55db7
URL: https://github.com/llvm/llvm-project/commit/627fa0b9a897b8d73b481e475abc4518efe55db7
DIFF: https://github.com/llvm/llvm-project/commit/627fa0b9a897b8d73b481e475abc4518efe55db7.diff
LOG: [mlir] MathApproximations: unroll virtual vectors into hardware vectors for ISA specific operation
Reviewed By: cota
Differential Revision: https://reviews.llvm.org/D112736
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index df66c1582645..8c9ff7d5407d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -92,6 +93,101 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
: value;
}
+//----------------------------------------------------------------------------//
+// Helper function to handle n-D vectors with 1-D operations.
+//----------------------------------------------------------------------------//
+
+// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
+// and calls the compute function with 1-D vector operands. Stitches back all
+// results into the original n-D vector result.
+//
+// Examples: vectorWidth = 8
+// - vector<4x8xf32> unrolled 4 times
+// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
+// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
+//
+// Some math approximations rely on ISA-specific operations that only accept
+// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
+//
+// It is the caller's responsibility to verify that the inner dimension is
+// divisible by the vectorWidth, and that all operands have the same vector
+// shape.
+static Value
+handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
+ ValueRange operands, int64_t vectorWidth,
+ std::function<Value(ValueRange)> compute) {
+ assert(!operands.empty() && "operands must be not empty");
+ assert(vectorWidth > 0 && "vector width must be larger than 0");
+
+ VectorType inputType = operands[0].getType().cast<VectorType>();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+
+ // If input shape matches target vector width, we can just call the
+ // user-provided compute function with the operands.
+ if (inputShape == llvm::makeArrayRef(vectorWidth))
+ return compute(operands);
+
+ // Check if the inner dimension has to be expanded, or we can directly iterate
+ // over the outer dimensions of the vector.
+ int64_t innerDim = inputShape.back();
+ int64_t expansionDim = innerDim / vectorWidth;
+ assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
+
+ // Maybe expand operands to the higher rank vector shape that we'll use to
+ // iterate over and extract one dimensional vectors.
+ SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
+ SmallVector<Value> expandedOperands(operands);
+
+ if (expansionDim > 1) {
+ // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
+ expandedShape.insert(expandedShape.end() - 1, expansionDim);
+ expandedShape.back() = vectorWidth;
+
+ for (unsigned i = 0; i < operands.size(); ++i) {
+ auto operand = operands[i];
+ auto eltType = operand.getType().cast<VectorType>().getElementType();
+ auto expandedType = VectorType::get(expandedShape, eltType);
+ expandedOperands[i] =
+ builder.create<vector::ShapeCastOp>(expandedType, operand);
+ }
+ }
+
+ // Iterate over all outer dimensions of the compute shape vector type.
+ auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
+ int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
+
+ SmallVector<int64_t> ones(iterationDims.size(), 1);
+ auto strides = computeStrides(iterationDims, ones);
+
+ // Compute results for each one dimensional vector.
+ SmallVector<Value> results(maxLinearIndex);
+
+ for (int64_t i = 0; i < maxLinearIndex; ++i) {
+ auto offsets = delinearize(strides, i);
+
+ SmallVector<Value> extracted(expandedOperands.size());
+ for (auto tuple : llvm::enumerate(expandedOperands))
+ extracted[tuple.index()] =
+ builder.create<vector::ExtractOp>(tuple.value(), offsets);
+
+ results[i] = compute(extracted);
+ }
+
+ // Stitch results together into one large vector.
+ Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
+ Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
+ Value result = builder.create<ConstantOp>(
+ resultExpandedType, builder.getZeroAttr(resultExpandedType));
+
+ for (int64_t i = 0; i < maxLinearIndex; ++i)
+ result = builder.create<vector::InsertOp>(results[i], result,
+ delinearize(strides, i));
+
+ // Reshape back to the original vector shape.
+ return builder.create<vector::ShapeCastOp>(
+ VectorType::get(inputShape, resultEltType), result);
+}
+
//----------------------------------------------------------------------------//
// Helper functions to create constants.
//----------------------------------------------------------------------------//
@@ -943,7 +1039,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const {
auto shape = vectorShape(op.operand().getType(), isF32);
// Only support already-vectorized rsqrt's.
- if (!shape.hasValue() || (*shape)[0] != 8)
+ if (!shape.hasValue() || shape->back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -967,7 +1063,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
// Compute an approximate result.
- Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand());
+ Value yApprox = handleMultidimensionalVectors(
+ builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
+ return builder.create<x86vector::RsqrtOp>(operands);
+ });
// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f3b2cdfc16ef..f388b84d83c8 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -402,9 +402,9 @@ func @rsqrt_scalar(%arg0: f32) -> f32 {
return %0 : f32
}
-// CHECK-LABEL: func @rsqrt_vector
+// CHECK-LABEL: func @rsqrt_vector_8xf32
// CHECK: math.rsqrt
-// AVX2-LABEL: func @rsqrt_vector(
+// AVX2-LABEL: func @rsqrt_vector_8xf32(
// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32>
// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32>
@@ -421,7 +421,89 @@ func @rsqrt_scalar(%arg0: f32) -> f32 {
// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32>
// AVX2: return %[[VAL_13]] : vector<8xf32>
// AVX2: }
-func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+func @rsqrt_vector_8xf32(%arg0: vector<8xf32>) -> vector<8xf32> {
%0 = math.rsqrt %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// Virtual vector width is not a multiple of an AVX2 vector width.
+//
+// CHECK-LABEL: func @rsqrt_vector_5xf32
+// CHECK: math.rsqrt
+// AVX2-LABEL: func @rsqrt_vector_5xf32
+// AVX2: math.rsqrt
+func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> {
+ %0 = math.rsqrt %arg0 : vector<5xf32>
+ return %0 : vector<5xf32>
+}
+
+// One dimensional virtual vector expanded and unrolled into multiple AVX2-sized
+// vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_16xf32
+// CHECK: math.rsqrt
+// AVX2-LABEL: func @rsqrt_vector_16xf32(
+// AVX2-SAME: %[[ARG:.*]]: vector<16xf32>
+// AVX2-SAME: ) -> vector<16xf32>
+// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
+// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
+// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
+// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
+// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
+// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
+// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
+// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
+// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
+func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = math.rsqrt %arg0 : vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// Two dimensional virtual vector unrolled into multiple AVX2-sized vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_2x8xf32
+// CHECK: math.rsqrt
+// AVX2-LABEL: func @rsqrt_vector_2x8xf32(
+// AVX2-SAME: %[[ARG:.*]]: vector<2x8xf32>
+// AVX2-SAME: ) -> vector<2x8xf32>
+// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
+// AVX2-NOT: vector.shape_cast
+// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0]
+// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
+// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1]
+// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
+// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
+// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
+// AVX2-NOT: vector.shape_cast
+func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> {
+ %0 = math.rsqrt %arg0 : vector<2x8xf32>
+ return %0 : vector<2x8xf32>
+}
+
+// Two dimensional virtual vector expanded and unrolled into multiple AVX2-sized
+// vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_2x16xf32
+// CHECK: math.rsqrt
+// AVX2-LABEL: func @rsqrt_vector_2x16xf32(
+// AVX2-SAME: %[[ARG:.*]]: vector<2x16xf32>
+// AVX2-SAME: ) -> vector<2x16xf32>
+// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
+// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
+// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
+// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]]
+// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
+// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]]
+// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
+// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]]
+// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
+// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]]
+// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
+// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
+// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]
+// AVX2: %[[RESULT3:.*]] = vector.insert %[[RSQRT11]], %[[RESULT2]] [1, 1]
+// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT3]] : vector<2x2x8xf32> to vector<2x16xf32>
+func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
+ %0 = math.rsqrt %arg0 : vector<2x16xf32>
+ return %0 : vector<2x16xf32>
+}
More information about the Mlir-commits
mailing list