[Mlir-commits] [mlir] 92a836d - [mlir] Attach InferTypeOpInterface on SameOperandsAndResultType operations when possible

River Riddle llvmlistbot at llvm.org
Thu Apr 28 12:58:40 PDT 2022


Author: River Riddle
Date: 2022-04-28T12:57:59-07:00
New Revision: 92a836da07596a9e409c3b4231fe727e0924d0e4

URL: https://github.com/llvm/llvm-project/commit/92a836da07596a9e409c3b4231fe727e0924d0e4
DIFF: https://github.com/llvm/llvm-project/commit/92a836da07596a9e409c3b4231fe727e0924d0e4.diff

LOG: [mlir] Attach InferTypeOpInterface on SameOperandsAndResultType operations when possible

This allows for inferring the result types of operations in certain situations by using the type of
an operand. This commit allowed for automatically supporting type inference for many more
operations with no additional effort, e.g. nearly all Arithmetic operations now support
result type inferrence with no additional changes.

Differential Revision: https://reviews.llvm.org/D124581

Added: 
    

Modified: 
    mlir/examples/standalone/include/Standalone/StandaloneOps.h
    mlir/examples/standalone/lib/Standalone/CMakeLists.txt
    mlir/include/mlir/Dialect/Math/IR/Math.h
    mlir/include/mlir/Dialect/Quant/QuantOps.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/lib/Dialect/Quant/IR/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
    mlir/lib/TableGen/Operator.cpp
    mlir/test/Analysis/test-shape-fn-report.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/unittests/TableGen/OpBuildGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/standalone/include/Standalone/StandaloneOps.h b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
index a56c2867b1c8b..a9006e0de3813 100644
--- a/mlir/examples/standalone/include/Standalone/StandaloneOps.h
+++ b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt
index c8b16b7ae4aba..eadc695d39f35 100644
--- a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt
+++ b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRStandalone
 
 	LINK_LIBS PUBLIC
 	MLIRIR
+        MLIRInferTypeOpInterface
 	)

diff  --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h
index b300dd89eefa4..6af358bf57b37 100644
--- a/mlir/include/mlir/Dialect/Math/IR/Math.h
+++ b/mlir/include/mlir/Dialect/Math/IR/Math.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/QuantOps.h
index 9dc6b1ddd5e0d..14fb3035ab0d3 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 52539c469c60a..ec40a6c628883 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt
index 6115c51fa178f..7b871df5bc144 100644
--- a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRQuant
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRInferTypeOpInterface
   MLIRSideEffectInterfaces
   MLIRSupport
   )

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index 6b94ee010b7cd..25b58a51357ae 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -11,5 +11,6 @@ add_mlir_dialect_library(MLIRSparseTensor
   LINK_LIBS PUBLIC
   MLIRDialect
   MLIRIR
+  MLIRInferTypeOpInterface
   MLIRSupport
   )

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 35afb8d7f6946..b3b6d36ee397e 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -333,8 +333,25 @@ void Operator::populateTypeInferenceInfo(
 
   // Skip cases currently being custom generated.
   // TODO: Remove special cases.
-  if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
+  if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
+    // Check for a non-variable length operand to use as the type anchor.
+    auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
+      NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
+      return operand && !operand->isVariableLength();
+    });
+    if (operandI == arguments.end())
+      return;
+
+    // Map each of the result types to the anchor operation.
+    int operandIdx = operandI - arguments.begin();
+    resultTypeMapping.resize(getNumResults());
+    for (int i = 0; i < getNumResults(); ++i)
+      resultTypeMapping[i].emplace_back(operandIdx);
+
+    allResultsHaveKnownTypes = true;
+    traits.push_back(Trait::create(inferTrait->getDefInit()));
     return;
+  }
 
   // We create equivalence classes of argument/result types where arguments
   // and results are mapped into the same index space and indices corresponding

diff  --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir
index 54d97483b14dc..fbf5c47c597a7 100644
--- a/mlir/test/Analysis/test-shape-fn-report.mlir
+++ b/mlir/test/Analysis/test-shape-fn-report.mlir
@@ -5,9 +5,9 @@ module attributes {shape.lib = [@shape_lib]} {
 // expected-remark at +1 {{associated shape function: same_result_shape}}
 func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
     attributes {shape.function = @shape_lib::@same_result_shape} {
-  // expected-remark at +1 {{no associated way}}
+  // expected-remark at +1 {{implements InferType op interface}}
   %0 = math.tanh %arg : tensor<10x20xf32>
-  // expected-remark at +1 {{associated shape function: same_result_shape}}
+  // expected-remark at +1 {{implements InferType op interface}}
   %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
   return %1 : tensor<10x20xf32>
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8fd6fd6035a6d..5ace00169546f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2608,15 +2608,9 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
    }];
 }
 
-// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
-// Tests suppression of ambiguous build methods for operations with
-// SameOperandsAndResultType and InferTypeOpInterface.
-def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
-    "tblgen_build_5", [SameOperandsAndResultType]>;
-
 // Op with InferTypeOpInterface and regions.
-def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
-    "tblgen_build_6", [InferTypeOpInterface]> {
+def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
+    "tblgen_build_5", [InferTypeOpInterface]> {
   let regions = (region AnyRegion:$body);
 }
 

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index ada37117e73a7..9ca8efacdf624 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -199,7 +199,7 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
   let results = (outs AnyType:$b);
 }
 
-// CHECK_LABEL: class NS_HCollectiveParamsOp :
+// CHECK_LABEL: class HCollectiveParamsOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
@@ -212,7 +212,7 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
   let results = (outs Variadic<I32>:$b);
 }
 
-// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
+// CHECK_LABEL: class HCollectiveParamsSuppress0Op :
 // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
@@ -224,7 +224,7 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
   let results = (outs I32:$b);
 }
 
-// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
+// CHECK_LABEL: class HCollectiveParamsSuppress1Op :
 // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
@@ -237,7 +237,7 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
   let arguments = (ins Variadic<I32>:$a);
   let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
 }
-// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
+// CHECK_LABEL: class HCollectiveParamsSuppress2Op :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
 // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@@ -247,11 +247,11 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
   let arguments = (ins AnyType:$a, AnyType:$b);
   let results = (outs AnyType:$r);
 }
-// CHECK_LABEL: class NS_IOp :
+// CHECK_LABEL: class IOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 // Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder
@@ -259,7 +259,7 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
   let arguments = (ins AnyType:$a, AnyType:$b);
   let results = (outs AnyType:$r);
 }
-// CHECK_LABEL: class NS_JOp :
+// CHECK_LABEL: class JOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
@@ -292,14 +292,14 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
   let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1);
   let results = (outs AnyType:$r);
 }
-// CHECK_LABEL: class NS_LOp :
+// CHECK_LABEL: class LOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index d4d8746ac4416..5955dcd0c4187 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -27,7 +27,12 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
 // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x)
 // CHECK:   odsState.addTypes(y);
 // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x)
-// CHECK:   odsState.addTypes({x.getType()});
+// CHECK:   ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
+// CHECK:   if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(),
+// CHECK:             odsState.location, odsState.operands,
+// CHECK:             odsState.attributes.getDictionary(odsState.getContext()),
+// CHECK:             /*regions=*/{}, inferredReturnTypes)))
+// CHECK:     odsState.addTypes(inferredReturnTypes);
 
 def OpC : NS_Op<"three_normal_result_op", []> {
   let results = (outs I32:$x, /*unnamed*/I32, I32:$z);

diff  --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 9b985e32062e5..cf85762f2ef41 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -204,7 +204,7 @@ TEST_F(OpBuildGenTest,
   verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
 }
 
-// The next 2 tests test supression of ambiguous build methods for ops that
+// The next test checks supression of ambiguous build methods for ops that
 // have a single variadic input, and single non-variadic result, and which
 // support the SameOperandsAndResultType trait and and optionally the
 // InferOpTypeInterface interface. For such ops, the ODS framework generates
@@ -213,14 +213,8 @@ TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
   testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
 }
 
-TEST_F(
-    OpBuildGenTest,
-    BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
-  testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
-}
-
 TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
-  auto op = builder.create<test::TableGenBuildOp6>(
+  auto op = builder.create<test::TableGenBuildOp5>(
       loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
   ASSERT_EQ(op->getNumRegions(), 1u);
   verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);


        


More information about the Mlir-commits mailing list