[Mlir-commits] [mlir] 851d02f - Enhance InferShapedTypeOpInterface to make it accessible during dialect conversion

Mehdi Amini llvmlistbot at llvm.org
Tue May 18 19:51:25 PDT 2021


Author: Wenyi Zhao
Date: 2021-05-19T02:51:14Z
New Revision: 851d02f61e945d335021858111416f444139e2b2

URL: https://github.com/llvm/llvm-project/commit/851d02f61e945d335021858111416f444139e2b2
DIFF: https://github.com/llvm/llvm-project/commit/851d02f61e945d335021858111416f444139e2b2.diff

LOG: Enhance InferShapedTypeOpInterface to make it accessible during dialect conversion

Original interfaces are not safe to be called during dialect conversion.
This is because some ops (e.g. `dynamic_reshape(input, target_shape)`)
depend on the values of their operands to calculate the output shape.
However the operands may be out of reach during dialect conversion (e.g.
converting from tensor world to buffer world). This patch provides a new
kind of interface which accpets user-provided operands to solve this
problem.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D102317

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index f56d8d6ef7e5..485eed6a4f86 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -107,11 +107,23 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
       Insert operations using the given OpBuilder that computes the
       result shape. Only one of this method or
       `reifyReturnTypeShapesPerResultDim` needs to be overriden by the
-      operation.
+      operation. This interface is supposed to be workable during dialect
+      conversion (e.g. convert from tensor world to buffer world),
+      where `getOperand` may be invalid. For example, some ops (e.g.
+      dynamic_reshape(input, target_shape)) may depend on their operands
+      to calculate the result shape. When the `matchAndRewrite ` method
+      of a conversion pattern is called, the operands of the op to convert
+      may have been converted into other types, which makes it invalid to
+      call the `getOperand` method of such op directly inside the
+      conversion pattern.  To solve this problem, this interface follows
+      the design of the conversion pattern, that is, accepting passed in
+      operands to avoid calling `getOperand` directly inside the interface
+      implementation.
       }],
       /*retTy=*/"::mlir::LogicalResult",
       /*methodName=*/"reifyReturnTypeShapes",
       /*args=*/(ins "::mlir::OpBuilder&":$builder,
+          "::mlir::ValueRange":$operands,
           "::mlir::SmallVectorImpl<Value> &":$reifiedReturnShapes),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{ return ::mlir::failure(); }]

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8125c053e394..bb12cad29844 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -734,8 +734,8 @@ static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
   // check if the op implements the first interface method or the second, and
   // get the value to use appropriately.
   SmallVector<Value> reifiedResultShapes;
-  if (succeeded(
-          shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
+  if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
+          builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
     if (reifiedResultShapes.size() <= resultNumber)
       return nullptr;
     Value resultShape = reifiedResultShapes[resultNumber];

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e996d0040e8d..79e43eeba099 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -748,9 +748,10 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
 }
 
 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
-    OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
+    OpBuilder &builder, ValueRange operands,
+    llvm::SmallVectorImpl<Value> &shapes) {
   shapes = SmallVector<Value, 1>{
-      builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)};
+      builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)};
   return success();
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index cb5546874d68..32f47f40ad36 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -137,7 +137,7 @@ static void reifyReturnShape(Operation *op) {
   // Use permutations of 2 args as operands.
   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
   SmallVector<Value, 2> shapes;
-  if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) ||
+  if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
       !llvm::hasSingleElement(shapes))
     return;
   for (auto it : llvm::enumerate(shapes)) {


        


More information about the Mlir-commits mailing list