[Mlir-commits] [mlir] 3a5a610 - [MLIR][Shape] Expose `getShapeVec` and add support for extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Apr 16 05:00:17 PDT 2021


Author: Frederik Gossen
Date: 2021-04-16T13:59:20+02:00
New Revision: 3a5a610e275d227b0f6dc11ea5f943c7ca851052

URL: https://github.com/llvm/llvm-project/commit/3a5a610e275d227b0f6dc11ea5f943c7ca851052
DIFF: https://github.com/llvm/llvm-project/commit/3a5a610e275d227b0f6dc11ea5f943c7ca851052.diff

LOG: [MLIR][Shape] Expose `getShapeVec` and add support for extent tensors

Differential Revision: https://reviews.llvm.org/D100636

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index db2862141ea91..47e61a8c47689 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -31,6 +31,9 @@ namespace shape {
 /// Alias type for extent tensors.
 RankedTensorType getExtentTensorType(MLIRContext *ctx);
 
+// Given an input shape Value, try to obtain the shape's values.
+LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues);
+
 /// The shape descriptor type represents rank and dimension sizes.
 class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
 public:

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b82b28c88f61b..1455271071b05 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -31,6 +31,26 @@ RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
+LogicalResult shape::getShapeVec(Value input,
+                                 SmallVectorImpl<int64_t> &shapeValues) {
+  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
+    auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
+    if (!type.hasRank())
+      return failure();
+    shapeValues = llvm::to_vector<6>(type.getShape());
+    return success();
+  } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
+    shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
+    return success();
+  } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
+    shapeValues = llvm::to_vector<6>(
+        inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
+    return success();
+  } else {
+    return failure();
+  }
+}
+
 static bool isErrorPropagationPossible(TypeRange operandTypes) {
   return llvm::any_of(operandTypes, [](Type ty) {
     return ty.isa<SizeType, ShapeType, ValueShapeType>();
@@ -605,24 +625,6 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // CstrBroadcastableOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-// Given an input shape Value, try to obtain the shape's values.
-LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
-  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
-    auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
-    if (!type.hasRank())
-      return failure();
-    shapeValues = llvm::to_vector<6>(type.getShape());
-    return success();
-  } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
-    shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
-    return success();
-  } else {
-    return failure();
-  }
-}
-} // namespace
-
 void CstrBroadcastableOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   // Canonicalization patterns have overlap with the considerations during

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index a2aaf7b84afc3..7a5170ec3504a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1259,3 +1259,14 @@ func @min_same_arg(%a: !shape.shape) -> !shape.shape {
   // CHECK: return %[[SHAPE]]
   return %1 : !shape.shape
 }
+
+// ----
+
+// CHECK-LABEL: @cstr_broadcastable_folding
+func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
+  // CHECK: const_witness true
+  %0 = shape.shape_of %arg : tensor<?x4xf32> -> tensor<2xindex>
+  %1 = constant dense<[4]> : tensor<1xindex>
+  %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
+  "use"(%2) : (!shape.witness) -> ()
+}


        


More information about the Mlir-commits mailing list