[Mlir-commits] [mlir] 970bb4a - [MLIR] Add `to/from_extent_tensor` lowering to the standard dialect
Frederik Gossen
llvmlistbot at llvm.org
Mon Jun 8 02:38:58 PDT 2020
Author: Frederik Gossen
Date: 2020-06-08T09:38:18Z
New Revision: 970bb4a291c0a823786efee7d572a699569ec339
URL: https://github.com/llvm/llvm-project/commit/970bb4a291c0a823786efee7d572a699569ec339
DIFF: https://github.com/llvm/llvm-project/commit/970bb4a291c0a823786efee7d572a699569ec339.diff
LOG: [MLIR] Add `to/from_extent_tensor` lowering to the standard dialect
The operations `to_extent_tensor` and `from_extent_tensor` become no-ops when
lowered to the standard dialect.
This is possible with a lowering from `shape.shape` to `tensor<?xindex>`.
Differential Revision: https://reviews.llvm.org/D81162
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 8deb8b85c25e..5c74cdcaa241 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -19,16 +19,16 @@ using namespace mlir;
namespace {
/// Conversion patterns.
-class SizeToIndexOpConversion
- : public OpConversionPattern<shape::SizeToIndexOp> {
+class FromExtentTensorOpConversion
+ : public OpConversionPattern<shape::FromExtentTensorOp> {
public:
- using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+ using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::SizeToIndexOpOperandAdaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.arg());
+ shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
@@ -47,6 +47,34 @@ class IndexToSizeOpConversion
}
};
+class SizeToIndexOpConversion
+ : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+ using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::SizeToIndexOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.arg());
+ return success();
+ }
+};
+
+class ToExtentTensorOpConversion
+ : public OpConversionPattern<shape::ToExtentTensorOp> {
+public:
+ using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.input());
+ return success();
+ }
+};
+
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
@@ -55,6 +83,7 @@ class ShapeTypeConverter : public TypeConverter {
ShapeTypeConverter(MLIRContext *ctx) {
// Add default pass-through conversion.
addConversion([&](Type type) { return type; });
+
addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
addConversion([ctx](shape::ShapeType type) {
return RankedTensorType::get({ShapedType::kDynamicSize},
@@ -99,8 +128,10 @@ void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
patterns.insert<
+ FromExtentTensorOpConversion,
IndexToSizeOpConversion,
- SizeToIndexOpConversion>(ctx);
+ SizeToIndexOpConversion,
+ ToExtentTensorOpConversion>(ctx);
// clang-format on
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 138a9b2f1db1..de420a96a70f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -39,3 +39,26 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape {
// CHECK: return %[[SHAPE]] : tensor<?xindex>
return %shape : !shape.shape
}
+
+// -----
+
+// Lower `to_extent_tensor` operation to no-op.
+// CHECK-LABEL: @to_extent_tensor
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @to_extent_tensor(%shape : !shape.shape) -> tensor<?xindex> {
+ // CHECK-NEXT: return %[[SHAPE]] : tensor<?xindex>
+ %tensor = "shape.to_extent_tensor"(%shape) : (!shape.shape) -> tensor<?xindex>
+ return %tensor : tensor<?xindex>
+}
+
+// -----
+
+// Lower `from_extent_tensor` operation to no-op.
+// CHECK-LABEL: @from_extent_tensor
+// CHECK-SAME: (%[[TENSOR:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @from_extent_tensor(%tensor : tensor<?xindex>) -> !shape.shape {
+ // CHECK-NEXT: return %[[TENSOR]] : tensor<?xindex>
+ %shape = "shape.from_extent_tensor"(%tensor)
+ : (tensor<?xindex>) -> !shape.shape
+ return %shape : !shape.shape
+}
More information about the Mlir-commits
mailing list