[Mlir-commits] [mlir] 2f025e0 - [mlir][shape] Add dim op
Jacques Pienaar
llvmlistbot at llvm.org
Fri Aug 12 11:02:13 PDT 2022
Author: Jacques Pienaar
Date: 2022-08-12T11:02:08-07:00
New Revision: 2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf
URL: https://github.com/llvm/llvm-project/commit/2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf
DIFF: https://github.com/llvm/llvm-project/commit/2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf.diff
LOG: [mlir][shape] Add dim op
Convenience op that allows for simple expression of common crossing of
value/shape divide.
Differential Revision: https://reviews.llvm.org/D131497
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6d0d84dbbd0fd..eeace76b632c0 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -328,6 +328,41 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
let hasFolder = 1;
}
+def Shape_DimOp : Shape_Op<"dim",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Gets the specified extent from the shape of a shaped input";
+ let description = [{
+ Gets the extent indexed by `dim` from the shape of the `value` operand. If
+ the dim is error or out-of-bound then it returns an invalid size if the
+ return type carries error information else the behavior is undefined.
+
+ This is a convenience op that performs the equivalent of getting the extent
+ of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`).
+ }];
+ let arguments = (ins AnyShaped:$value,
+ Shape_SizeOrIndexType:$dim);
+ let results = (outs Shape_SizeOrIndexType:$extent);
+ let assemblyFormat = "$value `,` $dim attr-dict `:` type($value) `,` type($dim) `->` "
+ "type($extent)";
+
+ let builders = [
+ // Builder that allows passing a constant dimension as a simple integer.
+ OpBuilder<(ins "Value":$value, "int64_t":$dim)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Get the `dim` value as integer if it is constant.
+ Optional<int64_t> getConstantDim();
+
+ /// Returns when two result types are compatible for this op; method used by
+ /// InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
+
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
def Shape_GetExtentOp : Shape_Op<"get_extent",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from a shape or extent tensor";
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 4d41c51125f03..972484ba435a8 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -322,6 +322,28 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
return success();
}
+namespace {
+class DimOpConverter : public OpConversionPattern<DimOp> {
+ using OpConversionPattern<DimOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(DimOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
+ // lowerings. This can be further optimized if needed to avoid intermediate
+ // steps.
+ auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
+ rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
+ op.getDim());
+ return success();
+}
+
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -693,6 +715,7 @@ void mlir::populateShapeToStandardConversionPatterns(
BroadcastOpConverter,
ConstShapeOpConverter,
ConstSizeOpConversion,
+ DimOpConverter,
IsBroadcastableOpConverter,
GetExtentOpConverter,
RankOpConverter,
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ff7065474bdb4..9c6ab1fbc2212 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1064,6 +1064,58 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
return operands[0];
}
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+Optional<int64_t> DimOp::getConstantDim() {
+ if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
+ return constSizeOp.getValue().getLimitedValue();
+ if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
+ return constantOp.getValue().cast<IntegerAttr>().getInt();
+ return llvm::None;
+}
+
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+ Type valType = getValue().getType();
+ auto valShapedType = valType.dyn_cast<ShapedType>();
+ if (!valShapedType || !valShapedType.hasRank())
+ return nullptr;
+ Optional<int64_t> dim = getConstantDim();
+ if (!dim.has_value())
+ return nullptr;
+ if (dim.value() >= valShapedType.getRank())
+ return nullptr;
+ auto extent = valShapedType.getDimSize(*dim);
+ if (ShapedType::isDynamic(extent))
+ return nullptr;
+ return IntegerAttr::get(IndexType::get(getContext()), extent);
+}
+
+LogicalResult mlir::shape::DimOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ DimOpAdaptor dimOp(operands);
+ inferredReturnTypes.assign({dimOp.getDim().getType()});
+ return success();
+}
+
+bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
+LogicalResult mlir::shape::DimOp::verify() {
+ auto st = getValue().getType().cast<ShapedType>();
+ if (!st.hasRank())
+ return success();
+ if (auto dim = getConstantDim()) {
+ if (*dim < 0 || *dim >= st.getRank())
+ return emitOpError("index is out of range");
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 5c77a68445d27..cb3af973daee2 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -60,6 +60,18 @@ func.func @rank(%shape : !shape.shape) {
// -----
+// Express `shape.dim` as `tensor.dim` when valid.
+// CHECK-LABEL: @dim
+// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
+func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index {
+ // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
+ // CHECK: return %[[RESULT]] : index
+ %result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index
+ return %result : index
+}
+
+// -----
+
// Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a
// `shape_of` operation.
// CHECK-LABEL: @get_extent_shape_of
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 0f442308b3f6a..ab35a69cb0602 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -216,6 +216,12 @@ func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
return %result : index
}
+func.func @get_dim(%arg : memref<?x?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index
+ return %result : index
+}
+
func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
More information about the Mlir-commits
mailing list