[Mlir-commits] [mlir] 970bb4a - [MLIR] Add `to/from_extent_tensor` lowering to the standard dialect

Frederik Gossen llvmlistbot at llvm.org
Mon Jun 8 02:38:58 PDT 2020


Author: Frederik Gossen
Date: 2020-06-08T09:38:18Z
New Revision: 970bb4a291c0a823786efee7d572a699569ec339

URL: https://github.com/llvm/llvm-project/commit/970bb4a291c0a823786efee7d572a699569ec339
DIFF: https://github.com/llvm/llvm-project/commit/970bb4a291c0a823786efee7d572a699569ec339.diff

LOG: [MLIR] Add `to/from_extent_tensor` lowering to the standard dialect

The operations `to_extent_tensor` and `from_extent_tensor` become no-ops when
lowered to the standard dialect.
This is possible with a lowering from `shape.shape` to `tensor<?xindex>`.

Differential Revision: https://reviews.llvm.org/D81162

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 8deb8b85c25e..5c74cdcaa241 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -19,16 +19,16 @@ using namespace mlir;
 namespace {
 
 /// Conversion patterns.
-class SizeToIndexOpConversion
-    : public OpConversionPattern<shape::SizeToIndexOp> {
+class FromExtentTensorOpConversion
+    : public OpConversionPattern<shape::FromExtentTensorOp> {
 public:
-  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+  using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+  matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::SizeToIndexOpOperandAdaptor transformed(operands);
-    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.input());
     return success();
   }
 };
@@ -47,6 +47,34 @@ class IndexToSizeOpConversion
   }
 };
 
+class SizeToIndexOpConversion
+    : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::SizeToIndexOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    return success();
+  }
+};
+
+class ToExtentTensorOpConversion
+    : public OpConversionPattern<shape::ToExtentTensorOp> {
+public:
+  using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.input());
+    return success();
+  }
+};
+
 /// Type conversions.
 class ShapeTypeConverter : public TypeConverter {
 public:
@@ -55,6 +83,7 @@ class ShapeTypeConverter : public TypeConverter {
   ShapeTypeConverter(MLIRContext *ctx) {
     // Add default pass-through conversion.
     addConversion([&](Type type) { return type; });
+
     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
     addConversion([ctx](shape::ShapeType type) {
       return RankedTensorType::get({ShapedType::kDynamicSize},
@@ -99,8 +128,10 @@ void mlir::populateShapeToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   // clang-format off
   patterns.insert<
+      FromExtentTensorOpConversion,
       IndexToSizeOpConversion,
-      SizeToIndexOpConversion>(ctx);
+      SizeToIndexOpConversion,
+      ToExtentTensorOpConversion>(ctx);
   // clang-format on
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 138a9b2f1db1..de420a96a70f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -39,3 +39,26 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape {
   // CHECK: return %[[SHAPE]] : tensor<?xindex>
   return %shape : !shape.shape
 }
+
+// -----
+
+// Lower `to_extent_tensor` operation to no-op.
+// CHECK-LABEL: @to_extent_tensor
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @to_extent_tensor(%shape : !shape.shape) -> tensor<?xindex> {
+  // CHECK-NEXT: return %[[SHAPE]] : tensor<?xindex>
+  %tensor = "shape.to_extent_tensor"(%shape) : (!shape.shape) -> tensor<?xindex>
+  return %tensor : tensor<?xindex>
+}
+
+// -----
+
+// Lower `from_extent_tensor` operation to no-op.
+// CHECK-LABEL: @from_extent_tensor
+// CHECK-SAME: (%[[TENSOR:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @from_extent_tensor(%tensor : tensor<?xindex>) -> !shape.shape {
+  // CHECK-NEXT: return %[[TENSOR]] : tensor<?xindex>
+  %shape = "shape.from_extent_tensor"(%tensor)
+      : (tensor<?xindex>) -> !shape.shape
+  return %shape : !shape.shape
+}


        


More information about the Mlir-commits mailing list