[Mlir-commits] [mlir] 24acade - [mlir][Shape] Make shape_eq nary

Benjamin Kramer llvmlistbot at llvm.org
Wed Mar 3 07:27:54 PST 2021


Author: Benjamin Kramer
Date: 2021-03-03T16:26:40+01:00
New Revision: 24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc

URL: https://github.com/llvm/llvm-project/commit/24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc
DIFF: https://github.com/llvm/llvm-project/commit/24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc.diff

LOG: [mlir][Shape] Make shape_eq nary

This gets rid of a dubious shape_eq %a, %a fold, that folds shape_eq
even if %a is not an Attribute.

Differential Revision: https://reviews.llvm.org/D97728

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 0a6122801835..c651b84429b8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -168,20 +168,38 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
+def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
+                                            InferTypeOpInterface]> {
   let summary = "Returns whether the input shapes or extent tensors are equal";
   let description = [{
-    Takes two shape or extent tensor operands and determines whether they are
-    equal. When extent tensors are compared to shapes they are regarded as their
-    equivalent non-error shapes. Error shapes can be tested for equality like
-    any other shape value, meaning that the error value is equal to itself.
+    Takes one or more shape or extent tensor operands and determines whether
+    they are equal. When extent tensors are compared to shapes they are regarded
+    as their equivalent non-error shapes. Error shapes can be tested for
+    equality like any other shape value, meaning that the error value is equal
+    to itself.
   }];
 
-  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
-                       Shape_ShapeOrExtentTensorType:$rhs);
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
   let results = (outs I1:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  // Convenience builder alias for the binary version.
+  let builders = [
+  OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+    [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+  ];
+  let extraClassDeclaration = [{
+    // TODO: This should really be automatic. Figure out how to not need this defined.
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+      inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
+                                                             /*width=*/1));
+      return success();
+    };
+  }];
+
+  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5f4396d73d88..2b5d619bf58e 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -474,46 +474,56 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
 LogicalResult
 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
                                     ConversionPatternRewriter &rewriter) const {
-  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
-  // on shapes.
-  if (op.lhs().getType().isa<ShapeType>() ||
-      op.rhs().getType().isa<ShapeType>()) {
+  if (!llvm::all_of(op.shapes(),
+                    [](Value v) { return !v.getType().isa<ShapeType>(); }))
     return failure();
+
+  Type i1Ty = rewriter.getI1Type();
+  if (op.shapes().size() <= 1) {
+    rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
+                                            rewriter.getBoolAttr(true));
+    return success();
   }
 
   ShapeEqOp::Adaptor transformed(operands);
   auto loc = op.getLoc();
   Type indexTy = rewriter.getIndexType();
   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
-  Value eqRank =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
-  Type i1Ty = rewriter.getI1Type();
-  rewriter.replaceOpWithNewOp<IfOp>(
-      op, i1Ty, eqRank,
-      [&](OpBuilder &b, Location loc) {
-        Value one = b.create<ConstantIndexOp>(loc, 1);
-        Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
-        auto loop = b.create<scf::ForOp>(
-            loc, zero, lhsRank, one, ValueRange{init},
-            [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
-              Value conj = args[0];
-              Value lhsExtent =
-                  b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
-              Value rhsExtent =
-                  b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
-              Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
-                                                lhsExtent, rhsExtent);
-              Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
-              b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
-            });
-        b.create<scf::YieldOp>(loc, loop.getResults());
-      },
-      [&](OpBuilder &b, Location loc) {
-        Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
-        b.create<scf::YieldOp>(loc, result);
-      });
+  Value firstShape = transformed.shapes().front();
+  Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
+  Value result = nullptr;
+  // Generate a linear sequence of compares, all with firstShape as lhs.
+  for (Value shape : transformed.shapes().drop_front(1)) {
+    Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
+    Value eqRank =
+        rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
+    auto same = rewriter.create<IfOp>(
+        loc, i1Ty, eqRank,
+        [&](OpBuilder &b, Location loc) {
+          Value one = b.create<ConstantIndexOp>(loc, 1);
+          Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
+          auto loop = b.create<scf::ForOp>(
+              loc, zero, firstRank, one, ValueRange{init},
+              [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+                Value conj = args[0];
+                Value lhsExtent =
+                    b.create<tensor::ExtractOp>(loc, firstShape, iv);
+                Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
+                Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+                                                  lhsExtent, rhsExtent);
+                Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
+                b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+              });
+          b.create<scf::YieldOp>(loc, loop.getResults());
+        },
+        [&](OpBuilder &b, Location loc) {
+          Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
+          b.create<scf::YieldOp>(loc, result);
+        });
+    result = !result ? same.getResult(0)
+                     : rewriter.create<AndOp>(loc, result, same.getResult(0));
+  }
+  rewriter.replaceOp(op, result);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0a5daabcff48..719f4bddb58d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -629,15 +629,15 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
-  if (lhs() == rhs())
-    return BoolAttr::get(getContext(), true);
-  auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
-  if (lhs == nullptr)
-    return {};
-  auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
-  if (rhs == nullptr)
+  bool allSame = true;
+  if (!operands.empty() && !operands[0])
     return {};
-  return BoolAttr::get(getContext(), lhs == rhs);
+  for (Attribute operand : operands.drop_front(1)) {
+    if (!operand)
+      return {};
+    allSame = allSame && operand == operands[0];
+  }
+  return BoolAttr::get(getContext(), allSame);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 385e296177ad..d8aec027a11e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -295,6 +295,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
 
 // -----
 
+// CHECK-LABEL:  @shape_eq
+// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
+func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]]
+  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+  // CHECK:   %[[C1:.*]] = constant 1 : index
+  // CHECK:   %[[INIT:.*]] = constant true
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+  // CHECK:     %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
+  // CHECK:     %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
+  // CHECK:   }
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: } else {
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = constant false
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: }
+  // CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]]
+  // CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+  // CHECK:   %[[C1:.*]] = constant 1 : index
+  // CHECK:   %[[INIT:.*]] = constant true
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+  // CHECK:     %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
+  // CHECK:     %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
+  // CHECK:   }
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: } else {
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = constant false
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: }
+  // CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
+  // CHECK: return %[[RESULT]] : i1
+  %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
+  return %result : i1
+}
+
+// -----
+
 // Don't lower `shape.broadcast` if a `shape.shape` type is involved.
 // CHECK-LABEL: @broadcast
 func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b5828fe53bd8..5ee495d66f18 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -864,7 +864,8 @@ func @shape_eq_fold_1() -> i1 {
   // CHECK: return %[[RESULT]] : i1
   %a = shape.const_shape [1, 2, 3] : !shape.shape
   %b = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  %result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
+  %c = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
   return %result : i1
 }
 
@@ -877,7 +878,8 @@ func @shape_eq_fold_0() -> i1 {
   // CHECK: return %[[RESULT]] : i1
   %a = shape.const_shape [1, 2, 3] : tensor<?xindex>
   %b = shape.const_shape [4, 5, 6] : tensor<?xindex>
-  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+  %c = shape.const_shape [4, 5, 6] : tensor<?xindex>
+  %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
   return %result : i1
 }
 
@@ -908,19 +910,6 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
   return %result : i1
 }
 
-
-// -----
-
-// Fold `shape_eq` for non-constant but same shapes.
-// CHECK-LABEL: @shape_eq_do_fold
-// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
-func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
-  // CHECK: %[[RESULT:.*]] = constant true
-  // CHECK: return %[[RESULT]] : i1
-  %result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
-  return %result : i1
-}
-
 // -----
 
 // Fold `mul` for constant sizes.


        


More information about the Mlir-commits mailing list