[Mlir-commits] [mlir] 5f8cefe - [mlir][vector] Fix crash in vector.reduction canonicalization

Thomas Raoux llvmlistbot at llvm.org
Tue Jul 12 16:15:40 PDT 2022


Author: Thomas Raoux
Date: 2022-07-12T23:15:30Z
New Revision: 5f8cefebd900bbbd96961162ed9b80056e2ab95f

URL: https://github.com/llvm/llvm-project/commit/5f8cefebd900bbbd96961162ed9b80056e2ab95f
DIFF: https://github.com/llvm/llvm-project/commit/5f8cefebd900bbbd96961162ed9b80056e2ab95f.diff

LOG: [mlir][vector] Fix crash in vector.reduction canonicalization

since vector.reduce support accumulator in all the cases remove the
assert assuming old definition.

Differential Revision: https://reviews.llvm.org/D129602

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 24c2ff5f636d9..d51c5592ee3bc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
 /// memory.
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+                         Value v1, Value v2);
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b84f1e28cda..b5e6bc1ae5747 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -34,11 +34,6 @@ namespace vector {
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
-
-/// Return the result value of reducing two scalar/vector values with the
-/// corresponding arith operation.
-Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value v2);
 } // namespace vector
 
 /// Return the number of elements of basis, `0` if empty.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f803868c2150d..c50359af87b06 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
                                               reductionOp.getVector(),
                                               rewriter.getI64ArrayAttr(0));
 
-    if (Value acc = reductionOp.getAcc()) {
-      assert(reductionOp.getType().isa<FloatType>());
-      switch (reductionOp.getKind()) {
-      case CombiningKind::ADD:
-        result = rewriter.create<arith::AddFOp>(loc, result, acc);
-        break;
-      case CombiningKind::MUL:
-        result = rewriter.create<arith::MulFOp>(loc, result, acc);
-        break;
-      default:
-        assert(false && "invalid op!");
-      }
-    }
+    if (Value acc = reductionOp.getAcc())
+      result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
+                                          result, acc);
 
     rewriter.replaceOp(reductionOp, result);
     return success();
@@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+                                       CombiningKind kind, Value v1, Value v2) {
+  Type t1 = getElementTypeOrSelf(v1.getType());
+  Type t2 = getElementTypeOrSelf(v2.getType());
+  switch (kind) {
+  case CombiningKind::ADD:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for ADD reduction");
+  case CombiningKind::AND:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+  case CombiningKind::MAXF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+  case CombiningKind::MINF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+  case CombiningKind::MAXSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+  case CombiningKind::MINSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+  case CombiningKind::MAXUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+  case CombiningKind::MINUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+  case CombiningKind::MUL:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for MUL reduction");
+  case CombiningKind::OR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+  case CombiningKind::XOR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+  };
+  llvm_unreachable("unknown CombiningKind");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e6d56aa622e7..b979033ab4716 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
   llvm_unreachable("Expected MemRefType or TensorType");
 }
 
-Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
-                                       CombiningKind kind, Value v1, Value v2) {
-  Type t1 = getElementTypeOrSelf(v1.getType());
-  Type t2 = getElementTypeOrSelf(v2.getType());
-  switch (kind) {
-  case CombiningKind::ADD:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for ADD reduction");
-  case CombiningKind::AND:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
-  case CombiningKind::MAXF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
-           "expected float values");
-    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
-  case CombiningKind::MINF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
-           "expected float values");
-    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
-  case CombiningKind::MAXSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
-  case CombiningKind::MINSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
-  case CombiningKind::MAXUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
-  case CombiningKind::MINUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
-  case CombiningKind::MUL:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for MUL reduction");
-  case CombiningKind::OR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
-  case CombiningKind::XOR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
-  };
-  llvm_unreachable("unknown CombiningKind");
-}
-
 /// Return the number of elements of basis, `0` if empty.
 int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   if (basis.empty())

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 702670095c8d5..54025a626f002 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @bitcast(
 //  CHECK-SAME:               %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
 //       CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>


        


More information about the Mlir-commits mailing list