[Mlir-commits] [mlir] 5c58172 - [mlir][Vector] Add support for scalable vectors in multi_reduction
Andrzej Warzynski
llvmlistbot at llvm.org
Tue Aug 8 10:02:46 PDT 2023
Author: Andrzej Warzynski
Date: 2023-08-08T17:01:59Z
New Revision: 5c581720b9bc28e933de1aed99c79557803f22ac
URL: https://github.com/llvm/llvm-project/commit/5c581720b9bc28e933de1aed99c79557803f22ac
DIFF: https://github.com/llvm/llvm-project/commit/5c581720b9bc28e933de1aed99c79557803f22ac.diff
LOG: [mlir][Vector] Add support for scalable vectors in multi_reduction
Support for scalable vectors in vector.multi_reduction is added by
simply updating MultiDimReductionOp::verify.
Also, the conversion pattern for reducing n-D vector.multi_reduction to
2D vector.multi_reduction is updated.
Differential Revision: https://reviews.llvm.org/D157092
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 094d555a0e8fc9..5b416e4a69996f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -313,18 +313,22 @@ MultiDimReductionOp::getShapeForUnroll() {
LogicalResult MultiDimReductionOp::verify() {
SmallVector<int64_t> targetShape;
+ SmallVector<bool> scalableDims;
Type inferredReturnType;
+ auto sourceScalableDims = getSourceVectorType().getScalableDims();
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
- }))
+ })) {
targetShape.push_back(it.value());
+ scalableDims.push_back(sourceScalableDims[it.index()]);
+ }
// TODO: update to also allow 0-d vectors when available.
if (targetShape.empty())
inferredReturnType = getSourceVectorType().getElementType();
else
- inferredReturnType =
- VectorType::get(targetShape, getSourceVectorType().getElementType());
+ inferredReturnType = VectorType::get(
+ targetShape, getSourceVectorType().getElementType(), scalableDims);
if (getType() != inferredReturnType)
return emitOpError() << "destination type " << getType()
<< " is incompatible with source type "
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5d2abf7c03680f..bed2c2496719dd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -154,12 +154,19 @@ class ReduceMultiDimReductionRank
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+ auto srcScalableDims =
+ multiReductionOp.getSourceVectorType().getScalableDims();
auto loc = multiReductionOp.getLoc();
// If rank less than 2, nothing to do.
if (srcRank < 2)
return failure();
+ // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
+ // `vscale * vscale` that's currently not modelled.
+ if (llvm::count(srcScalableDims, true) > 1)
+ return failure();
+
// If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
if (srcRank == 2 && reductionMask.front() != reductionMask.back())
@@ -167,16 +174,20 @@ class ReduceMultiDimReductionRank
// 1. Separate reduction and parallel dims.
SmallVector<int64_t, 4> parallelDims, parallelShapes;
+ SmallVector<bool, 4> parallelScalableDims;
SmallVector<int64_t, 4> reductionDims, reductionShapes;
+ bool isReductionDimScalable = false;
for (const auto &it : llvm::enumerate(reductionMask)) {
int64_t i = it.index();
bool isReduction = it.value();
if (isReduction) {
reductionDims.push_back(i);
reductionShapes.push_back(srcShape[i]);
+ isReductionDimScalable |= srcScalableDims[i];
} else {
parallelDims.push_back(i);
parallelShapes.push_back(srcShape[i]);
+ parallelScalableDims.push_back(srcScalableDims[i]);
}
}
@@ -212,18 +223,23 @@ class ReduceMultiDimReductionRank
// 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
// a single parallel (resp. reduction) dim.
SmallVector<bool, 2> mask;
+ SmallVector<bool, 2> scalableDims;
SmallVector<int64_t, 2> vectorShape;
+ bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
if (flattenedParallelDim) {
mask.push_back(false);
vectorShape.push_back(flattenedParallelDim);
+ scalableDims.push_back(isParallelDimScalable);
}
if (flattenedReductionDim) {
mask.push_back(true);
vectorShape.push_back(flattenedReductionDim);
+ scalableDims.push_back(isReductionDimScalable);
}
if (!useInnerDimsForReduction && vectorShape.size() == 2) {
std::swap(mask.front(), mask.back());
std::swap(vectorShape.front(), vectorShape.back());
+ std::swap(scalableDims.front(), scalableDims.back());
}
Value newVectorMask;
@@ -237,7 +253,8 @@ class ReduceMultiDimReductionRank
}
auto castedType = VectorType::get(
- vectorShape, multiReductionOp.getSourceVectorType().getElementType());
+ vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
+ scalableDims);
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.getSource());
@@ -245,7 +262,8 @@ class ReduceMultiDimReductionRank
if (flattenedParallelDim) {
auto accType = VectorType::get(
{flattenedParallelDim},
- multiReductionOp.getSourceVectorType().getElementType());
+ multiReductionOp.getSourceVectorType().getElementType(),
+ /*scalableDims=*/{isParallelDimScalable});
acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
}
// 6. Creates the flattened form of vector.multi_reduction with inner/outer
@@ -264,8 +282,8 @@ class ReduceMultiDimReductionRank
// 8. Creates shape cast for the output n-D -> 2-D.
VectorType outputCastedType = VectorType::get(
- parallelShapes,
- multiReductionOp.getSourceVectorType().getElementType());
+ parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
+ parallelScalableDims);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
return success();
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index bbb461c56e5401..cf770670c57528 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -249,6 +249,38 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
+func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
+ %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
+ return %0 : vector<8x[4]xf32>
+}
+// CHECK-LABEL: func.func private @scalable_dims(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_34:.*]] = arith.constant 31 : index
+
+// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<8x[4]x2xf32>
+// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<8x[4]xf32>
+// CHECK: %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32
+// CHECK: %[[VAL_38:.*]] = vector.insertelement %[[VAL_37]], %[[VAL_2]]{{\[}}%[[VAL_3]] : index] : vector<[32]xf32>
+
+// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<8x[4]x2xf32>
+// CHECK: %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<8x[4]xf32>
+// CHECK: %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32
+// CHECK: %[[VAL_42:.*]] = vector.insertelement %[[VAL_41]], %[[VAL_38]]{{\[}}%[[VAL_4]] : index] : vector<[32]xf32>
+
+// (...)
+
+// CHECK: %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<8x[4]x2xf32>
+// CHECK: %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : vector<8x[4]xf32>
+// CHECK: %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32
+// CHECK: %[[VAL_162:.*]] = vector.insertelement %[[VAL_161]], %{{.*}}{{\[}}%[[VAL_34]] : index] : vector<[32]xf32>
+
+// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
+// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
+
transform.sequence failures(propagate) {
^bb1(%func_op: !transform.op<"func.func">):
transform.apply_patterns to %func_op {
More information about the Mlir-commits
mailing list