[Mlir-commits] [mlir] 6673c6c - [MLIR][Shape] Limit shape to standard lowerings to their supported types

Frederik Gossen llvmlistbot at llvm.org
Wed Jul 29 06:57:16 PDT 2020


Author: Frederik Gossen
Date: 2020-07-29T13:56:52Z
New Revision: 6673c6cd82f79b76c1676ab1ab30a288286acb71

URL: https://github.com/llvm/llvm-project/commit/6673c6cd82f79b76c1676ab1ab30a288286acb71
DIFF: https://github.com/llvm/llvm-project/commit/6673c6cd82f79b76c1676ab1ab30a288286acb71.diff

LOG: [MLIR][Shape] Limit shape to standard lowerings to their supported types

The lowering does not support all types for its source operations. This change
makes the patterns fail in a well-defined manner.

Differential Revision: https://reviews.llvm.org/D84443

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 4deaa8cd2df3..41d4d90b33d3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -49,8 +49,14 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
   LogicalResult
   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    typename SrcOpTy::Adaptor adaptor(operands);
-    rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
+    typename SrcOpTy::Adaptor transformed(operands);
+
+    // For now, only error-free types are supported by this lowering.
+    if (op.getType().template isa<SizeType>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
+                                         transformed.rhs());
     return success();
   }
 };
@@ -85,27 +91,31 @@ class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
 LogicalResult ShapeOfOpConversion::matchAndRewrite(
     ShapeOfOp op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  ShapeOfOp::Adaptor transformed(operands);
-  auto loc = op.getLoc();
-  auto tensorVal = transformed.arg();
-  auto tensorTy = tensorVal.getType();
+
+  // For now, only error-free types are supported by this lowering.
+  if (op.getType().isa<ShapeType>())
+    return failure();
 
   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
   // found in the corresponding pass.
+  ShapeOfOp::Adaptor transformed(operands);
+  Value tensorVal = transformed.arg();
+  Type tensorTy = tensorVal.getType();
   if (tensorTy.isa<UnrankedTensorType>())
     return failure();
 
   // Build values for individual dimensions.
   SmallVector<Value, 8> dimValues;
-  auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+  RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
   int64_t rank = rankedTensorTy.getRank();
+  auto loc = op.getLoc();
   for (int64_t i = 0; i < rank; i++) {
     if (rankedTensorTy.isDynamicDim(i)) {
-      auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+      Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
       dimValues.push_back(dimVal);
     } else {
       int64_t dim = rankedTensorTy.getDimSize(i);
-      auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+      Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
       dimValues.push_back(dimVal);
     }
   }
@@ -187,11 +197,18 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   GetExtentOp::Adaptor transformed(operands);
 
-  // Derive shape extent directly from shape origin if possible.
-  // This circumvents the necessity to materialize the shape in memory.
+  // For now, only error-free types are supported by this lowering.
+  if (op.getType().isa<SizeType>())
+    return failure();
+
+  // Derive shape extent directly from shape origin if possible. This
+  // circumvents the necessity to materialize the shape in memory.
   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
-    rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
-    return success();
+    if (shapeOfOp.arg().getType().isa<ShapedType>()) {
+      rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+                                         transformed.dim());
+      return success();
+    }
   }
 
   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
@@ -241,7 +258,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
 
   // Apply conversion.
   auto module = getOperation();
-  if (failed(applyFullConversion(module, target, patterns)))
+  if (failed(applyPartialConversion(module, target, patterns)))
     signalPassFailure();
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 0e30cc2bdf56..3f19de9c52f0 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -13,6 +13,30 @@ func @binary_ops(%lhs : index, %rhs : index) {
 
 // -----
 
+// Don't lower binary ops when they operate on `shape.size`.
+// CHECK-LABEL: @binary_ops_on_size
+// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size)
+func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
+  // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
+  // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
+  %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size
+  %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size
+  return
+}
+
+// -----
+
+// Don't lower `shape_of` with `shape.shape` type.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+  // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
+  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
+  return
+}
+
+// -----
+
 // Lower `shape_of` for statically shaped tensor.
 // CHECK-LABEL: @shape_of_stat
 // CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
@@ -55,6 +79,17 @@ func @rank(%shape : tensor<?xindex>) -> index {
 
 // -----
 
+// Don't lower `get_extent` if it is of type `shape.size`.
+// CHECK-LABEL: @get_extent
+func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size {
+  // CHECK: shape.get_extent
+  %result = shape.get_extent %shape, %idx
+      : tensor<?xindex>, !shape.size -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
 // Express `get_extent` as `std.dim` when it relies directly on the outcome of a
 // `shape_of` operation.
 // CHECK-LABEL: @get_extent_shape_of


        


More information about the Mlir-commits mailing list