[Mlir-commits] [mlir] af29db6 - [mlir][shape] refine shape.func and shape.with_shape
Jacques Pienaar
llvmlistbot at llvm.org
Mon Aug 22 14:56:42 PDT 2022
Author: Jacques Pienaar
Date: 2022-08-22T14:52:18-07:00
New Revision: af29db64b2c7091070dd623c81872559657e7b3d
URL: https://github.com/llvm/llvm-project/commit/af29db64b2c7091070dd623c81872559657e7b3d
DIFF: https://github.com/llvm/llvm-project/commit/af29db64b2c7091070dd623c81872559657e7b3d.diff
LOG: [mlir][shape] refine shape.func and shape.with_shape
- shape.with_shape supports ExtentTensorType
- add helper to create shape.func
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D131977
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 4570772c7d393..8503b9d5633d6 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -736,7 +736,7 @@ def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
}];
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
- Shape_ShapeType:$shape);
+ Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_ValueShapeType:$result);
let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
@@ -1110,7 +1110,20 @@ def Shape_FuncOp : Shape_Op<"func",
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
+ let builders = [OpBuilder<(ins
+ "StringRef":$name, "FunctionType":$type,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
+ CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
+ >];
+
let extraClassDeclaration = [{
+ static FuncOp create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs = {});
+ static FuncOp create(Location location, StringRef name, FunctionType type,
+ Operation::dialect_attr_range attrs);
+ static FuncOp create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<DictionaryAttr> argAttrs);
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 133e7577335f4..2b8cae83fd414 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1267,6 +1267,43 @@ void FunctionLibraryOp::print(OpAsmPrinter &p) {
// FuncOp
//===----------------------------------------------------------------------===//
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs) {
+ OpBuilder builder(location->getContext());
+ OperationState state(location, getOperationName());
+ FuncOp::build(builder, state, name, type, attrs);
+ return cast<FuncOp>(Operation::create(state));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+ Operation::dialect_attr_range attrs) {
+ SmallVector<NamedAttribute, 8> attrRef(attrs);
+ return create(location, name, type, llvm::makeArrayRef(attrRef));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<DictionaryAttr> argAttrs) {
+ FuncOp func = create(location, name, type, attrs);
+ func.setAllArgAttrs(argAttrs);
+ return func;
+}
+
+void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
+ FunctionType type, ArrayRef<NamedAttribute> attrs,
+ ArrayRef<DictionaryAttr> argAttrs) {
+ state.addAttribute(FuncOp::getSymNameAttrName(state.name),
+ builder.getStringAttr(name));
+ state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
+ TypeAttr::get(type));
+ state.attributes.append(attrs.begin(), attrs.end());
+ state.addRegion();
+
+ if (argAttrs.empty())
+ return;
+ assert(type.getNumInputs() == argAttrs.size());
+ function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
+ /*resultAttrs=*/llvm::None);
+}
+
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 33da6b2f15022..8a90ed88ffd68 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -268,6 +268,12 @@ func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) ->
return %2 : !shape.shape
}
+func.func @shape_with_shape_extent_tensor_type(%a : tensor<?x?x?xf32>, %b : !shape.value_shape) -> !shape.value_shape {
+ %0 = shape.shape_of %a : tensor<?x?x?xf32> -> tensor<3xindex>
+ %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex>
+ return %1 : !shape.value_shape
+}
+
func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
-> !shape.shape {
%result = shape.any %a, %b, %c
More information about the Mlir-commits
mailing list