[Mlir-commits] [mlir] 9ed1e58 - [mlir][shape] Start a pass that lowers shape constraints.
Sean Silva
llvmlistbot at llvm.org
Thu Sep 24 12:25:47 PDT 2020
Author: Sean Silva
Date: 2020-09-24T12:25:30-07:00
New Revision: 9ed1e5873c19eb817fb9e36d0262c7effee5d35e
URL: https://github.com/llvm/llvm-project/commit/9ed1e5873c19eb817fb9e36d0262c7effee5d35e
DIFF: https://github.com/llvm/llvm-project/commit/9ed1e5873c19eb817fb9e36d0262c7effee5d35e.diff
LOG: [mlir][shape] Start a pass that lowers shape constraints.
This pass converts shape.cstr_* ops to eager (side-effecting)
error-handling code. After that conversion is done, the witnesses are
trivially satisfied and are replaced with `shape.const_witness true`.
Differential Revision: https://reviews.llvm.org/D87941
Added:
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index dae59c9e792e..547b952b60b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -242,6 +242,21 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
}
+def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
+ let summary = "Convert shape constraint operations to the standard dialect";
+ let description = [{
+ This pass eliminates shape constraints from the program, converting them to
+ eager (side-effecting) error handling code.
+
+ This pass is separate from the regular convert-shape-to-standard, despite
+ converting between the same dialects, because converting shape constraints
+ can happen at a
diff erent part of the program than general shape
+ computation lowering.
+ }];
+ let constructor = "mlir::createConvertShapeConstraintsPass()";
+ let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
+}
+
//===----------------------------------------------------------------------===//
// SPIRVToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
index 74e4ff758022..176f10183881 100644
--- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -13,6 +13,7 @@
namespace mlir {
+class FuncOp;
class MLIRContext;
class ModuleOp;
template <typename T>
@@ -24,6 +25,11 @@ void populateShapeToStandardConversionPatterns(
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
+void populateConvertShapeConstraintsConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
+
} // namespace mlir
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index 8750c331859e..eaea3de6c869 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_conversion_library(MLIRShapeToStandard
+ ConvertShapeConstraints.cpp
ShapeToStandard.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
new file mode 100644
index 000000000000..8f8342a47ea8
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -0,0 +1,143 @@
+//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+namespace {
+class ConvertCstrBroadcastableOp
+ : public OpRewritePattern<shape::CstrBroadcastableOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getType().isa<shape::ShapeType>() ||
+ op.lhs().getType().isa<shape::ShapeType>() ||
+ op.rhs().getType().isa<shape::ShapeType>()) {
+ return rewriter.notifyMatchFailure(
+ op, "cannot convert error-propagating shapes");
+ }
+
+ auto loc = op.getLoc();
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+ // Find smaller and greater rank and extent tensor.
+ Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
+ Value lhsSmaller =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+ Type indexTy = rewriter.getIndexType();
+ Type extentTensorTy = op.lhs().getType();
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+ lhsSmaller,
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(
+ loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(
+ loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
+ });
+ Value lesserRank = ifOp.getResult(0);
+ Value lesserRankOperand = ifOp.getResult(1);
+ Value greaterRank = ifOp.getResult(2);
+ Value greaterRankOperand = ifOp.getResult(3);
+
+ Value rankDiff =
+ rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
+
+ // Generate code to compare the shapes extent by extent, and emit errors for
+ // non-broadcast-compatible shapes.
+ // Two extents are broadcast-compatible if
+ // 1. they are both equal, or
+ // 2. at least one of them is 1.
+
+ rewriter.create<scf::ForOp>(
+ loc, rankDiff, greaterRank, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+ Value greaterRankOperandExtent = b.create<ExtractElementOp>(
+ loc, greaterRankOperand, ValueRange{iv});
+ Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+ Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+ loc, lesserRankOperand, ValueRange{ivShifted});
+
+ Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
+ loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
+ Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
+ loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
+ Value extentsAgree =
+ b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
+ lesserRankOperandExtent);
+ auto broadcastIsValid =
+ b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
+ b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
+ lesserRankOperandExtentIsOne));
+ b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
+ b.create<scf::YieldOp>(loc);
+ });
+
+ rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(shape::CstrRequireOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
+ rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateConvertShapeConstraintsConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<ConvertCstrBroadcastableOp>(ctx);
+ patterns.insert<ConvertCstrRequireOp>(ctx);
+}
+
+namespace {
+// This pass eliminates shape constraints from the program, converting them to
+// eager (side-effecting) error handling code. After eager error handling code
+// is emitted, witnesses are satisfied, so they are replace with
+// `shape.const_witness true`.
+class ConvertShapeConstraints
+ : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
+ void runOnOperation() {
+ auto func = getOperation();
+ auto *context = &getContext();
+
+ OwningRewritePatternList patterns;
+ populateConvertShapeConstraintsConversionPatterns(patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(func, patterns)))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertShapeConstraintsPass() {
+ return std::make_unique<ConvertShapeConstraints>();
+}
diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
new file mode 100644
index 000000000000..1f7b6d60dd4f
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s
+
+// There's not very much useful to check here other than pasting the output.
+// CHECK-LABEL: func @cstr_broadcastable(
+// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[RET:.*]] = shape.const_witness true
+// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
+// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+// CHECK: } else {
+// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+// CHECK: }
+// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
+// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
+// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
+// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
+// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
+// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
+// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
+// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
+// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
+// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
+// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
+// CHECK: }
+// CHECK: return %[[RET]] : !shape.witness
+// CHECK: }
+func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
+ %witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
+ return %witness : !shape.witness
+}
+
+// CHECK-LABEL: func @cstr_require
+func @cstr_require(%arg0: i1) -> !shape.witness {
+ // CHECK: %[[RET:.*]] = shape.const_witness true
+ // CHECK: assert %arg0, "msg"
+ // CHECK: return %[[RET]]
+ %witness = shape.cstr_require %arg0, "msg"
+ return %witness : !shape.witness
+}
More information about the Mlir-commits
mailing list