[Mlir-commits] [mlir] 5d9f33a - [MLIR][Shape] Add conversion for missing ops to standard

Stephan Herhut llvmlistbot at llvm.org
Wed Jul 29 03:46:39 PDT 2020


Author: Stephan Herhut
Date: 2020-07-29T12:46:18+02:00
New Revision: 5d9f33aaa00cc02143f137387adc8dd1e51b71d3

URL: https://github.com/llvm/llvm-project/commit/5d9f33aaa00cc02143f137387adc8dd1e51b71d3
DIFF: https://github.com/llvm/llvm-project/commit/5d9f33aaa00cc02143f137387adc8dd1e51b71d3.diff

LOG: [MLIR][Shape] Add conversion for missing ops to standard

This adds conversions for const_size and to_extent_tensor. Also, cast-like operations are now folded away if the source and target types are the same.

Differential Revision: https://reviews.llvm.org/D84745

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 9ba838dbbb26..6ea61376c34d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -440,7 +440,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
     arguments.
   }];
 
-  let arguments = (outs Shape_SizeOrIndexType:$arg);
+  let arguments = (ins Shape_SizeOrIndexType:$arg);
   let results = (outs Index:$result);
 
   let assemblyFormat = "$arg attr-dict `:` type($arg)";

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index b84b6ba3b5d6..efeaa18e17c1 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -56,6 +56,21 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
 };
 } // namespace
 
+namespace {
+class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
+public:
+  using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
 public:
@@ -136,6 +151,27 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
   return success();
 }
 
+namespace {
+class ToExtentTensorOpConversion
+    : public OpConversionPattern<ToExtentTensorOp> {
+public:
+  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    ToExtentTensorOpAdaptor adaptor(operands);
+
+    if (!adaptor.input().getType().isa<RankedTensorType>())
+      return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
+
+    rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
+                                              op.getType());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -244,9 +280,11 @@ void mlir::populateShapeToStandardConversionPatterns(
       BinaryOpConversion<AddOp, AddIOp>,
       ConstShapeOpConverter,
       BinaryOpConversion<MulOp, MulIOp>,
+      ConstSizeOpConversion,
       GetExtentOpConverter,
       RankOpConverter,
-      ShapeOfOpConversion>(ctx);
+      ShapeOfOpConversion,
+      ToExtentTensorOpConversion>(ctx);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3c71e3409923..02fe7b8129f7 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -753,7 +753,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
   // `IntegerAttr`s which makes constant folding simple.
   if (Attribute arg = operands[0])
     return arg;
-  return {};
+  return impl::foldCastOp(*this);
 }
 
 void SizeToIndexOp::getCanonicalizationPatterns(
@@ -812,7 +812,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
 
 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[0])
-    return nullptr;
+    return impl::foldCastOp(*this);
   Builder builder(getContext());
   auto shape = llvm::to_vector<6>(
       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 7f875f3bb19f..b94b24599351 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -149,3 +149,28 @@ func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
   return %result : tensor<?xindex>
 }
 
+// -----
+
+// Lower 'const_size` to `std.constant`
+// CHECK-LABEL: @const_size
+func @const_size() -> index {
+  // CHECK: %[[RES:.*]] = constant 42 : index
+  %size = shape.const_size 42
+  %result = shape.size_to_index %size : !shape.size
+  // CHECK: return %[[RES]]
+  return %result : index
+}
+
+// -----
+
+// Lower `to_extent_tensor` to `std.tensor_cast`
+// Fold to_extent_tensor when already on tensor.
+// CHECK-LABEL: @to_extent_tensor
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
+func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
+  // CHECK-NOT: to_extent_tensor
+  // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
+  %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
+  // CHECK: return %[[RES]]
+  return %casted : tensor<3xindex>
+}

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5fe2ac108a69..ed5bf6999cd9 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -774,3 +774,22 @@ func @fold_mul_mixed() -> !shape.size {
   return %result : !shape.size
 }
 
+// -----
+
+// Fold index_cast when already on index.
+// CHECK-LABEL: @fold_index_cast_on_index
+func @fold_index_cast_on_index(%arg: index) -> index {
+  // CHECK-NOT: size_to_index
+  %casted = shape.size_to_index %arg : index
+  return %casted : index
+}
+
+// -----
+
+// Fold to_extent_tensor when already on tensor.
+// CHECK-LABEL: @fold_to_extent_tensor_on_tensor
+func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK-NOT: to_extent_tensor
+  %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
+  return %casted : tensor<?xindex>
+}


        


More information about the Mlir-commits mailing list