[Mlir-commits] [mlir] 3941355 - [mlir][vector] Support 0-D vector when eliding single element reduction

Kai Sasaki llvmlistbot at llvm.org
Tue Feb 7 19:12:02 PST 2023


Author: Kai Sasaki
Date: 2023-02-08T12:01:56+09:00
New Revision: 3941355d8fee763e99c259ecd02f6fe567583296

URL: https://github.com/llvm/llvm-project/commit/3941355d8fee763e99c259ecd02f6fe567583296
DIFF: https://github.com/llvm/llvm-project/commit/3941355d8fee763e99c259ecd02f6fe567583296.diff

LOG: [mlir][vector] Support 0-D vector when eliding single element reduction

ElideSingleElementReduction causes assertion failure when we give 0-D vector. It's possible to fold the case by using vector.extractelement op instead. It's originally reported in https://github.com/llvm/llvm-project/issues/60193.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D143242

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 32ae7b1017e8..8073757a3042 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -530,13 +530,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     if (maskableOp.isMasked())
       return failure();
 
-    if (reductionOp.getVectorType().getDimSize(0) != 1)
+    auto vectorType = reductionOp.getVectorType();
+    if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
       return failure();
 
     Location loc = reductionOp.getLoc();
-    Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
-                                              reductionOp.getVector(),
-                                              rewriter.getI64ArrayAttr(0));
+    Value result;
+    if (vectorType.getRank() == 0) {
+      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+    } else {
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
+                                          reductionOp.getVector(),
+                                          rewriter.getI64ArrayAttr(0));
+    }
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8fc1834ec6aa..cac24b396136 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2157,3 +2157,13 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
   %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
   return %1 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_0d_vector_reduction
+func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
+  // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+  // CHECK-NEXT: return %[[RES]] : f32
+  %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+  return %0 : f32
+}


        


More information about the Mlir-commits mailing list