[Mlir-commits] [mlir] 25bb616 - [mlir][linalg][python] Add attribute support to the YAML codegen.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jun 24 05:45:26 PDT 2021
Author: Tobias Gysi
Date: 2021-06-24T12:33:48Z
New Revision: 25bb61649085c0a6e66630bbffe7faa54cd67829
URL: https://github.com/llvm/llvm-project/commit/25bb61649085c0a6e66630bbffe7faa54cd67829
DIFF: https://github.com/llvm/llvm-project/commit/25bb61649085c0a6e66630bbffe7faa54cd67829.diff
LOG: [mlir][linalg][python] Add attribute support to the YAML codegen.
Extend the yaml code generation to support the index attributes that https://reviews.llvm.org/D104711 added to the OpDSL.
Differential Revision: https://reviews.llvm.org/D104712
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9f5bf5d62755d..58872da9b1dab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -13,19 +13,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: input
- shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ usage: InputOperand
type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- !LinalgOperandDefConfig
name: B
- usage: input
- shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ usage: InputOperand
type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- !LinalgOperandDefConfig
name: C
- usage: output
- shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ usage: OutputOperand
type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -75,19 +75,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: input
- shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ usage: InputOperand
type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- !LinalgOperandDefConfig
name: B
- usage: input
- shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+ usage: InputOperand
type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- !LinalgOperandDefConfig
name: C
- usage: output
- shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ usage: OutputOperand
type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
@@ -138,19 +138,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: input
- shape: affine_map<()[s0, s1] -> (s0, s1)>
+ usage: InputOperand
type_var: T1
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: y
- usage: input
- shape: affine_map<()[s0, s1] -> (s1)>
+ usage: InputOperand
type_var: T2
+ shape_map: affine_map<()[s0, s1] -> (s1)>
- !LinalgOperandDefConfig
name: x
- usage: output
- shape: affine_map<()[s0, s1] -> (s0)>
+ usage: OutputOperand
type_var: U
+ shape_map: affine_map<()[s0, s1] -> (s0)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -199,19 +199,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: y
- usage: input
- shape: affine_map<()[s0, s1] -> (s1)>
+ usage: InputOperand
type_var: T1
+ shape_map: affine_map<()[s0, s1] -> (s1)>
- !LinalgOperandDefConfig
name: A
- usage: input
- shape: affine_map<()[s0, s1] -> (s1, s0)>
+ usage: InputOperand
type_var: T2
+ shape_map: affine_map<()[s0, s1] -> (s1, s0)>
- !LinalgOperandDefConfig
name: x
- usage: output
- shape: affine_map<()[s0, s1] -> (s0)>
+ usage: OutputOperand
type_var: U
+ shape_map: affine_map<()[s0, s1] -> (s0)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d1)>
@@ -260,19 +260,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: input
- shape: affine_map<()[s0] -> (s0)>
+ usage: InputOperand
type_var: T1
+ shape_map: affine_map<()[s0] -> (s0)>
- !LinalgOperandDefConfig
name: B
- usage: input
- shape: affine_map<()[s0] -> (s0)>
+ usage: InputOperand
type_var: T2
+ shape_map: affine_map<()[s0] -> (s0)>
- !LinalgOperandDefConfig
name: C
- usage: output
- shape: affine_map<()[s0] -> ()>
+ usage: OutputOperand
type_var: U
+ shape_map: affine_map<()[s0] -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0)[s0] -> (d0)>
@@ -306,6 +306,83 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
+ cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
+ doc: A depth-wise 2-D convolution operation.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ usage: InputOperand
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s6, s7, s3)>
+ - !LinalgOperandDefConfig
+ name: K
+ usage: InputOperand
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s4, s5, s3)>
+ - !LinalgOperandDefConfig
+ name: O
+ usage: OutputOperand
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s2, s3)>
+ - !LinalgOperandDefConfig
+ name: strides
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s8, s9)>
+ - !LinalgOperandDefConfig
+ name: dilations
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s10, s11)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d4, d5, d3)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d0, d1, d2, d3)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
cpp_class_name: FillRng2DOp
@@ -323,21 +400,21 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: min
- usage: input
+ usage: InputOperand
type_var: F64
- !LinalgOperandDefConfig
name: max
- usage: input
+ usage: InputOperand
type_var: F64
- !LinalgOperandDefConfig
name: seed
- usage: input
+ usage: InputOperand
type_var: I32
- !LinalgOperandDefConfig
name: O
- usage: output
- shape: affine_map<()[s0, s1] -> (s0, s1)>
+ usage: OutputOperand
type_var: T
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> ()>
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 35fe9fe69a5a7..b40ab139c3e73 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -30,6 +30,36 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
// -----
+func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+ %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+ return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+ %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+ return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT: %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32
+// CHECK-NEXT: linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 00a6528ddcd17..471890e5f4a45 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -21,9 +21,9 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: O
- usage: output
- shape: affine_map<()[s0, s1] -> (s0, s1)>
+ usage: OutputOperand
type_var: T
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -86,12 +86,13 @@ structured_op: !LinalgStructuredOpConfig
# @linalg_structured_op
# def test2(I=TensorDef(T, S.M, S.N),
-# O=TensorDef(T, S.M, S.N, output=True)):
+# O=TensorDef(T, S.M, S.N, output=True),
+# strides=AttributeDef(S.SM, S.SN)):
# """Title.
# Detailed description.
# """
-# O[D.m, D.n] = I[D.n, D.m]
+# O[D.m, D.n] = I[D.n * S.SM, D.m * S.SN]
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
@@ -103,25 +104,25 @@ metadata: !LinalgOpMetadata
Detailed description.
structured_op: !LinalgStructuredOpConfig
args:
- - !LinalgOperandDefConfig
- name: value
- usage: input
- type_var: T
- !LinalgOperandDefConfig
name: I
- usage: input
- shape: affine_map<()[s0, s1] -> (s1, s0)>
+ usage: InputOperand
type_var: T
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: O
- usage: output
- shape: affine_map<()[s0, s1] -> (s0, s1)>
+ usage: OutputOperand
type_var: T
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: strides
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- - affine_map<(d0, d1)[s0, s1] -> ()>
- - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
- - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
+ - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>
iterator_types:
- parallel
- parallel
@@ -129,23 +130,41 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: value
- - !ScalarExpression
- scalar_arg: I
+ scalar_arg: I
-# IMPL-LABEL: Test2Op::iterator_types()
-# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
+# ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2"
+
+# ODS: let arguments =
+# ODS-NEXT: Variadic<AnyType>:$inputs,
+# ODS-NEXT: Variadic<AnyShaped>:$outputs,
+# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides
+
+# ODS: "Attribute":$strides
+# ODS: $_state.addAttribute("strides", strides);
+
+# ODS: bool hasDynamicIndexingMaps();
+# ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes();
+
+# IMPL: getSymbolBindings(Test2Op self)
+# IMPL: cst2 = self.strides().getValue<int64_t>({ 0 });
+# IMPL-NEXT: getAffineConstantExpr(cst2, context)
+# IMPL: cst3 = self.strides().getValue<int64_t>({ 1 });
+# IMPL-NEXT: getAffineConstantExpr(cst3, context)
# IMPL: Test2Op::indexing_maps()
-# IMPL: "affine_map<(d0, d1)[s0, s1] -> ()>"
-# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
-# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
+# IMPL: = getSymbolBindings(*this);
+# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>"
+# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>"
+
+# IMPL: Test2Op::getNumRegionArgs() { return 2; }
+
+# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; }
+# IMPL: Test2Op::verifyIndexingMapRequiredAttributes()
+# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
+# IMPL: "missing indexing map required attribute 'strides'"
# IMPL: void Test2Op::regionBuilder(
-# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+# IMPL-NEXT: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
-# IMPL: = helper.applyfn__add(block.getArgument(0), block.getArgument(1));
+# IMPL: yields.push_back(block.getArgument(0));
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index 6c94bec316293..a70e3cdeca99b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -55,7 +55,7 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
@linalg_structured_op
def strided_copy(
- I=TensorDef(T, S.W, S.H),
+ I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
- strides=AttributeDef(S.S0, S.S1)):
- O[D.oh, D.ow] = I[D.h * S.S0, D.w * S.S1]
+ strides=AttributeDef(S.SH, S.SW)):
+ O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW]
diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index 14217014fcd98..e315a5fe9889e 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -210,6 +210,36 @@ def fill_on_buffers(min, max, seed, out):
test_fill_generic()
+def test_conv_builtin():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64),
+ MemRefType.get((1, 2, 4, 1), i32))
+ def conv_on_buffers(input, filter, output):
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+ input, filter, outs=[output], strides=[2, 4], dilations=[1, 2])
+
+ execution_engine = ExecutionEngine(transform(module, conv_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 8
+
+
+test_conv_builtin()
+
+
def test_conv_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index e7660cbd6286e..00c4096d095cf 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -62,13 +62,14 @@ struct SerializedAffineMap {
AffineMap affineMap() { return affineMapAttr.getValue(); }
};
-enum class LinalgOperandDefUsage { input, output };
+enum class LinalgOperandDefUsage { input, output, attribute };
struct LinalgOperandDef {
std::string name;
LinalgOperandDefUsage usage;
- Optional<SerializedAffineMap> shape;
std::string typeVar;
+ Optional<SerializedAffineMap> shapeMap;
+ Optional<SerializedAffineMap> attributeMap;
};
enum class LinalgIteratorTypeDef {
@@ -149,8 +150,8 @@ struct MappingTraits<LinalgOpConfig> {
};
/// A structured op models (at most) a single contraction by modeling
-/// - A list of named arguments (`LinalgOperandDef`), which can be inputs or
-/// outputs.
+/// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
+/// outputs, or index attributes.
/// - List of indexing maps (see `LinalgIndexingMaps`).
/// - Iterator types (see `LinalgIteratorTypeDef`).
/// - List of scalar level assignment (see `ScalarAssign`).
@@ -164,21 +165,28 @@ struct MappingTraits<LinalgStructuredOpConfig> {
}
};
-/// Maps a named tensor- or scalar-argument to an operation, consisting of:
+/// Maps a named tensor, scalar or attribute argument to an operation,
+/// consisting of:
/// - `name`: Must be unique within the operation.
-/// - `usage`: How the argument is used (input, output, etc).
-/// - `shape`: An optional AffineMap from all op symbols to the shape of the
-/// argument. Only tensor-arguments have a shape. Each shape must be
-/// normalized over the same list of symbols and have no dimension inputs.
+/// - `usage`: How the argument is used (input, output, attribute, etc).
/// - `type_var`: The symbolic type variable that binds to the element or self
-/// type of the tensor- or scalar-argument, respectively.
+/// type of the tensor or scalar argument, respectively.
+/// - `shape_map`: An optional AffineMap from all op symbols to the shape of
+/// the argument. Only tensor arguments have a `shape_map`. Each shape must
+/// be normalized over the same list of symbols and have no dimension
+/// inputs.
+/// - `attribute_map`: An optional AffineMap from all op symbols to the
+/// attribute symbols. During op creation these symbols are replaced by the
+/// corresponding `name` attribute values. Only attribute arguments have
+/// an `attribute_map`.
template <>
struct MappingTraits<LinalgOperandDef> {
static void mapping(IO &io, LinalgOperandDef &info) {
io.mapRequired("name", info.name);
io.mapRequired("usage", info.usage);
- io.mapOptional("shape", info.shape);
io.mapRequired("type_var", info.typeVar);
+ io.mapOptional("shape_map", info.shapeMap);
+ io.mapOptional("attribute_map", info.attributeMap);
}
};
@@ -186,8 +194,9 @@ struct MappingTraits<LinalgOperandDef> {
template <>
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
- io.enumCase(value, "input", LinalgOperandDefUsage::input);
- io.enumCase(value, "output", LinalgOperandDefUsage::output);
+ io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
+ io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
+ io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute);
}
};
@@ -425,9 +434,8 @@ static const char bannerFormat[] = R"FMT(
// {2}: op interface list
// {3}: documentation (summary + description)
// {4}: op attribute list
-// {5}: the number of arguments for the op region
-// {6}: builder methods taking standalone attribute parameters
-// {7}: additional methods for attributes used by indexing maps
+// {5}: builder methods taking standalone attribute parameters
+// {6}: additional methods for attributes used by indexing maps
static const char structuredOpOdsHeaderFormat[] = R"FMT(
//===----------------------------------------------------------------------===//
// Op definition for {0}
@@ -491,7 +499,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion();
}]>
- {6}
+ {5}
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{
@@ -514,11 +522,37 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
- {7}
+ {6}
}];
}
)FMT";
+// Builder method taking attribute parameters. Parameters:
+// {0}: Class name
+// {1}: Comma interleaved attribute parameters
+// {2}: Attribute initialization
+static const char structuredOpBuilderFormat[] = R"FMT(
+ , OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, {1}),
+ [{{
+ $_state.addOperands(inputs);
+ $_state.addOperands(outputs);
+ $_state.addTypes(resultTensorTypes);
+ $_state.addAttribute(
+ "operand_segment_sizes",
+ $_builder.getI32VectorAttr({{
+ static_cast<int32_t>(inputs.size()),
+ static_cast<int32_t>(outputs.size())}));
+ createAndFillStructuredOpRegion<{0}>(
+ $_builder,
+ $_state,
+ TypeRange(inputs),
+ TypeRange(outputs));
+ {2}
+ }]>
+)FMT";
+
// The iterator_types() method implementation. Parameters:
// {0}: Class name
// {1}: Comma interleaved iterator type names.
@@ -560,24 +594,53 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
std::string doc;
if (opConfig.metadata->doc) {
- const char *docFmt = R"FMT(
- let summary = [{ {0} }];
- let description = [{
- {1}
- }];
- )FMT";
+ static const char structuredOpDocFmt[] = R"FMT(
+ let summary = [{ {0} }];
+ let description = [{
+ {1}
+ }];
+)FMT";
StringRef summary, description;
std::tie(summary, description) =
StringRef(*opConfig.metadata->doc).trim().split('\n');
- doc = llvm::formatv(docFmt, summary.trim(), description.trim());
+ doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
}
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
- os << llvm::formatv(
- structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName,
- opConfig.metadata->name, interfaceNameList, doc, attrList,
- opConfig.structuredOp->args.size(), attrBuilder, attrMethods);
+ // Assemble the attribute specific logic required for the op definition.
+ if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+ return arg.usage == LinalgOperandDefUsage::attribute;
+ })) {
+ SmallVector<std::string> attrDefs;
+ SmallVector<std::string> attrParams;
+ SmallVector<std::string> attrStmts;
+ for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+ if (arg.usage != LinalgOperandDefUsage::attribute)
+ continue;
+ assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
+ static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}";
+ static const char paramFmt[] = "\"Attribute\":${0}";
+ static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
+ attrDefs.push_back(llvm::formatv(
+ defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name));
+ attrParams.push_back(llvm::formatv(paramFmt, arg.name));
+ attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
+ }
+ attrList = ",\n" + llvm::join(attrDefs, ",\n");
+ attrMethods = R"(
+ bool hasDynamicIndexingMaps();
+ LogicalResult verifyIndexingMapRequiredAttributes();
+ )";
+ attrBuilder = llvm::formatv(
+ structuredOpBuilderFormat, opConfig.metadata->cppClassName,
+ llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
+ }
+
+ os << llvm::formatv(structuredOpOdsHeaderFormat,
+ opConfig.metadata->cppClassName, opConfig.metadata->name,
+ interfaceNameList, doc, attrList, attrBuilder,
+ attrMethods);
return success();
}
@@ -595,6 +658,12 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
std::string bannerComment = llvm::formatv("Implementation of {0}", className);
os << llvm::formatv(bannerFormat, bannerComment);
+ // Compute the number of scalar and tensor arguments.
+ int64_t numOfArgs =
+ llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+ return arg.usage != LinalgOperandDefUsage::attribute;
+ });
+
// Reference iterators.
{
std::string iteratorsStr;
@@ -627,7 +696,6 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
// For each symbol, generate a declaration for it, either with an
// AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
// an attribute).
- // TODO: Implement attribute constants.
// TODO: Possibly lift into a top-level method.
static const char structuredOpSymbolBindingsFormat[] = R"FMT(
static SmallVector<AffineExpr> getSymbolBindings({0} self) {
@@ -641,10 +709,33 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
unsigned symbolCount = firstMap.getNumSymbols();
SmallVector<std::string> symbolBindings;
for (unsigned i = 0; i < symbolCount; ++i) {
- // TODO: Switch and emit constants for attribute bound symbols.
symbolBindings.push_back(llvm::formatv(
" exprs.push_back(getAffineSymbolExpr({0}, context));", i));
}
+
+ // Access an index attribute. Parameters:
+ // {0}: Attribute name
+ // {1}: Symbol position
+ // {2}: Attribute index
+ static const char structuredOpAccessAttrFormat[] = R"FMT(
+int64_t cst{1} = self.{0}().getValue<int64_t>({ {2} });
+exprs.push_back(getAffineConstantExpr(cst{1}, context));
+)FMT";
+ // Update all symbol bindings mapped to an attribute.
+ for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+ if (arg.usage != LinalgOperandDefUsage::attribute)
+ continue;
+ assert(arg.attributeMap.hasValue());
+ for (auto &en :
+ llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
+ if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
+ symbolBindings[symbol.getPosition()] =
+ llvm::formatv(structuredOpAccessAttrFormat, arg.name,
+ symbol.getPosition(), en.index());
+ }
+ }
+ }
+
std::string symbolBindingsStr;
llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
@@ -726,7 +817,7 @@ ArrayAttr {0}::indexing_maps() {
unsigned {0}::getNumRegionArgs() {{ return {1}; }
)FMT";
os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
- opConfig.structuredOp->args.size());
+ numOfArgs);
}
// getLibraryCallName()
@@ -741,6 +832,50 @@ std::string {0}::getLibraryCallName() {{
os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
}
+ // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
+ if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+ return arg.usage == LinalgOperandDefUsage::attribute;
+ })) {
+ std::vector<std::string> attrVerifications;
+ for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+ if (arg.usage != LinalgOperandDefUsage::attribute)
+ continue;
+ assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
+ // Verify index attribute. Paramters:
+ // {0}: Attribute name
+ // {1}: Attribute size
+ static const char attrFmt[] = R"FMT(
+if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
+ if (!attr.getType().getElementType().isInteger(64))
+ return op->emitError(
+ "incorrect element type for indexing map required attribute '{0}'");
+ if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
+ return op->emitError(
+ "incorrect shape for indexing map required attribute '{0}'");
+} else {
+ return op->emitError(
+ "missing indexing map required attribute '{0}'");
+}
+)FMT";
+ attrVerifications.push_back(llvm::formatv(
+ attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults()));
+ }
+
+ // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
+ // {0}: Class name
+ // {1}: Attribute verification
+ static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
+bool {0}::hasDynamicIndexingMaps() {{ return true; }
+LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
+ Operation *op = getOperation();
+ {1}
+ return success();
+}
+)FMT";
+ os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
+ className, llvm::join(attrVerifications, "\n"));
+ }
+
// regionBuilder()
{
// Generates a regionBuilder method. Parameters.
@@ -861,7 +996,6 @@ void {0}::regionBuilder(
return emitError(genContext.getLoc())
<< "mismatched number of assignments vs output arguments";
- int64_t numOfArgs = args.size();
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
interleaveToString(stmts, "\n "));
}
More information about the Mlir-commits
mailing list