[Mlir-commits] [mlir] 5c9c4ad - Add the inline interface to the shape dialect

Feng Liu llvmlistbot at llvm.org
Fri Aug 7 23:30:04 PDT 2020


Author: Feng Liu
Date: 2020-08-07T23:29:43-07:00
New Revision: 5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0

URL: https://github.com/llvm/llvm-project/commit/5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0
DIFF: https://github.com/llvm/llvm-project/commit/5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0.diff

LOG: Add the inline interface to the shape dialect

This patch also fixes a minor issue that shape.rank should allow
returning !shape.size. The dialect doc has such an example for
shape.rank.

Differential Revision: https://reviews.llvm.org/D85556

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 47c592e51a40..7cf02149fa25 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -59,6 +60,32 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// InlinerInterface
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class defines the interface for inlining shape dialect ops.
+struct ShapeInlinerInterface : public DialectInlinerInterface {
+  using DialectInlinerInterface::DialectInlinerInterface;
+
+  // Returns true if the given region 'src' can be inlined into the region
+  // 'dest' that is attached to an operation registered to the current dialect.
+  bool isLegalToInline(Region *dest, Region *src,
+                       BlockAndValueMapping &) const final {
+    return true;
+  }
+
+  // Returns true if the given operation 'op', that is registered to this
+  // dialect, can be inlined into the region 'dest' that is attached to an
+  // operation registered to the current dialect.
+  bool isLegalToInline(Operation *op, Region *dest,
+                       BlockAndValueMapping &) const final {
+    return true;
+  }
+};
+} // namespace
+
 void ShapeDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
@@ -66,6 +93,7 @@ void ShapeDialect::initialize() {
       >();
   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
            WitnessType>();
+  addInterfaces<ShapeInlinerInterface>();
   // Allow unknown operations during prototyping and testing. As the dialect is
   // still evolving it makes it simple to start with an unregistered ops and
   // try 
diff erent variants before actually defining the op.
@@ -640,11 +668,14 @@ struct RankShapeOfCanonicalizationPattern
         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
     if (!rankedTensorType)
       return failure();
-    assert(op.getType().isa<IndexType>() &&
-           "expected `rank(shape_of( ... )]` based on a shaped argument to "
-           "yield an index type");
     int64_t rank = rankedTensorType.getRank();
-    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
+    if (op.getType().isa<IndexType>()) {
+      rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
+    } else if (op.getType().isa<shape::SizeType>()) {
+      rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
+    } else {
+      return failure();
+    }
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 21c5a68c3adc..8a38b42c5a71 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -624,6 +624,18 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {
 
 // -----
 
+// Canonicalize `rank` when shape is derived from ranked tensor.
+// CHECK-LABEL: @canonicalize_rank
+func @canonicalize_rank_size(%arg : tensor<1x2x?xf32>) -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 3
+  // CHECK: return %[[RESULT]] : !shape.size
+  %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> !shape.shape
+  %rank = shape.rank %shape : !shape.shape -> !shape.size
+  return %rank : !shape.size
+}
+
+// -----
+
 // Do not canonicalize `rank` when shape is derived from unranked tensor.
 // CHECK-LABEL: @dont_canonicalize_rank
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index


        


More information about the Mlir-commits mailing list