[Mlir-commits] [mlir] 3a5a610 - [MLIR][Shape] Expose `getShapeVec` and add support for extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Fri Apr 16 05:00:17 PDT 2021
Author: Frederik Gossen
Date: 2021-04-16T13:59:20+02:00
New Revision: 3a5a610e275d227b0f6dc11ea5f943c7ca851052
URL: https://github.com/llvm/llvm-project/commit/3a5a610e275d227b0f6dc11ea5f943c7ca851052
DIFF: https://github.com/llvm/llvm-project/commit/3a5a610e275d227b0f6dc11ea5f943c7ca851052.diff
LOG: [MLIR][Shape] Expose `getShapeVec` and add support for extent tensors
Differential Revision: https://reviews.llvm.org/D100636
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index db2862141ea91..47e61a8c47689 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -31,6 +31,9 @@ namespace shape {
/// Alias type for extent tensors.
RankedTensorType getExtentTensorType(MLIRContext *ctx);
+// Given an input shape Value, try to obtain the shape's values.
+LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues);
+
/// The shape descriptor type represents rank and dimension sizes.
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
public:
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b82b28c88f61b..1455271071b05 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -31,6 +31,26 @@ RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
+LogicalResult shape::getShapeVec(Value input,
+ SmallVectorImpl<int64_t> &shapeValues) {
+ if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
+ auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
+ if (!type.hasRank())
+ return failure();
+ shapeValues = llvm::to_vector<6>(type.getShape());
+ return success();
+ } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
+ shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
+ return success();
+ } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
+ shapeValues = llvm::to_vector<6>(
+ inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ return success();
+ } else {
+ return failure();
+ }
+}
+
static bool isErrorPropagationPossible(TypeRange operandTypes) {
return llvm::any_of(operandTypes, [](Type ty) {
return ty.isa<SizeType, ShapeType, ValueShapeType>();
@@ -605,24 +625,6 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//
-namespace {
-// Given an input shape Value, try to obtain the shape's values.
-LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
- if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
- auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
- if (!type.hasRank())
- return failure();
- shapeValues = llvm::to_vector<6>(type.getShape());
- return success();
- } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
- shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
- return success();
- } else {
- return failure();
- }
-}
-} // namespace
-
void CstrBroadcastableOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// Canonicalization patterns have overlap with the considerations during
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index a2aaf7b84afc3..7a5170ec3504a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1259,3 +1259,14 @@ func @min_same_arg(%a: !shape.shape) -> !shape.shape {
// CHECK: return %[[SHAPE]]
return %1 : !shape.shape
}
+
+// ----
+
+// CHECK-LABEL: @cstr_broadcastable_folding
+func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
+ // CHECK: const_witness true
+ %0 = shape.shape_of %arg : tensor<?x4xf32> -> tensor<2xindex>
+ %1 = constant dense<[4]> : tensor<1xindex>
+ %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
+ "use"(%2) : (!shape.witness) -> ()
+}
More information about the Mlir-commits
mailing list