[Mlir-commits] [mlir] 0d988da - [MLIR] Change ODS collective params build method to provide an empty default value for named attributes
Rahul Joshi
llvmlistbot at llvm.org
Mon Jul 13 13:36:30 PDT 2020
Author: Rahul Joshi
Date: 2020-07-13T13:35:44-07:00
New Revision: 0d988da6d13e16a397d58bc3b965a36adb7fee03
URL: https://github.com/llvm/llvm-project/commit/0d988da6d13e16a397d58bc3b965a36adb7fee03
DIFF: https://github.com/llvm/llvm-project/commit/0d988da6d13e16a397d58bc3b965a36adb7fee03.diff
LOG: [MLIR] Change ODS collective params build method to provide an empty default value for named attributes
- Provide default value for `ArrayRef<NamedAttribute> attributes` parameter of
the collective params build method.
- Change the `genSeparateArgParamBuilder` function to not generate build methods
that may be ambiguous with the new collective params build method.
- This change should help eliminate passing empty NamedAttribue ArrayRef when the
collective params build method is used
- Extend op-decl.td unit test to make sure the ambiguous build methods are not
generated.
Differential Revision: https://reviews.llvm.org/D83517
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index dad8bfc0173f..e59830fcef89 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -331,8 +331,7 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
return operation.emitError(
"bitwidth emulation is not implemented yet on unsigned op");
}
- rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
- ArrayRef<NamedAttribute>());
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
return success();
}
};
@@ -368,11 +367,11 @@ class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
if (!dstType)
return failure();
if (isBoolScalarOrVector(operands.front().getType())) {
- rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
- operation, dstType, operands, ArrayRef<NamedAttribute>());
+ rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
+ operands);
} else {
- rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
- operation, dstType, operands, ArrayRef<NamedAttribute>());
+ rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
+ operands);
}
return success();
}
@@ -529,8 +528,8 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(operation, operands.front());
} else {
- rewriter.template replaceOpWithNewOp<SPIRVOp>(
- operation, dstType, operands, ArrayRef<NamedAttribute>());
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
+ operands);
}
return success();
}
@@ -1046,8 +1045,7 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
auto dstType = typeConverter.convertType(xorOp.getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands,
- ArrayRef<NamedAttribute>());
+ rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 7aa26541ac27..c6a58a8dc5a8 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -418,8 +418,7 @@ Value Importer::processConstant(llvm::Constant *c) {
}
if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
- processGlobal(GV),
- ArrayRef<NamedAttribute>());
+ processGlobal(GV));
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
llvm::Instruction *i = ce->getAsInstruction();
@@ -727,7 +726,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
if (!calledValue)
return failure();
ops.insert(ops.begin(), calledValue);
- op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
+ op = b.create<CallOp>(loc, tys, ops);
}
if (!ci->getType()->isVoidTy())
v = op->getResult(0);
@@ -809,7 +808,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
Type type = processType(inst->getType());
if (!type)
return failure();
- v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
+ v = b.create<GEPOp>(loc, type, ops);
return success();
}
}
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index b596eee03829..f8ff60e35557 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -171,6 +171,56 @@ def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
// CHECK-LABEL: class GOp :
// CHECK: static ::mlir::LogicalResult inferReturnTypes
+// Check default value for collective params builder. Check that other builders
+// are generated as well.
+def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
+ let arguments = (ins AnyType:$a);
+ let results = (outs AnyType:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsOp :
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
+
+// Check suppression of "separate arg, separate result" build method for an op
+// with single variadic arg and single variadic result (since it will be
+// ambiguous with the collective params build method).
+def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
+ let arguments = (ins Variadic<I32>:$a);
+ let results = (outs Variadic<I32>:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
+// Check suppression of "separate arg, collective result" build method for an op
+// with single variadic arg and non variadic result (since it will be
+// ambiguous with the collective params build method).
+def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
+ let arguments = (ins Variadic<I32>:$a);
+ let results = (outs I32:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
+// Check suppression of "separate arg, collective result" build method for an op
+// with single variadic arg and > 1 variadic result (since it will be
+// ambiguous with the collective params build method). Note that "separate arg,
+// separate result" build method should be generated in this case as its not
+// ambiguous with the collective params build method.
+def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVariadicResultSize]> {
+ let arguments = (ins Variadic<I32>:$a);
+ let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
+}
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a);
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
// Check that default builders can be suppressed.
// ---
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b2b4245989b5..5e009e602524 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -955,14 +955,51 @@ void OpEmitter::genSeparateArgParamBuilder() {
llvm_unreachable("unhandled TypeParamKind");
};
+ // A separate arg param builder method will have a signature which is
+ // ambiguous with the collective params build method (generated in
+ // `genCollectiveParamBuilder` function below) if it has a single
+ // `ArrayReg<Type>` parameter for result types and a single `ArrayRef<Value>`
+ // parameter for the operands, no parameters after that, and the collective
+ // params build method has `attributes` as its last parameter (with
+ // a default value). This will happen when all of the following are true:
+ // 1. [`attributes` as last parameter in collective params build method]:
+ // getNumVariadicRegions must be 0 (otherwise the collective params build
+ // method ends with a `numRegions` param, and we don't specify default
+ // value for attributes).
+ // 2. [single `ArrayRef<Value>` parameter for operands, and no parameters
+ // after that]: numArgs() must be 1 (if not, each arg gets a separate param
+ // in the build methods generated here) and the single arg must be a
+ // non-attribute variadic argument.
+ // 3. [single `ArrayReg<Type>` parameter for result types]:
+ // 3a. paramKind should be Collective, or
+ // 3b. paramKind should be Separate and there should be a single variadic
+ // result
+ //
+ // In that case, skip generating such ambiguous build methods here.
+ bool hasSingleVariadicResult =
+ op.getNumResults() == 1 && op.getResult(0).isVariadic();
+
+ bool hasSingleVariadicArg =
+ op.getNumArgs() == 1 &&
+ op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
+ op.getOperand(0).isVariadic();
+ bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
+
for (auto attrType : attrBuilderType) {
- emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
+ // Case 3b above.
+ if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
+ hasSingleVariadicResult))
+ emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType(op))
emit(attrType, TypeParamKind::None, /*inferType=*/true);
- // Emit separate arg build with collective type, unless there is only one
- // variadic result, in which case the above would have already generated
- // the same build method.
- if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength()))
+ // The separate arg + collective param kind method will be:
+ // (a) Same as the separate arg + separate param kind method if there is
+ // only one variadic result.
+ // (b) Ambiguous with the collective params method under conditions in (3a)
+ // above.
+ // In either case, skip generating such build method.
+ if (!hasSingleVariadicResult &&
+ !(hasNoVariadicRegions && hasSingleVariadicArg))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
@@ -1184,8 +1221,12 @@ void OpEmitter::genCollectiveParamBuilder() {
", ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange "
"operands, "
"::llvm::ArrayRef<::mlir::NamedAttribute> attributes";
- if (op.getNumVariadicRegions())
+ if (op.getNumVariadicRegions()) {
params += ", unsigned numRegions";
+ } else {
+ // Provide default value for `attributes` since its the last parameter
+ params += " = {}";
+ }
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
More information about the Mlir-commits
mailing list