[Mlir-commits] [mlir] 3941355 - [mlir][vector] Support 0-D vector when eliding single element reduction
Kai Sasaki
llvmlistbot at llvm.org
Tue Feb 7 19:12:02 PST 2023
Author: Kai Sasaki
Date: 2023-02-08T12:01:56+09:00
New Revision: 3941355d8fee763e99c259ecd02f6fe567583296
URL: https://github.com/llvm/llvm-project/commit/3941355d8fee763e99c259ecd02f6fe567583296
DIFF: https://github.com/llvm/llvm-project/commit/3941355d8fee763e99c259ecd02f6fe567583296.diff
LOG: [mlir][vector] Support 0-D vector when eliding single element reduction
ElideSingleElementReduction causes assertion failure when we give 0-D vector. It's possible to fold the case by using vector.extractelement op instead. It's originally reported in https://github.com/llvm/llvm-project/issues/60193.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D143242
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 32ae7b1017e8..8073757a3042 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -530,13 +530,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (maskableOp.isMasked())
return failure();
- if (reductionOp.getVectorType().getDimSize(0) != 1)
+ auto vectorType = reductionOp.getVectorType();
+ if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
return failure();
Location loc = reductionOp.getLoc();
- Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
- reductionOp.getVector(),
- rewriter.getI64ArrayAttr(0));
+ Value result;
+ if (vectorType.getRank() == 0) {
+ result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+ } else {
+ result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
+ reductionOp.getVector(),
+ rewriter.getI64ArrayAttr(0));
+ }
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8fc1834ec6aa..cac24b396136 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2157,3 +2157,13 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
%1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
return %1 : f32
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_0d_vector_reduction
+func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
+ // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+ // CHECK-NEXT: return %[[RES]] : f32
+ %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+ return %0 : f32
+}
More information about the Mlir-commits
mailing list