[Mlir-commits] [mlir] b1d8205 - [mlir][Vector] Add lowering support for 1-D masked multi-reductions
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 7 12:05:24 PST 2023
Author: Diego Caballero
Date: 2023-02-07T20:03:38Z
New Revision: b1d82057ed7b5374e9e8ebf5d7c53104e555ee53
URL: https://github.com/llvm/llvm-project/commit/b1d82057ed7b5374e9e8ebf5d7c53104e555ee53
DIFF: https://github.com/llvm/llvm-project/commit/b1d82057ed7b5374e9e8ebf5d7c53104e555ee53.diff
LOG: [mlir][Vector] Add lowering support for 1-D masked multi-reductions
1-D multi-reductions follow a different lowering path (they are
converted to 2-D multi-reductions) so masked variants need to be
supported explicitly.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D143453
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 117fdcb84c809..b790d141415aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -385,17 +385,25 @@ struct OneDimMultiReductionToTwoDim
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
- auto maskableOp =
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- if (maskableOp.isMasked())
- // TODO: Support masking.
- return failure();
-
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-1 or bail.
if (srcRank != 1)
return failure();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
auto loc = multiReductionOp.getLoc();
auto srcVectorType = multiReductionOp.getSourceVectorType();
auto srcShape = srcVectorType.getShape();
@@ -408,16 +416,27 @@ struct OneDimMultiReductionToTwoDim
// If the unique dim is reduced and we insert a parallel in front, we need a
// {false, true} mask.
- SmallVector<bool, 2> mask{false, true};
+ SmallVector<bool, 2> reductionMask{false, true};
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.getSource());
Value castAcc = rewriter.create<vector::BroadcastOp>(
loc, accType, multiReductionOp.getAcc());
- Value reduced = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, castAcc, mask, multiReductionOp.getKind());
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
+ Value castMask;
+ if (maskableOp.isMasked()) {
+ auto maskType = mask.getType().cast<ShapedType>();
+ auto castMaskType =
+ VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
+ maskType.getElementType());
+ castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
+ }
+
+ Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
+ loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
+ newOp = vector::maskOperation(rewriter, newOp, castMask);
+
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
ArrayRef<int64_t>{0});
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 5647089d2ed5b..1a9aaf38ac3b0 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -189,6 +189,26 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
// -----
+func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %c0_1 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.create_mask %dim : vector<8xi1>
+ %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32
+ return %4 : f32
+}
+
+// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction.
+// This transform expands 1-D vectors into 2-D.
+
+// CHECK-LABEL: func.func @vectorize_1d_dynamic_reduction(
+// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
+// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+
+// -----
+
func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
More information about the Mlir-commits
mailing list