[Mlir-commits] [mlir] 9ed1e58 - [mlir][shape] Start a pass that lowers shape constraints.

Sean Silva llvmlistbot at llvm.org
Thu Sep 24 12:25:47 PDT 2020


Author: Sean Silva
Date: 2020-09-24T12:25:30-07:00
New Revision: 9ed1e5873c19eb817fb9e36d0262c7effee5d35e

URL: https://github.com/llvm/llvm-project/commit/9ed1e5873c19eb817fb9e36d0262c7effee5d35e
DIFF: https://github.com/llvm/llvm-project/commit/9ed1e5873c19eb817fb9e36d0262c7effee5d35e.diff

LOG: [mlir][shape] Start a pass that lowers shape constraints.

This pass converts shape.cstr_* ops to eager (side-effecting)
error-handling code. After that conversion is done, the witnesses are
trivially satisfied and are replaced with `shape.const_witness true`.

Differential Revision: https://reviews.llvm.org/D87941

Added: 
    mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
    mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
    mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index dae59c9e792e..547b952b60b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -242,6 +242,21 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
   let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
 }
 
+def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
+  let summary = "Convert shape constraint operations to the standard dialect";
+  let description = [{
+    This pass eliminates shape constraints from the program, converting them to
+    eager (side-effecting) error handling code.
+
+    This pass is separate from the regular convert-shape-to-standard, despite
+    converting between the same dialects, because converting shape constraints
+    can happen at a 
diff erent part of the program than general shape
+    computation lowering.
+  }];
+  let constructor = "mlir::createConvertShapeConstraintsPass()";
+  let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // SPIRVToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
index 74e4ff758022..176f10183881 100644
--- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -13,6 +13,7 @@
 
 namespace mlir {
 
+class FuncOp;
 class MLIRContext;
 class ModuleOp;
 template <typename T>
@@ -24,6 +25,11 @@ void populateShapeToStandardConversionPatterns(
 
 std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
 
+void populateConvertShapeConstraintsConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_

diff  --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index 8750c331859e..eaea3de6c869 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_conversion_library(MLIRShapeToStandard
+  ConvertShapeConstraints.cpp
   ShapeToStandard.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
new file mode 100644
index 000000000000..8f8342a47ea8
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -0,0 +1,143 @@
+//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+namespace {
+class ConvertCstrBroadcastableOp
+    : public OpRewritePattern<shape::CstrBroadcastableOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getType().isa<shape::ShapeType>() ||
+        op.lhs().getType().isa<shape::ShapeType>() ||
+        op.rhs().getType().isa<shape::ShapeType>()) {
+      return rewriter.notifyMatchFailure(
+          op, "cannot convert error-propagating shapes");
+    }
+
+    auto loc = op.getLoc();
+    Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+    // Find smaller and greater rank and extent tensor.
+    Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
+    Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
+    Value lhsSmaller =
+        rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+    Type indexTy = rewriter.getIndexType();
+    Type extentTensorTy = op.lhs().getType();
+    auto ifOp = rewriter.create<scf::IfOp>(
+        loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+        lhsSmaller,
+        [&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(
+              loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
+        },
+        [&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(
+              loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
+        });
+    Value lesserRank = ifOp.getResult(0);
+    Value lesserRankOperand = ifOp.getResult(1);
+    Value greaterRank = ifOp.getResult(2);
+    Value greaterRankOperand = ifOp.getResult(3);
+
+    Value rankDiff =
+        rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
+
+    // Generate code to compare the shapes extent by extent, and emit errors for
+    // non-broadcast-compatible shapes.
+    // Two extents are broadcast-compatible if
+    // 1. they are both equal, or
+    // 2. at least one of them is 1.
+
+    rewriter.create<scf::ForOp>(
+        loc, rankDiff, greaterRank, one, llvm::None,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+          Value greaterRankOperandExtent = b.create<ExtractElementOp>(
+              loc, greaterRankOperand, ValueRange{iv});
+          Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+          Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+              loc, lesserRankOperand, ValueRange{ivShifted});
+
+          Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
+              loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
+          Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
+              loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
+          Value extentsAgree =
+              b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
+                               lesserRankOperandExtent);
+          auto broadcastIsValid =
+              b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
+                             b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
+                                            lesserRankOperandExtentIsOne));
+          b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
+          b.create<scf::YieldOp>(loc);
+        });
+
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(shape::CstrRequireOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateConvertShapeConstraintsConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<ConvertCstrBroadcastableOp>(ctx);
+  patterns.insert<ConvertCstrRequireOp>(ctx);
+}
+
+namespace {
+// This pass eliminates shape constraints from the program, converting them to
+// eager (side-effecting) error handling code. After eager error handling code
+// is emitted, witnesses are satisfied, so they are replace with
+// `shape.const_witness true`.
+class ConvertShapeConstraints
+    : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
+  void runOnOperation() {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    OwningRewritePatternList patterns;
+    populateConvertShapeConstraintsConversionPatterns(patterns, context);
+
+    if (failed(applyPatternsAndFoldGreedily(func, patterns)))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertShapeConstraintsPass() {
+  return std::make_unique<ConvertShapeConstraints>();
+}

diff  --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
new file mode 100644
index 000000000000..1f7b6d60dd4f
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s
+
+// There's not very much useful to check here other than pasting the output.
+// CHECK-LABEL:   func @cstr_broadcastable(
+// CHECK-SAME:                             %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME:                             %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           %[[RET:.*]] = shape.const_witness true
+// CHECK:           %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
+// CHECK:           %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+// CHECK:             scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+// CHECK:           } else {
+// CHECK:             scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+// CHECK:           }
+// CHECK:           %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
+// CHECK:           scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
+// CHECK:             %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
+// CHECK:             %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
+// CHECK:             %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
+// CHECK:             %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
+// CHECK:             %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
+// CHECK:             %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
+// CHECK:             %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
+// CHECK:             %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
+// CHECK:             assert %[[BROADCASTISVALID]], "invalid broadcast"
+// CHECK:           }
+// CHECK:           return %[[RET]] : !shape.witness
+// CHECK:         }
+func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
+  %witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
+  return %witness : !shape.witness
+}
+
+// CHECK-LABEL: func @cstr_require
+func @cstr_require(%arg0: i1) -> !shape.witness {
+  // CHECK: %[[RET:.*]] = shape.const_witness true
+  // CHECK: assert %arg0, "msg"
+  // CHECK: return %[[RET]]
+  %witness = shape.cstr_require %arg0, "msg"
+  return %witness : !shape.witness
+}


        


More information about the Mlir-commits mailing list