[Mlir-commits] [mlir] 2f025e0 - [mlir][shape] Add dim op

Jacques Pienaar llvmlistbot at llvm.org
Fri Aug 12 11:02:13 PDT 2022


Author: Jacques Pienaar
Date: 2022-08-12T11:02:08-07:00
New Revision: 2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf

URL: https://github.com/llvm/llvm-project/commit/2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf
DIFF: https://github.com/llvm/llvm-project/commit/2f025e0e78fd57923aceb49c7d4aeb3e5e2d34bf.diff

LOG: [mlir][shape] Add dim op

Convenience op that allows for simple expression of common crossing of
value/shape divide.

Differential Revision: https://reviews.llvm.org/D131497

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6d0d84dbbd0fd..eeace76b632c0 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -328,6 +328,41 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
   let hasFolder = 1;
 }
 
+def Shape_DimOp : Shape_Op<"dim",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Gets the specified extent from the shape of a shaped input";
+  let description = [{
+    Gets the extent indexed by `dim` from the shape of the `value` operand. If
+    the dim is error or out-of-bound then it returns an invalid size if the
+    return type carries error information else the behavior is undefined.
+
+    This is a convenience op that performs the equivalent of getting the extent
+    of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`).
+  }];
+  let arguments = (ins AnyShaped:$value,
+                       Shape_SizeOrIndexType:$dim);
+  let results = (outs Shape_SizeOrIndexType:$extent);
+  let assemblyFormat = "$value `,` $dim attr-dict `:` type($value) `,` type($dim) `->` "
+                       "type($extent)";
+
+  let builders = [
+    // Builder that allows passing a constant dimension as a simple integer.
+    OpBuilder<(ins "Value":$value, "int64_t":$dim)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Get the `dim` value as integer if it is constant.
+    Optional<int64_t> getConstantDim();
+
+    /// Returns when two result types are compatible for this op; method used by
+    /// InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
+
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 def Shape_GetExtentOp : Shape_Op<"get_extent",
     [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Gets the specified extent from a shape or extent tensor";

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 4d41c51125f03..972484ba435a8 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -322,6 +322,28 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   return success();
 }
 
+namespace {
+class DimOpConverter : public OpConversionPattern<DimOp> {
+  using OpConversionPattern<DimOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(DimOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const {
+  // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
+  // lowerings. This can be further optimized if needed to avoid intermediate
+  // steps.
+  auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
+  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
+                                                  op.getDim());
+  return success();
+}
+
 namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -693,6 +715,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       BroadcastOpConverter,
       ConstShapeOpConverter,
       ConstSizeOpConversion,
+      DimOpConverter,
       IsBroadcastableOpConverter,
       GetExtentOpConverter,
       RankOpConverter,

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ff7065474bdb4..9c6ab1fbc2212 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1064,6 +1064,58 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
   return operands[0];
 }
 
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+Optional<int64_t> DimOp::getConstantDim() {
+  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
+    return constSizeOp.getValue().getLimitedValue();
+  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
+    return constantOp.getValue().cast<IntegerAttr>().getInt();
+  return llvm::None;
+}
+
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+  Type valType = getValue().getType();
+  auto valShapedType = valType.dyn_cast<ShapedType>();
+  if (!valShapedType || !valShapedType.hasRank())
+    return nullptr;
+  Optional<int64_t> dim = getConstantDim();
+  if (!dim.has_value())
+    return nullptr;
+  if (dim.value() >= valShapedType.getRank())
+    return nullptr;
+  auto extent = valShapedType.getDimSize(*dim);
+  if (ShapedType::isDynamic(extent))
+    return nullptr;
+  return IntegerAttr::get(IndexType::get(getContext()), extent);
+}
+
+LogicalResult mlir::shape::DimOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  DimOpAdaptor dimOp(operands);
+  inferredReturnTypes.assign({dimOp.getDim().getType()});
+  return success();
+}
+
+bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
+LogicalResult mlir::shape::DimOp::verify() {
+  auto st = getValue().getType().cast<ShapedType>();
+  if (!st.hasRank())
+    return success();
+  if (auto dim = getConstantDim()) {
+    if (*dim < 0 || *dim >= st.getRank())
+      return emitOpError("index is out of range");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DivOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 5c77a68445d27..cb3af973daee2 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -60,6 +60,18 @@ func.func @rank(%shape : !shape.shape) {
 
 // -----
 
+// Express `shape.dim` as `tensor.dim` when valid.
+// CHECK-LABEL: @dim
+// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
+func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index {
+  // CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
+  // CHECK: return %[[RESULT]] : index
+  %result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index
+  return %result : index
+}
+
+// -----
+
 // Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a
 // `shape_of` operation.
 // CHECK-LABEL: @get_extent_shape_of

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 0f442308b3f6a..ab35a69cb0602 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -216,6 +216,12 @@ func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
   return %result : index
 }
 
+func.func @get_dim(%arg : memref<?x?xindex>) -> index {
+  %c0 = arith.constant 0 : index
+  %result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index
+  return %result : index
+}
+
 func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
   %c0 = shape.const_size 0
   %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size


        


More information about the Mlir-commits mailing list