[Mlir-commits] [mlir] 24acade - [mlir][Shape] Make shape_eq nary
Benjamin Kramer
llvmlistbot at llvm.org
Wed Mar 3 07:27:54 PST 2021
Author: Benjamin Kramer
Date: 2021-03-03T16:26:40+01:00
New Revision: 24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc
URL: https://github.com/llvm/llvm-project/commit/24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc
DIFF: https://github.com/llvm/llvm-project/commit/24acadef8acb8ed9320b694b6ed4e1dfe2cc58bc.diff
LOG: [mlir][Shape] Make shape_eq nary
This gets rid of a dubious shape_eq %a, %a fold, that folds shape_eq
even if %a is not an Attribute.
Differential Revision: https://reviews.llvm.org/D97728
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 0a6122801835..c651b84429b8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -168,20 +168,38 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
let hasFolder = 1;
}
-def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
+def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
+ InferTypeOpInterface]> {
let summary = "Returns whether the input shapes or extent tensors are equal";
let description = [{
- Takes two shape or extent tensor operands and determines whether they are
- equal. When extent tensors are compared to shapes they are regarded as their
- equivalent non-error shapes. Error shapes can be tested for equality like
- any other shape value, meaning that the error value is equal to itself.
+ Takes one or more shape or extent tensor operands and determines whether
+ they are equal. When extent tensors are compared to shapes they are regarded
+ as their equivalent non-error shapes. Error shapes can be tested for
+ equality like any other shape value, meaning that the error value is equal
+ to itself.
}];
- let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
- Shape_ShapeOrExtentTensorType:$rhs);
+ let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs I1:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+ // Convenience builder alias for the binary version.
+ let builders = [
+ OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+ [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+ ];
+ let extraClassDeclaration = [{
+ // TODO: This should really be automatic. Figure out how to not need this defined.
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+ inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
+ /*width=*/1));
+ return success();
+ };
+ }];
+
+ let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5f4396d73d88..2b5d619bf58e 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -474,46 +474,56 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- // For now, this lowering is only defined on `tensor<?xindex>` operands, not
- // on shapes.
- if (op.lhs().getType().isa<ShapeType>() ||
- op.rhs().getType().isa<ShapeType>()) {
+ if (!llvm::all_of(op.shapes(),
+ [](Value v) { return !v.getType().isa<ShapeType>(); }))
return failure();
+
+ Type i1Ty = rewriter.getI1Type();
+ if (op.shapes().size() <= 1) {
+ rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
+ rewriter.getBoolAttr(true));
+ return success();
}
ShapeEqOp::Adaptor transformed(operands);
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
- Value eqRank =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
- Type i1Ty = rewriter.getI1Type();
- rewriter.replaceOpWithNewOp<IfOp>(
- op, i1Ty, eqRank,
- [&](OpBuilder &b, Location loc) {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
- auto loop = b.create<scf::ForOp>(
- loc, zero, lhsRank, one, ValueRange{init},
- [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
- Value conj = args[0];
- Value lhsExtent =
- b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
- Value rhsExtent =
- b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
- Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
- lhsExtent, rhsExtent);
- Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
- b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
- });
- b.create<scf::YieldOp>(loc, loop.getResults());
- },
- [&](OpBuilder &b, Location loc) {
- Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
- b.create<scf::YieldOp>(loc, result);
- });
+ Value firstShape = transformed.shapes().front();
+ Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
+ Value result = nullptr;
+ // Generate a linear sequence of compares, all with firstShape as lhs.
+ for (Value shape : transformed.shapes().drop_front(1)) {
+ Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
+ Value eqRank =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
+ auto same = rewriter.create<IfOp>(
+ loc, i1Ty, eqRank,
+ [&](OpBuilder &b, Location loc) {
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
+ auto loop = b.create<scf::ForOp>(
+ loc, zero, firstRank, one, ValueRange{init},
+ [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+ Value conj = args[0];
+ Value lhsExtent =
+ b.create<tensor::ExtractOp>(loc, firstShape, iv);
+ Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
+ Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+ lhsExtent, rhsExtent);
+ Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
+ b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+ });
+ b.create<scf::YieldOp>(loc, loop.getResults());
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
+ b.create<scf::YieldOp>(loc, result);
+ });
+ result = !result ? same.getResult(0)
+ : rewriter.create<AndOp>(loc, result, same.getResult(0));
+ }
+ rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0a5daabcff48..719f4bddb58d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -629,15 +629,15 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
- if (lhs() == rhs())
- return BoolAttr::get(getContext(), true);
- auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
- if (lhs == nullptr)
- return {};
- auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
- if (rhs == nullptr)
+ bool allSame = true;
+ if (!operands.empty() && !operands[0])
return {};
- return BoolAttr::get(getContext(), lhs == rhs);
+ for (Attribute operand : operands.drop_front(1)) {
+ if (!operand)
+ return {};
+ allSame = allSame && operand == operands[0];
+ }
+ return BoolAttr::get(getContext(), allSame);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 385e296177ad..d8aec027a11e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -295,6 +295,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
// -----
+// CHECK-LABEL: @shape_eq
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
+func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]]
+ // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[INIT:.*]] = constant true
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+ // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
+ // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+ // CHECK: scf.yield %[[CONJ_NEXT]] : i1
+ // CHECK: }
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: } else {
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: }
+ // CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]]
+ // CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[INIT:.*]] = constant true
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+ // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
+ // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+ // CHECK: scf.yield %[[CONJ_NEXT]] : i1
+ // CHECK: }
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: } else {
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: }
+ // CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
+ // CHECK: return %[[RESULT]] : i1
+ %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
+ return %result : i1
+}
+
+// -----
+
// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
// CHECK-LABEL: @broadcast
func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b5828fe53bd8..5ee495d66f18 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -864,7 +864,8 @@ func @shape_eq_fold_1() -> i1 {
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : !shape.shape
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
- %result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
+ %c = shape.const_shape [1, 2, 3] : tensor<?xindex>
+ %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@@ -877,7 +878,8 @@ func @shape_eq_fold_0() -> i1 {
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
- %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+ %c = shape.const_shape [4, 5, 6] : tensor<?xindex>
+ %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@@ -908,19 +910,6 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
return %result : i1
}
-
-// -----
-
-// Fold `shape_eq` for non-constant but same shapes.
-// CHECK-LABEL: @shape_eq_do_fold
-// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
-func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
- // CHECK: %[[RESULT:.*]] = constant true
- // CHECK: return %[[RESULT]] : i1
- %result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
- return %result : i1
-}
-
// -----
// Fold `mul` for constant sizes.
More information about the Mlir-commits
mailing list