[Mlir-commits] [mlir] 3813f24 - [mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Jun 15 08:55:08 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-15T17:54:50+02:00
New Revision: 3813f24e971bb406efa6436c48fdeb8e24a2654b
URL: https://github.com/llvm/llvm-project/commit/3813f24e971bb406efa6436c48fdeb8e24a2654b
DIFF: https://github.com/llvm/llvm-project/commit/3813f24e971bb406efa6436c48fdeb8e24a2654b.diff
LOG: [mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.
Differential Revision: https://reviews.llvm.org/D81694
Added:
mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d74e419dbb1a..48149ced5403 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -206,6 +206,15 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let constructor = "mlir::createConvertShapeToStandardPass()";
}
+//===----------------------------------------------------------------------===//
+// ShapeToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
+ let summary = "Convert operations from the shape dialect to the SCF dialect";
+ let constructor = "mlir::createConvertShapeToSCFPass()";
+}
+
//===----------------------------------------------------------------------===//
// SPIRVToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
new file mode 100644
index 000000000000..f953f6e2ddf1
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
@@ -0,0 +1,27 @@
+//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
+#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
+
+#include <memory>
+
+namespace mlir {
+
+class MLIRContext;
+class FunctionPass;
+class OwningRewritePatternList;
+
+void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+
+std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index ca5c8e0dac46..a2810f3b270b 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -26,6 +26,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
+#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 698dbd269b8e..e63b44cff782 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToStandard)
+add_subdirectory(ShapeToSCF)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
diff --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
new file mode 100644
index 000000000000..60dd2b8514da
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRShapeToSCF
+ ShapeToSCF.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRShape
+ MLIRPass
+ MLIRSCF
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
new file mode 100644
index 000000000000..db7796d5c6a0
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -0,0 +1,99 @@
+//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+/// Converts `shape.reduce` to `scf.for`.
+struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReduceOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
+ PatternRewriter &rewriter) const {
+ auto loc = reduceOp.getLoc();
+
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value extentTensor = rewriter.create<ToExtentTensorOp>(
+ loc,
+ RankedTensorType::get({ShapedType::kDynamicSize},
+ rewriter.getIndexType()),
+ reduceOp.shape());
+ Value size =
+ rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
+
+ auto loop = rewriter.create<scf::ForOp>(
+ loc, zero, size, one, reduceOp.initVals(),
+ [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+ Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
+ Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
+
+ SmallVector<Value, 2> mapped_values{iv, sizeExtent};
+ mapped_values.append(args.begin(), args.end());
+
+ BlockAndValueMapping mapping;
+ Block *reduceBody = reduceOp.getBody();
+ mapping.map(reduceBody->getArguments(), mapped_values);
+ for (auto &nested : reduceBody->without_terminator())
+ b.clone(nested, mapping);
+
+ SmallVector<Value, 2> mappedResults;
+ for (auto result : reduceBody->getTerminator()->getOperands())
+ mappedResults.push_back(mapping.lookup(result));
+ b.create<scf::YieldOp>(loc, mappedResults);
+ });
+
+ rewriter.replaceOp(reduceOp, loop.getResults());
+ return success();
+}
+
+namespace {
+struct ConvertShapeToSCFPass
+ : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
+ void runOnFunction() override;
+};
+} // namespace
+
+void ConvertShapeToSCFPass::runOnFunction() {
+ MLIRContext &ctx = getContext();
+
+ OwningRewritePatternList patterns;
+ populateShapeToSCFConversionPatterns(patterns, &ctx);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
+ target.addIllegalOp<ReduceOp>();
+ if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+}
+
+void mlir::populateShapeToSCFConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<ReduceOpConverter>(ctx);
+}
+
+std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
+ return std::make_unique<ConvertShapeToSCFPass>();
+}
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
new file mode 100644
index 000000000000..d0e6d0196dbf
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: shape_reduce
+// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
+func @shape_reduce(%shape : !shape.shape) -> !shape.size {
+ %init = shape.const_size 1
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
+ %new_acc = shape.mul %acc, %dim
+ shape.yield %new_acc : !shape.size
+ }
+ return %num_elements : !shape.size
+}
+// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
+// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
+// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
+
+// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]])
+// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
+
+// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
+// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
+// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
+// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
+// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
+// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
+// CHECK-NEXT: }
+// CHECK-NEXT: return [[RESULT]] : !shape.size
More information about the Mlir-commits
mailing list