[llvm-branch-commits] [mlir] 0386b6a - [MLIR] Lower shape.num_elements -> shape.reduce.

Alexander Belyaev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Jun 7 07:35:30 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-07T16:25:14+02:00
New Revision: 0386b6ac330bd591f1f04c24c5416cc24b0c2c0a

URL: https://github.com/llvm/llvm-project/commit/0386b6ac330bd591f1f04c24c5416cc24b0c2c0a
DIFF: https://github.com/llvm/llvm-project/commit/0386b6ac330bd591f1f04c24c5416cc24b0c2c0a.diff

LOG: [MLIR] Lower shape.num_elements -> shape.reduce.

Differential Revision: https://reviews.llvm.org/D81279

Added: 
    mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
    mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Shape/Transforms/PassDetail.h
    mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
    mlir/test/Dialect/Shape/shape-to-shape.mlir

Modified: 
    mlir/docs/Passes.md
    mlir/include/mlir/Dialect/Shape/CMakeLists.txt
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/Shape/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index 0f48396b220e..451b76370c7e 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -36,6 +36,10 @@ This document describes the available MLIR passes and their contracts.
 
 [include "QuantPasses.md"]
 
+## `shape` Dialect Passes
+
+[include "ShapePasses.md"]
+
 ## `spv` Dialect Passes
 
 [include "SPIRVPasses.md"]

diff  --git a/mlir/include/mlir/Dialect/Shape/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt
index f33061b2d87c..9f57627c321f 100644
--- a/mlir/include/mlir/Dialect/Shape/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ac5bedf3d6e3..e9cd539cef05 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -327,7 +327,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
 
   let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
   let results = (outs Variadic<AnyType>:$result);
-  let regions = (region SizedRegion<1>:$body);
+  let regions = (region SizedRegion<1>:$region);
 
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..629b8c0db294
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls)
+add_public_tablegen_target(MLIRShapeTransformsIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc ShapePasses ./)

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
new file mode 100644
index 000000000000..29cf9d1b6715
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -0,0 +1,30 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors in the
+// shape transformation library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
+/// dialect to be convertible to Standard. For example, `shape.num_elements` get
+/// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
+std::unique_ptr<Pass> createShapeToShapeLowering();
+
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
new file mode 100644
index 000000000000..46dc4dc37160
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -0,0 +1,19 @@
+//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
+#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
+  let summary = "Legalize Shape dialect to be convertible to Standard";
+  let constructor = "mlir::createShapeToShapeLowering()";
+}
+
+#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index c0c79d1c69ef..957153770351 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -37,6 +37,7 @@
 #include "mlir/Dialect/Quant/Passes.h"
 #include "mlir/Dialect/SCF/Passes.h"
 #include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
@@ -94,6 +95,10 @@ inline void registerAllPasses() {
   // Standard
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
+
+  // Shape
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
 }
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index 0a03849722cb..1f4f653fe514 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -18,3 +18,5 @@ add_mlir_dialect_library(MLIRShape
   MLIRIR
   MLIRSideEffectInterfaces
   )
+
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 04b1a51e986e..2688a459eded 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -504,7 +504,7 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
 
 static LogicalResult verify(ReduceOp op) {
   // Verify block arg types.
-  Block &block = op.body().front();
+  Block &block = op.region().front();
 
   auto blockArgsCount = op.initVals().size() + 2;
   if (block.getNumArguments() != blockArgsCount)
@@ -560,7 +560,7 @@ static void print(OpAsmPrinter &p, ReduceOp op) {
   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
     << ") ";
   p.printOptionalArrowTypeList(op.getResultTypes());
-  p.printRegion(op.body());
+  p.printRegion(op.region());
   p.printOptionalAttrDict(op.getAttrs());
 }
 

diff  --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..3c0ec3211e69
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRShapeOpsTransforms
+  ShapeToShapeLowering.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
+
+  DEPENDS
+  MLIRShapeTransformsIncGen
+  )
+
+target_link_libraries(MLIRShapeOpsTransforms
+  PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRShape
+  MLIRSupport
+  )

diff  --git a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h
new file mode 100644
index 000000000000..abb5c21d66bb
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h
@@ -0,0 +1,21 @@
+//===- PassDetail.h - Shape Pass class details ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
+#define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
new file mode 100644
index 000000000000..1ba68a0a94ee
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -0,0 +1,69 @@
+//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+namespace {
+/// Converts `shape.num_elements` to `shape.reduce`.
+struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(NumElementsOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
+                                        PatternRewriter &rewriter) const {
+  auto loc = op.getLoc();
+  Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
+  ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
+
+  // Generate reduce operator.
+  Block *body = reduce.getBody();
+  OpBuilder b = OpBuilder::atBlockEnd(body);
+  Value product =
+      b.create<MulOp>(loc, body->getArgument(1), body->getArgument(2));
+  b.create<YieldOp>(loc, product);
+
+  rewriter.replaceOp(op, reduce.result());
+  return success();
+}
+
+namespace {
+struct ShapeToShapeLowering
+    : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
+  void runOnFunction() override;
+};
+} // namespace
+
+void ShapeToShapeLowering::runOnFunction() {
+  OwningRewritePatternList patterns;
+  patterns.insert<NumElementsOpConverter>(&getContext());
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<ShapeDialect>();
+  target.addIllegalOp<NumElementsOp>();
+  if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+    signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
+  return std::make_unique<ShapeToShapeLowering>();
+}

diff  --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
new file mode 100644
index 000000000000..d2338cddc5e1
--- /dev/null
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: func @num_elements_to_reduce(
+// CHECK-SAME:    [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
+func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
+  %num_elements = shape.num_elements %shape
+  return %num_elements : !shape.size
+}
+// CHECK: [[C1:%.*]] = shape.const_size 1
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]])  -> [[SIZE_TY]]
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
+// CHECK:   [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
+// CHECK:   shape.yield [[NEW_ACC]] : [[SIZE_TY]]
+// CHECK: }
+// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
+


        


More information about the llvm-branch-commits mailing list