[llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)

Chengji Yao via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 11:40:19 PST 2024


================
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+  return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType) {
+  Value one = builder.create<ConstantOp>(loc, resultType,
+                                         builder.getOneAttr(resultType));
+  ArithBuilder arithBuilder(builder, loc);
+  return std::accumulate(
+      values.begin(), values.end(), one,
+      [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
----------------
yaochengji wrote:

nit: new line

https://github.com/llvm/llvm-project/pull/81218


More information about the llvm-commits mailing list