[Mlir-commits] [mlir] 4626bd0 - [mlir][vector] Disable folding for masked reductions
Diego Caballero
llvmlistbot at llvm.org
Thu Jan 19 15:07:30 PST 2023
Author: Diego Caballero
Date: 2023-01-19T23:06:53Z
New Revision: 4626bd0b91102af234125f7f8ff0daaffb7a1fa4
URL: https://github.com/llvm/llvm-project/commit/4626bd0b91102af234125f7f8ff0daaffb7a1fa4
DIFF: https://github.com/llvm/llvm-project/commit/4626bd0b91102af234125f7f8ff0daaffb7a1fa4.diff
LOG: [mlir][vector] Disable folding for masked reductions
Reductions can't be folded into plain arith ops until we can mask
those arith ops.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D141645
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 777133df07e39..974854e5637f5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -361,6 +361,12 @@ struct ElideUnitDimsInMultiDimReduction
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
+ // Masked reductions can't be folded until we can propagate the mask to the
+ // resulting operation.
+ auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
for (const auto &dim : enumerate(shape)) {
if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
@@ -518,6 +524,12 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
LogicalResult matchAndRewrite(ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
+ // Masked reductions can't be folded until we can propagate the mask to the
+ // resulting operation.
+ auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
if (reductionOp.getVectorType().getDimSize(0) != 1)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 081d94e5805fc..fb890b6a0f44a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1371,6 +1371,20 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
// -----
+// Masked reduction can't be folded.
+
+// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
+func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
+ %acc: vector<5x4x20xf32>,
+ %mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
+// CHECK: vector.mask %{{.*}} { vector.multi_reduction <mul>
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
+ vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
+ return %0 : vector<5x4x20xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
// CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x20xf32>
func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x20xf32>) -> vector<5x1x20xf32> {
@@ -1921,6 +1935,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
+// CHECK: vector.mask %{{.*}} { vector.reduction <add>
+func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
+ %b: f32,
+ %mask: vector<1xi1>) -> f32 {
+ %s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
+ : vector<1xi1> -> f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduce_one_element_vector_mulf
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
More information about the Mlir-commits
mailing list