[Mlir-commits] [mlir] 67391a7 - [MLIR] Lower `shape.reduce` to `scf.for` only when argument is `tensor<?xindex>`
Frederik Gossen
llvmlistbot at llvm.org
Thu Jul 16 06:56:07 PDT 2020
Author: Frederik Gossen
Date: 2020-07-16T13:55:48Z
New Revision: 67391a7045486c5d82b763dc1c32dba6d99ee31a
URL: https://github.com/llvm/llvm-project/commit/67391a7045486c5d82b763dc1c32dba6d99ee31a
DIFF: https://github.com/llvm/llvm-project/commit/67391a7045486c5d82b763dc1c32dba6d99ee31a.diff
LOG: [MLIR] Lower `shape.reduce` to `scf.for` only when argument is `tensor<?xindex>`
To make it clear when shape error values cannot occur the shape operations can
operate on extent tensors. This change updates the lowering for `shape.reduce`
accordingly.
Differential Revision: https://reviews.llvm.org/D83944
Added:
Modified:
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 1f1134757b3a..0caaacd75bed 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -17,45 +17,46 @@
using namespace mlir;
using namespace mlir::shape;
+using namespace mlir::scf;
namespace {
/// Converts `shape.reduce` to `scf.for`.
-struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
+struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(ReduceOp op,
- PatternRewriter &rewriter) const final;
+ LogicalResult
+ matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
-ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
- PatternRewriter &rewriter) const {
- auto loc = reduceOp.getLoc();
+ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // For now, this lowering is only defined on `tensor<?xindex>` operands.
+ if (!op.shape().getType().isa<RankedTensorType>())
+ return failure();
+
+ auto loc = op.getLoc();
+ shape::ReduceOp::Adaptor transformed(operands);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- Value extentTensor = rewriter.create<ToExtentTensorOp>(
- loc,
- RankedTensorType::get({ShapedType::kDynamicSize},
- rewriter.getIndexType()),
- reduceOp.shape());
- Value size =
- rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
+ Type indexTy = rewriter.getIndexType();
+ Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
auto loop = rewriter.create<scf::ForOp>(
- loc, zero, size, one, reduceOp.initVals(),
- [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
- Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
- Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
+ loc, zero, rank, one, op.initVals(),
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
- SmallVector<Value, 2> mapped_values{iv, sizeExtent};
- mapped_values.append(args.begin(), args.end());
+ SmallVector<Value, 2> mappedValues{iv, extent};
+ mappedValues.append(args.begin(), args.end());
BlockAndValueMapping mapping;
- Block *reduceBody = reduceOp.getBody();
- mapping.map(reduceBody->getArguments(), mapped_values);
+ Block *reduceBody = op.getBody();
+ mapping.map(reduceBody->getArguments(), mappedValues);
for (auto &nested : reduceBody->without_terminator())
b.clone(nested, mapping);
@@ -65,7 +66,7 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
b.create<scf::YieldOp>(loc, mappedResults);
});
- rewriter.replaceOp(reduceOp, loop.getResults());
+ rewriter.replaceOp(op, loop.getResults());
return success();
}
@@ -138,8 +139,8 @@ void ConvertShapeToSCFPass::runOnFunction() {
// Setup target legality.
ConversionTarget target(getContext());
- target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
- target.addIllegalOp<ReduceOp, ShapeOfOp>();
+ target.addLegalDialect<SCFDialect, StandardOpsDialect>();
+ target.addLegalOp<ModuleOp, FuncOp>();
// Apply conversion.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 9051054b3f18..6ba630aa4aa6 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -1,31 +1,26 @@
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
// CHECK-LABEL: @shape_reduce
-// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
-func @shape_reduce(%shape : !shape.shape) -> !shape.size {
- %init = shape.const_size 1
- %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
- ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
- %new_acc = shape.mul %acc, %dim
- shape.yield %new_acc : !shape.size
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @shape_reduce(%shape : tensor<?xindex>) -> index {
+ %init = constant 1 : index
+ %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+ ^bb0(%index : index, %extent : index, %acc: index):
+ %new_acc = muli %acc, %extent : index
+ shape.yield %new_acc : index
}
- return %num_elements : !shape.size
+ return %num_elements : index
}
-// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
-// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
-// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
-
-// CHECK-NEXT: [[EXTENTS:%.*]] = shape.to_extent_tensor [[SHAPE]]
-// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
-
-// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
-// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
-// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
-// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
-// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
-// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
+// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
+// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
+// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
+// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
+// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index
// CHECK-NEXT: }
-// CHECK-NEXT: return [[RESULT]] : !shape.size
+// CHECK-NEXT: return %[[RESULT]] : index
// -----
More information about the Mlir-commits
mailing list