[Mlir-commits] [mlir] 0d988da - [MLIR] Change ODS collective params build method to provide an empty default value for named attributes

Rahul Joshi llvmlistbot at llvm.org
Mon Jul 13 13:36:30 PDT 2020


Author: Rahul Joshi
Date: 2020-07-13T13:35:44-07:00
New Revision: 0d988da6d13e16a397d58bc3b965a36adb7fee03

URL: https://github.com/llvm/llvm-project/commit/0d988da6d13e16a397d58bc3b965a36adb7fee03
DIFF: https://github.com/llvm/llvm-project/commit/0d988da6d13e16a397d58bc3b965a36adb7fee03.diff

LOG: [MLIR] Change ODS collective params build method to provide an empty default value for named attributes

- Provide default value for `ArrayRef<NamedAttribute> attributes` parameter of
  the collective params build method.
- Change the `genSeparateArgParamBuilder` function to not generate build methods
  that may be ambiguous with the new collective params build method.
- This change should help eliminate passing empty NamedAttribue ArrayRef when the
  collective params build method is used
- Extend op-decl.td unit test to make sure the ambiguous build methods are not
  generated.

Differential Revision: https://reviews.llvm.org/D83517

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index dad8bfc0173f..e59830fcef89 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -331,8 +331,7 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
       return operation.emitError(
           "bitwidth emulation is not implemented yet on unsigned op");
     }
-    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
-                                                  ArrayRef<NamedAttribute>());
+    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
     return success();
   }
 };
@@ -368,11 +367,11 @@ class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
     if (!dstType)
       return failure();
     if (isBoolScalarOrVector(operands.front().getType())) {
-      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
-          operation, dstType, operands, ArrayRef<NamedAttribute>());
+      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
+                                                           operands);
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
-          operation, dstType, operands, ArrayRef<NamedAttribute>());
+      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
+                                                           operands);
     }
     return success();
   }
@@ -529,8 +528,8 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(operation, operands.front());
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVOp>(
-          operation, dstType, operands, ArrayRef<NamedAttribute>());
+      rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
+                                                    operands);
     }
     return success();
   }
@@ -1046,8 +1045,7 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
   auto dstType = typeConverter.convertType(xorOp.getType());
   if (!dstType)
     return failure();
-  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands,
-                                                   ArrayRef<NamedAttribute>());
+  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
 
   return success();
 }

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 7aa26541ac27..c6a58a8dc5a8 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -418,8 +418,7 @@ Value Importer::processConstant(llvm::Constant *c) {
   }
   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
-                                      processGlobal(GV),
-                                      ArrayRef<NamedAttribute>());
+                                      processGlobal(GV));
 
   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
     llvm::Instruction *i = ce->getAsInstruction();
@@ -727,7 +726,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
       if (!calledValue)
         return failure();
       ops.insert(ops.begin(), calledValue);
-      op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
+      op = b.create<CallOp>(loc, tys, ops);
     }
     if (!ci->getType()->isVoidTy())
       v = op->getResult(0);
@@ -809,7 +808,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     Type type = processType(inst->getType());
     if (!type)
       return failure();
-    v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
+    v = b.create<GEPOp>(loc, type, ops);
     return success();
   }
   }

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index b596eee03829..f8ff60e35557 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -171,6 +171,56 @@ def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
 // CHECK-LABEL: class GOp :
 // CHECK: static ::mlir::LogicalResult inferReturnTypes
 
+// Check default value for collective params builder. Check that other builders
+// are generated as well.
+def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
+  let arguments = (ins AnyType:$a);
+  let results = (outs AnyType:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsOp :
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
+
+// Check suppression of "separate arg, separate result" build method for an op
+// with single variadic arg and single variadic result (since it will be
+// ambiguous with the collective params build method).
+def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
+  let arguments = (ins Variadic<I32>:$a);
+  let results = (outs Variadic<I32>:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
+// Check suppression of "separate arg, collective result" build method for an op
+// with single variadic arg and non variadic result (since it will be
+// ambiguous with the collective params build method).
+def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
+  let arguments = (ins Variadic<I32>:$a);
+  let results = (outs I32:$b);
+}
+
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
+// Check suppression of "separate arg, collective result" build method for an op
+// with single variadic arg and > 1 variadic result (since it will be
+// ambiguous with the collective params build method). Note that "separate arg,
+// separate result" build method should be generated in this case as its not
+// ambiguous with the collective params build method.
+def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVariadicResultSize]> {
+  let arguments = (ins Variadic<I32>:$a);
+  let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
+}
+// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a);
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+
 // Check that default builders can be suppressed.
 // ---
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b2b4245989b5..5e009e602524 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -955,14 +955,51 @@ void OpEmitter::genSeparateArgParamBuilder() {
     llvm_unreachable("unhandled TypeParamKind");
   };
 
+  // A separate arg param builder method will have a signature which is
+  // ambiguous with the collective params build method (generated in
+  // `genCollectiveParamBuilder` function below) if it has a single
+  // `ArrayReg<Type>` parameter for result types and a single `ArrayRef<Value>`
+  // parameter for the operands, no parameters after that, and the collective
+  // params build method has `attributes` as its last parameter (with
+  // a default value). This will happen when all of the following are true:
+  // 1. [`attributes` as last parameter in collective params build method]:
+  //    getNumVariadicRegions must be 0 (otherwise the collective params build
+  //    method ends with a `numRegions` param, and we don't specify default
+  //    value for attributes).
+  // 2. [single `ArrayRef<Value>` parameter for operands, and no parameters
+  //    after that]: numArgs() must be 1 (if not, each arg gets a separate param
+  //    in the build methods generated here) and the single arg must be a
+  //    non-attribute variadic argument.
+  // 3. [single `ArrayReg<Type>` parameter for result types]:
+  //      3a. paramKind should be Collective, or
+  //      3b. paramKind should be Separate and there should be a single variadic
+  //          result
+  //
+  // In that case, skip generating such ambiguous build methods here.
+  bool hasSingleVariadicResult =
+      op.getNumResults() == 1 && op.getResult(0).isVariadic();
+
+  bool hasSingleVariadicArg =
+      op.getNumArgs() == 1 &&
+      op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
+      op.getOperand(0).isVariadic();
+  bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
+
   for (auto attrType : attrBuilderType) {
-    emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
+    // Case 3b above.
+    if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
+          hasSingleVariadicResult))
+      emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
     if (canInferType(op))
       emit(attrType, TypeParamKind::None, /*inferType=*/true);
-    // Emit separate arg build with collective type, unless there is only one
-    // variadic result, in which case the above would have already generated
-    // the same build method.
-    if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength()))
+    // The separate arg + collective param kind method will be:
+    // (a) Same as the separate arg + separate param kind method if there is
+    //     only one variadic result.
+    // (b) Ambiguous with the collective params method under conditions in (3a)
+    //     above.
+    // In either case, skip generating such build method.
+    if (!hasSingleVariadicResult &&
+        !(hasNoVariadicRegions && hasSingleVariadicArg))
       emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
   }
 }
@@ -1184,8 +1221,12 @@ void OpEmitter::genCollectiveParamBuilder() {
       ", ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange "
       "operands, "
       "::llvm::ArrayRef<::mlir::NamedAttribute> attributes";
-  if (op.getNumVariadicRegions())
+  if (op.getNumVariadicRegions()) {
     params += ", unsigned numRegions";
+  } else {
+    // Provide default value for `attributes` since its the last parameter
+    params += " = {}";
+  }
   auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
   auto &body = m.body();
 


        


More information about the Mlir-commits mailing list