[Mlir-commits] [mlir] 51f235c - [mlir][Vector] Add folding for masked reductions and vector.mask
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 21 22:51:44 PST 2023
Author: Diego Caballero
Date: 2023-02-22T06:37:38Z
New Revision: 51f235c4445794e9cae25d0d29b75f030a029ceb
URL: https://github.com/llvm/llvm-project/commit/51f235c4445794e9cae25d0d29b75f030a029ceb
DIFF: https://github.com/llvm/llvm-project/commit/51f235c4445794e9cae25d0d29b75f030a029ceb.diff
LOG: [mlir][Vector] Add folding for masked reductions and vector.mask
This patch adds support for folding trivial masked reductions and
multi-reductions (e.g., multi-reductions with only parallel dims,
reductions of a single element, etc.). To support those foldings in
a composable way we also add support for folding different flavors of
empty vector.mask opertions.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D144414
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 04fb36a520265..6a5913e512388 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1172,7 +1172,7 @@ def Vector_ExtractStridedSliceOp :
static StringRef getSizesAttrStrName() { return "sizes"; }
static StringRef getStridesAttrStrName() { return "strides"; }
VectorType getSourceVectorType() {
- return getVector().getType().cast<VectorType>();
+ return getVector().getType().cast<VectorType>();
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
@@ -2382,9 +2382,11 @@ def Vector_MaskOp : Vector_Op<"mask", [
];
let extraClassDeclaration = [{
+ Block *getMaskBlock() { return &getMaskRegion().front(); }
static void ensureTerminator(Region ®ion, Builder &builder, Location loc);
}];
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb58f904462a6..2a81cf673dc40 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -361,34 +361,50 @@ struct ElideUnitDimsInMultiDimReduction
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
- // Masked reductions can't be folded until we can propagate the mask to the
- // resulting operation.
- auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
for (const auto &dim : enumerate(shape)) {
if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
return failure();
}
+
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Operation *rootOp;
+ Value mask;
+ if (reductionOp.isMasked()) {
+ rewriter.setInsertionPoint(reductionOp.getMaskingOp());
+ rootOp = reductionOp.getMaskingOp();
+ mask = reductionOp.getMaskingOp().getMask();
+ } else {
+ rootOp = reductionOp;
+ }
+
Location loc = reductionOp.getLoc();
Value acc = reductionOp.getAcc();
Value cast;
- if (reductionOp.getDestType().isa<VectorType>()) {
+ if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
+ if (mask) {
+ VectorType newMaskType =
+ VectorType::get(dstVecType.getShape(), rewriter.getI1Type());
+ mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+ }
cast = rewriter.create<vector::ShapeCastOp>(
loc, reductionOp.getDestType(), reductionOp.getSource());
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
+ auto zeroAttr =
+ rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
+ if (mask)
+ mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
+ mask, zeroAttr);
cast = rewriter.create<vector::ExtractOp>(
- loc, reductionOp.getDestType(), reductionOp.getSource(),
- rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
+ loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
}
- Value result = vector::makeArithReduction(rewriter, loc,
- reductionOp.getKind(), acc, cast);
- rewriter.replaceOp(reductionOp, result);
+ Value result = vector::makeArithReduction(
+ rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
@@ -524,11 +540,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
LogicalResult matchAndRewrite(ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
- // Masked reductions can't be folded until we can propagate the mask to the
- // resulting operation.
- auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(reductionOp.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = reductionOp;
+ }
auto vectorType = reductionOp.getSourceVectorType();
if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
@@ -537,8 +561,14 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Location loc = reductionOp.getLoc();
Value result;
if (vectorType.getRank() == 0) {
+ if (mask)
+ mask = rewriter.create<ExtractElementOp>(loc, mask);
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
} else {
+ if (mask) {
+ mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
+ rewriter.getI64ArrayAttr(0));
+ }
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
@@ -546,9 +576,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, acc);
+ result, acc, mask);
- rewriter.replaceOp(reductionOp, result);
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
@@ -5465,7 +5495,7 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
// Print single masked operation and skip terminator.
p << " { ";
Block *singleBlock = &getMaskRegion().getBlocks().front();
- if (singleBlock && singleBlock->getOperations().size() > 1)
+ if (singleBlock && singleBlock->getOperations().size() >= 1)
p.printCustomOrGenericOp(&singleBlock->front());
p << " }";
@@ -5481,33 +5511,49 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
MaskOp>::ensureTerminator(region, builder, loc);
// Keep the default yield terminator if the number of masked operations is not
// the expected. This case will trigger a verification failure.
- if (region.front().getOperations().size() != 2)
+ Block &block = region.front();
+ if (block.getOperations().size() != 2)
return;
// Replace default yield terminator with a new one that returns the results
// from the masked operation.
OpBuilder opBuilder(builder.getContext());
- Operation *maskedOp = ®ion.front().front();
- Operation *oldYieldOp = ®ion.front().back();
+ Operation *maskedOp = &block.front();
+ Operation *oldYieldOp = &block.back();
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+ // Empty vector.mask op.
+ if (maskedOp == oldYieldOp)
+ return;
+
opBuilder.setInsertionPoint(oldYieldOp);
- opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
+ opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
oldYieldOp->dropAllReferences();
oldYieldOp->erase();
+ return;
}
LogicalResult MaskOp::verify() {
// Structural checks.
Block &block = getMaskRegion().getBlocks().front();
- if (block.getOperations().size() < 2)
- return emitOpError("expects an operation to mask");
+ if (block.getOperations().size() < 1)
+ return emitOpError("expects a terminator within the mask region");
if (block.getOperations().size() > 2)
return emitOpError("expects only one operation to mask");
+ // Terminator checks.
+ auto terminator = dyn_cast<vector::YieldOp>(block.back());
+ if (!terminator)
+ return emitOpError("expects a terminator within the mask region");
+
+ if (terminator->getNumOperands() != getNumResults())
+ return emitOpError(
+ "expects number of results to match mask region yielded values");
+
auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+ // Empty vector.mask. Nothing else to check.
if (!maskableOp)
- return emitOpError("expects a maskable operation");
+ return success();
// Result checks.
if (maskableOp->getNumResults() != getNumResults())
@@ -5545,10 +5591,47 @@ LogicalResult MaskOp::verify() {
return success();
}
+// Elides empty vector.mask operations with or without return values. Propagates
+// the yielded values by the vector.yield terminator, if any, or erases the op,
+// otherwise.
+class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MaskOp maskOp,
+ PatternRewriter &rewriter) const override {
+ auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
+ if (maskingOp.getMaskableOp())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ if (block->getOperations().size() > 1)
+ return failure();
+
+ auto terminator = cast<vector::YieldOp>(block->front());
+ if (terminator.getNumOperands() == 0)
+ rewriter.eraseOp(maskOp);
+ else
+ rewriter.replaceOp(maskOp, terminator.getOperands());
+
+ return success();
+ }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ElideEmptyMaskOp>(context);
+}
+
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
-Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
+Operation *MaskOp::getMaskableOp() {
+ Block *block = getMaskBlock();
+ if (block->getOperations().size() < 2)
+ return nullptr;
+
+ return &block->front();
+}
/// Returns true if 'vector.mask' has a passthru value.
bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index cac24b3961363..053e3620cab2e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1372,6 +1372,16 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
// -----
+// CHECK-LABEL: func @masked_vector_multi_reduction_single_parallel(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %{{.*}}: vector<2xf32>,
+func.func @masked_vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>, %mask: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32> } : vector<2xi1> -> vector<2xf32>
+// CHECK: return %[[VAL_0]] : vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
@@ -1385,14 +1395,17 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
// -----
-// Masked reduction can't be folded.
-
// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
+// CHECK-SAME: %[[VAL_0:.*]]: vector<5x1x4x1x20xf32>, %[[VAL_1:.*]]: vector<5x4x20xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<5x1x4x1x20xi1>)
func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
%acc: vector<5x4x20xf32>,
%mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
-// CHECK: vector.mask %{{.*}} { vector.multi_reduction <mul>
- %0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<5x1x4x1x20xi1> to vector<5x4x20xi1>
+// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
+// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : vector<5x4x20xf32>
+// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<5x4x20xi1>, vector<5x4x20xf32>
+%0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
return %0 : vector<5x4x20xf32>
}
@@ -1424,6 +1437,20 @@ func.func @vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x
// -----
+// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions_single_elem(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x1xf32>, %[[VAL_1:.*]]: f32,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x1x1xi1>)
+func.func @masked_vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x1x1xf32>, %acc: f32, %mask: vector<1x1x1xi1>) -> f32 {
+ // CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0, 0, 0] : vector<1x1x1xi1>
+ // CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0, 0, 0] : vector<1x1x1xf32>
+ // CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : f32
+ // CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : f32
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [0,1,2] : vector<1x1x1xf32> to f32 } : vector<1x1x1xi1> -> f32
+ return %0 : f32
+}
+
+// -----
+
// CHECK-LABEL: func @insert_strided_slice_full_range
// CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {
@@ -1937,6 +1964,17 @@ func.func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 {
// -----
+// CHECK-LABEL: func @masked_reduce_one_element_vector_extract
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: vector<1xi1>)
+func.func @masked_reduce_one_element_vector_extract(%a : vector<1xf32>, %mask : vector<1xi1>) -> f32 {
+// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
+ %s = vector.mask %mask { vector.reduction <add>, %a : vector<1xf32> into f32 }
+ : vector<1xi1> -> f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduce_one_element_vector_addf
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
@@ -1950,10 +1988,15 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
-// CHECK: vector.mask %{{.*}} { vector.reduction <add>
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
%b: f32,
%mask: vector<1xi1>) -> f32 {
+// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0] : vector<1xi1>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
+// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_1]] : f32
+// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_1]] : f32
%s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
: vector<1xi1> -> f32
return %s : f32
@@ -2167,3 +2210,25 @@ func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask
+func.func @empty_vector_mask(%mask : vector<8xi1>) {
+// CHECK-NOT: vector.mask
+ vector.mask %mask { } : vector<8xi1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @empty_vector_mask_with_return
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: return %[[IN]] : vector<8xf32>
+ %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index adb524e3b7e0d..75ccde168c994 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1604,13 +1604,6 @@ func.func @warp_mismatch_rank(%laneid: index) {
// -----
-func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 {
- // expected-error at +1 {{'vector.mask' op expects an operation to mask}}
- vector.mask %m0 { } : vector<16xi1>
-}
-
-// -----
-
func.func @vector_mask_multiple_ops(%t0: tensor<?xf32>, %t1: tensor<?xf32>, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) {
%ft0 = arith.constant 0.0 : f32
// expected-error at +1 {{'vector.mask' op expects only one operation to mask}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 53a836466b5ad..60e1507293f7e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -860,6 +860,27 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
return
}
+// CHECK-LABEL: func @vector_mask_empty
+func.func @vector_mask_empty(%m0: vector<16xi1>) {
+// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
+ vector.mask %m0 { } : vector<16xi1>
+ return
+}
+
+// CHECK-LABEL: func @vector_mask_empty_with_yield
+func.func @vector_mask_empty_with_yield(%m0: vector<16xi1>) {
+// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1>
+ vector.mask %m0 { vector.yield } : vector<16xi1>
+ return
+}
+
+// CHECK-LABEL: func @vector_mask_empty_return
+func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -> vector<16xf32> {
+// CHECK: vector.mask %{{.*}} { vector.yield {{.*}} : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
+ %0 = vector.mask %m0 { vector.yield %arg0 : vector<16xf32> } : vector<16xi1> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
More information about the Mlir-commits
mailing list