[Mlir-commits] [mlir] 5f8cefe - [mlir][vector] Fix crash in vector.reduction canonicalization
Thomas Raoux
llvmlistbot at llvm.org
Tue Jul 12 16:15:40 PDT 2022
Author: Thomas Raoux
Date: 2022-07-12T23:15:30Z
New Revision: 5f8cefebd900bbbd96961162ed9b80056e2ab95f
URL: https://github.com/llvm/llvm-project/commit/5f8cefebd900bbbd96961162ed9b80056e2ab95f
DIFF: https://github.com/llvm/llvm-project/commit/5f8cefebd900bbbd96961162ed9b80056e2ab95f.diff
LOG: [mlir][vector] Fix crash in vector.reduction canonicalization
since vector.reduce support accumulator in all the cases remove the
assert assuming old definition.
Differential Revision: https://reviews.llvm.org/D129602
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 24c2ff5f636d9..d51c5592ee3bc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
/// memory.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+ Value v1, Value v2);
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b84f1e28cda..b5e6bc1ae5747 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -34,11 +34,6 @@ namespace vector {
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
-
-/// Return the result value of reducing two scalar/vector values with the
-/// corresponding arith operation.
-Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value v2);
} // namespace vector
/// Return the number of elements of basis, `0` if empty.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f803868c2150d..c50359af87b06 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
- if (Value acc = reductionOp.getAcc()) {
- assert(reductionOp.getType().isa<FloatType>());
- switch (reductionOp.getKind()) {
- case CombiningKind::ADD:
- result = rewriter.create<arith::AddFOp>(loc, result, acc);
- break;
- case CombiningKind::MUL:
- result = rewriter.create<arith::MulFOp>(loc, result, acc);
- break;
- default:
- assert(false && "invalid op!");
- }
- }
+ if (Value acc = reductionOp.getAcc())
+ result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
+ result, acc);
rewriter.replaceOp(reductionOp, result);
return success();
@@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+ CombiningKind kind, Value v1, Value v2) {
+ Type t1 = getElementTypeOrSelf(v1.getType());
+ Type t2 = getElementTypeOrSelf(v2.getType());
+ switch (kind) {
+ case CombiningKind::ADD:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for ADD reduction");
+ case CombiningKind::AND:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+ case CombiningKind::MAXF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+ case CombiningKind::MINF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+ case CombiningKind::MAXSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+ case CombiningKind::MINSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+ case CombiningKind::MAXUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+ case CombiningKind::MINUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+ case CombiningKind::MUL:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for MUL reduction");
+ case CombiningKind::OR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+ case CombiningKind::XOR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+ };
+ llvm_unreachable("unknown CombiningKind");
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e6d56aa622e7..b979033ab4716 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
-Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
- CombiningKind kind, Value v1, Value v2) {
- Type t1 = getElementTypeOrSelf(v1.getType());
- Type t2 = getElementTypeOrSelf(v2.getType());
- switch (kind) {
- case CombiningKind::ADD:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::AddIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::AddFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for ADD reduction");
- case CombiningKind::AND:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::AndIOp>(loc, v1, v2);
- case CombiningKind::MAXF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
- "expected float values");
- return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
- case CombiningKind::MINF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
- "expected float values");
- return b.createOrFold<arith::MinFOp>(loc, v1, v2);
- case CombiningKind::MAXSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
- case CombiningKind::MINSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
- case CombiningKind::MAXUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
- case CombiningKind::MINUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
- case CombiningKind::MUL:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::MulIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::MulFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for MUL reduction");
- case CombiningKind::OR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::OrIOp>(loc, v1, v2);
- case CombiningKind::XOR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
- };
- llvm_unreachable("unknown CombiningKind");
-}
-
/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 702670095c8d5..54025a626f002 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @bitcast(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>
More information about the Mlir-commits
mailing list