[Mlir-commits] [mlir] 3813f24 - [mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.

Alexander Belyaev llvmlistbot at llvm.org
Mon Jun 15 08:55:08 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-15T17:54:50+02:00
New Revision: 3813f24e971bb406efa6436c48fdeb8e24a2654b

URL: https://github.com/llvm/llvm-project/commit/3813f24e971bb406efa6436c48fdeb8e24a2654b
DIFF: https://github.com/llvm/llvm-project/commit/3813f24e971bb406efa6436c48fdeb8e24a2654b.diff

LOG: [mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.

Differential Revision: https://reviews.llvm.org/D81694

Added: 
    mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
    mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d74e419dbb1a..48149ced5403 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -206,6 +206,15 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
   let constructor = "mlir::createConvertShapeToStandardPass()";
 }
 
+//===----------------------------------------------------------------------===//
+// ShapeToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
+  let summary = "Convert operations from the shape dialect to the SCF dialect";
+  let constructor = "mlir::createConvertShapeToSCFPass()";
+}
+
 //===----------------------------------------------------------------------===//
 // SPIRVToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
new file mode 100644
index 000000000000..f953f6e2ddf1
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
@@ -0,0 +1,27 @@
+//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
+#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
+
+#include <memory>
+
+namespace mlir {
+
+class MLIRContext;
+class FunctionPass;
+class OwningRewritePatternList;
+
+void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
+                                          MLIRContext *ctx);
+
+std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index ca5c8e0dac46..a2810f3b270b 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -26,6 +26,7 @@
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
+#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 698dbd269b8e..e63b44cff782 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(SCFToGPU)
 add_subdirectory(SCFToStandard)
+add_subdirectory(ShapeToSCF)
 add_subdirectory(ShapeToStandard)
 add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)

diff  --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
new file mode 100644
index 000000000000..60dd2b8514da
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRShapeToSCF
+  ShapeToSCF.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRShape
+  MLIRPass
+  MLIRSCF
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
new file mode 100644
index 000000000000..db7796d5c6a0
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -0,0 +1,99 @@
+//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+/// Converts `shape.reduce` to `scf.for`.
+struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReduceOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
+                                   PatternRewriter &rewriter) const {
+  auto loc = reduceOp.getLoc();
+
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+  Value extentTensor = rewriter.create<ToExtentTensorOp>(
+      loc,
+      RankedTensorType::get({ShapedType::kDynamicSize},
+                            rewriter.getIndexType()),
+      reduceOp.shape());
+  Value size =
+      rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
+
+  auto loop = rewriter.create<scf::ForOp>(
+      loc, zero, size, one, reduceOp.initVals(),
+      [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+        Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
+        Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
+
+        SmallVector<Value, 2> mapped_values{iv, sizeExtent};
+        mapped_values.append(args.begin(), args.end());
+
+        BlockAndValueMapping mapping;
+        Block *reduceBody = reduceOp.getBody();
+        mapping.map(reduceBody->getArguments(), mapped_values);
+        for (auto &nested : reduceBody->without_terminator())
+          b.clone(nested, mapping);
+
+        SmallVector<Value, 2> mappedResults;
+        for (auto result : reduceBody->getTerminator()->getOperands())
+          mappedResults.push_back(mapping.lookup(result));
+        b.create<scf::YieldOp>(loc, mappedResults);
+      });
+
+  rewriter.replaceOp(reduceOp, loop.getResults());
+  return success();
+}
+
+namespace {
+struct ConvertShapeToSCFPass
+    : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
+  void runOnFunction() override;
+};
+} // namespace
+
+void ConvertShapeToSCFPass::runOnFunction() {
+  MLIRContext &ctx = getContext();
+
+  OwningRewritePatternList patterns;
+  populateShapeToSCFConversionPatterns(patterns, &ctx);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
+  target.addIllegalOp<ReduceOp>();
+  if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+    signalPassFailure();
+}
+
+void mlir::populateShapeToSCFConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<ReduceOpConverter>(ctx);
+}
+
+std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
+  return std::make_unique<ConvertShapeToSCFPass>();
+}

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
new file mode 100644
index 000000000000..d0e6d0196dbf
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: shape_reduce
+// CHECK-SAME:   [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
+func @shape_reduce(%shape : !shape.shape) -> !shape.size {
+  %init = shape.const_size 1
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
+      %new_acc = shape.mul %acc, %dim
+      shape.yield %new_acc : !shape.size
+  }
+  return %num_elements : !shape.size
+}
+// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
+// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
+// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
+
+// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]])
+// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
+
+// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
+// CHECK-SAME:       step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
+// CHECK-NEXT:   [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
+// CHECK-NEXT:   [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
+// CHECK-NEXT:   [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
+// CHECK-NEXT:   scf.yield [[NEW_ACC]] : !shape.size
+// CHECK-NEXT: }
+// CHECK-NEXT: return [[RESULT]] : !shape.size


        


More information about the Mlir-commits mailing list