[Mlir-commits] [mlir] e706533 - [mlir] Add reifyReturnShape to shaped type OpInterface

Jacques Pienaar llvmlistbot at llvm.org
Fri Feb 28 08:41:45 PST 2020


Author: Jacques Pienaar
Date: 2020-02-28T08:41:18-08:00
New Revision: e706533f0a52b2eb8929d4004d1daccd53feda29

URL: https://github.com/llvm/llvm-project/commit/e706533f0a52b2eb8929d4004d1daccd53feda29
DIFF: https://github.com/llvm/llvm-project/commit/e706533f0a52b2eb8929d4004d1daccd53feda29.diff

LOG: [mlir] Add reifyReturnShape to shaped type OpInterface

This call results in inserting operations that compute the return shape
dynamically for the operation.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/InferTypeOpInterface.h
    mlir/include/mlir/Analysis/InferTypeOpInterface.td
    mlir/test/lib/TestDialect/TestDialect.cpp
    mlir/test/lib/TestDialect/TestPatterns.cpp
    mlir/test/mlir-tblgen/return-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h
index 2a64983ad8c2..4c2628512cf4 100644
--- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h
@@ -15,6 +15,7 @@
 #define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"

diff  --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
index 621d586bd3e0..548cd09a14c3 100644
--- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -97,6 +97,18 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
                     "SmallVectorImpl<ShapedTypeComponents>&":
                       $inferedReturnShapes)
     >,
+    InterfaceMethod<
+      /*desc=*/[{Reify the shape computation for the operation.
+
+      Insert operations using the given OpBulder that computes the result shape.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"reifyReturnTypeShapes",
+      /*args=*/(ins "OpBuilder&":$builder,
+                    "SmallVectorImpl<Value>&":$reifiedReturnShapes),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{ return failure(); }]
+    >,
   ];
 }
 

diff  --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 330b8041afdc..12ec279c1d67 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
@@ -312,24 +313,24 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
     MLIRContext *context, Optional<Location> location, ValueRange operands,
     ArrayRef<NamedAttribute> attributes, RegionRange regions,
-    SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
-  // Create return type consisting of the first element of each shape of the
-  // input operands or unknown for unranked operand.
-  std::vector<int64_t> shape;
-  shape.reserve(operands.size());
-  for (auto operandType : operands.getTypes()) {
-    if (auto sval = operandType.dyn_cast<ShapedType>()) {
-      if (sval.hasRank())
-        shape.push_back(sval.getShape().front());
-      else
-        shape.push_back(ShapedType::kDynamicSize);
-    } else {
-      return emitOptionalError(location, "only shaped type operands allowed");
-    }
+    SmallVectorImpl<ShapedTypeComponents> &inferedReturnShapes) {
+  // Create return type consisting of the last element of the first operand.
+  auto operandType = *operands.getTypes().begin();
+  auto sval = operandType.dyn_cast<ShapedType>();
+  if (!sval) {
+    return emitOptionalError(location, "only shaped type operands allowed");
   }
-  inferedComponents.reserve(1);
+  int64_t dim =
+      sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
   auto type = IntegerType::get(17, context);
-  inferedComponents.emplace_back(shape, type);
+  inferedReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+  return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+    OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
+  shapes = SmallVector<Value, 1>{
+      builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
   return success();
 }
 

diff  --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index f89987610c99..decb5e246a81 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -82,6 +82,19 @@ static void invokeCreateWithInferedReturnType(Operation *op) {
   }
 }
 
+static void reifyReturnShape(Operation *op) {
+  OpBuilder b(op);
+
+  // Use permutations of 2 args as operands.
+  auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
+  SmallVector<Value, 2> shapes;
+  if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
+    return;
+  for (auto it : llvm::enumerate(shapes))
+    op->emitRemark() << "value " << it.index() << ": "
+                     << it.value().getDefiningOp();
+}
+
 struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
   void runOnFunction() override {
     if (getFunction().getName() == "testCreateFunctions") {
@@ -100,6 +113,16 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
       };
       return;
     }
+    if (getFunction().getName() == "testReifyFunctions") {
+      std::vector<Operation *> ops;
+      // Collect ops to avoid triggering on inserted ops.
+      for (auto &op : getFunction().getBody().front())
+        if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
+          ops.push_back(&op);
+      // Generate test patterns for each, but skip terminator.
+      for (auto *op : ops)
+        reifyReturnShape(op);
+    }
   }
 };
 } // end anonymous namespace

diff  --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index 3fcb22331fa1..d0eb364a6a9d 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -7,13 +7,13 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
 // CHECK: "test.no_attributes"
   %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
 // CHECK: "test.op_with_infer_type_if"
 // CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 // CHECK: "test.op_with_infer_type_if"
@@ -36,3 +36,14 @@ func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<2
   %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: testReifyFunctions
+func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
+  // expected-remark at +1 {{constant 10}}
+  %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17>
+  // expected-remark at +1 {{constant 20}}
+  %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17>
+  return
+}


        


More information about the Mlir-commits mailing list