[Mlir-commits] [mlir] 67391a7 - [MLIR] Lower `shape.reduce` to `scf.for` only when argument is `tensor<?xindex>`

Frederik Gossen llvmlistbot at llvm.org
Thu Jul 16 06:56:07 PDT 2020


Author: Frederik Gossen
Date: 2020-07-16T13:55:48Z
New Revision: 67391a7045486c5d82b763dc1c32dba6d99ee31a

URL: https://github.com/llvm/llvm-project/commit/67391a7045486c5d82b763dc1c32dba6d99ee31a
DIFF: https://github.com/llvm/llvm-project/commit/67391a7045486c5d82b763dc1c32dba6d99ee31a.diff

LOG: [MLIR] Lower `shape.reduce` to `scf.for` only when argument is `tensor<?xindex>`

To make it clear when shape error values cannot occur the shape operations can
operate on extent tensors. This change updates the lowering for `shape.reduce`
accordingly.

Differential Revision: https://reviews.llvm.org/D83944

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 1f1134757b3a..0caaacd75bed 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -17,45 +17,46 @@
 
 using namespace mlir;
 using namespace mlir::shape;
+using namespace mlir::scf;
 
 namespace {
 /// Converts `shape.reduce` to `scf.for`.
-struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
+struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(ReduceOp op,
-                                PatternRewriter &rewriter) const final;
+  LogicalResult
+  matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final;
 };
 } // namespace
 
 LogicalResult
-ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
-                                   PatternRewriter &rewriter) const {
-  auto loc = reduceOp.getLoc();
+ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+                                   ConversionPatternRewriter &rewriter) const {
+  // For now, this lowering is only defined on `tensor<?xindex>` operands.
+  if (!op.shape().getType().isa<RankedTensorType>())
+    return failure();
+
+  auto loc = op.getLoc();
+  shape::ReduceOp::Adaptor transformed(operands);
 
   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  Value extentTensor = rewriter.create<ToExtentTensorOp>(
-      loc,
-      RankedTensorType::get({ShapedType::kDynamicSize},
-                            rewriter.getIndexType()),
-      reduceOp.shape());
-  Value size =
-      rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
+  Type indexTy = rewriter.getIndexType();
+  Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
 
   auto loop = rewriter.create<scf::ForOp>(
-      loc, zero, size, one, reduceOp.initVals(),
-      [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
-        Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
-        Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
+      loc, zero, rank, one, op.initVals(),
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+        Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
 
-        SmallVector<Value, 2> mapped_values{iv, sizeExtent};
-        mapped_values.append(args.begin(), args.end());
+        SmallVector<Value, 2> mappedValues{iv, extent};
+        mappedValues.append(args.begin(), args.end());
 
         BlockAndValueMapping mapping;
-        Block *reduceBody = reduceOp.getBody();
-        mapping.map(reduceBody->getArguments(), mapped_values);
+        Block *reduceBody = op.getBody();
+        mapping.map(reduceBody->getArguments(), mappedValues);
         for (auto &nested : reduceBody->without_terminator())
           b.clone(nested, mapping);
 
@@ -65,7 +66,7 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
         b.create<scf::YieldOp>(loc, mappedResults);
       });
 
-  rewriter.replaceOp(reduceOp, loop.getResults());
+  rewriter.replaceOp(op, loop.getResults());
   return success();
 }
 
@@ -138,8 +139,8 @@ void ConvertShapeToSCFPass::runOnFunction() {
 
   // Setup target legality.
   ConversionTarget target(getContext());
-  target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
-  target.addIllegalOp<ReduceOp, ShapeOfOp>();
+  target.addLegalDialect<SCFDialect, StandardOpsDialect>();
+  target.addLegalOp<ModuleOp, FuncOp>();
 
   // Apply conversion.
   if (failed(applyPartialConversion(getFunction(), target, patterns)))

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 9051054b3f18..6ba630aa4aa6 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -1,31 +1,26 @@
 // RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @shape_reduce
-// CHECK-SAME:  ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
-func @shape_reduce(%shape : !shape.shape) -> !shape.size {
-  %init = shape.const_size 1
-  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
-    ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
-      %new_acc = shape.mul %acc, %dim
-      shape.yield %new_acc : !shape.size
+// CHECK-SAME:  (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @shape_reduce(%shape : tensor<?xindex>) -> index {
+  %init = constant 1 : index
+  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+    ^bb0(%index : index, %extent : index, %acc: index):
+      %new_acc = muli %acc, %extent : index
+      shape.yield %new_acc : index
   }
-  return %num_elements : !shape.size
+  return %num_elements : index
 }
-// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
-// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
-// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
-
-// CHECK-NEXT: [[EXTENTS:%.*]] = shape.to_extent_tensor [[SHAPE]]
-// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
-
-// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
-// CHECK-SAME:       step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
-// CHECK-NEXT:   [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
-// CHECK-NEXT:   [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
-// CHECK-NEXT:   [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
-// CHECK-NEXT:   scf.yield [[NEW_ACC]] : !shape.size
+// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
+// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
+// CHECK-NEXT:   %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
+// CHECK-NEXT:   %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
+// CHECK-NEXT:   scf.yield %[[NEW_ACC]] : index
 // CHECK-NEXT: }
-// CHECK-NEXT: return [[RESULT]] : !shape.size
+// CHECK-NEXT: return %[[RESULT]] : index
 
 // -----
 


        


More information about the Mlir-commits mailing list