[Mlir-commits] [mlir] 4d4cb17 - [mlir][OpDSL] Refactor function handling.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 25 07:08:58 PST 2022
Author: gysit
Date: 2022-02-25T15:05:32Z
New Revision: 4d4cb17da8509156ca690e3d7eaf2e00ab606780
URL: https://github.com/llvm/llvm-project/commit/4d4cb17da8509156ca690e3d7eaf2e00ab606780
DIFF: https://github.com/llvm/llvm-project/commit/4d4cb17da8509156ca690e3d7eaf2e00ab606780.diff
LOG: [mlir][OpDSL] Refactor function handling.
Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn).
Depends On D119718
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D120108
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 5ebd12103b973..fed9d39ed2f3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -45,29 +45,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- attr_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- attr_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul_unsigned
@@ -109,29 +113,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast_unsigned
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
@@ -183,51 +191,59 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
@@ -280,29 +296,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: accum
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: lhs
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: rhs
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
@@ -345,29 +365,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
@@ -420,51 +444,59 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
@@ -505,29 +537,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
@@ -568,29 +604,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matvec
@@ -632,29 +672,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
@@ -694,29 +738,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d
@@ -757,29 +805,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d
@@ -822,29 +874,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d
@@ -890,29 +946,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d_nwc_wcf
@@ -970,29 +1030,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf
@@ -1064,29 +1128,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf_q
@@ -1171,51 +1239,59 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw_fchw
@@ -1287,29 +1363,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d_ndhwc_dhwcf
@@ -1383,29 +1463,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_1d_nwc_wc
@@ -1462,29 +1546,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwc
@@ -1551,29 +1639,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwc_q
@@ -1651,51 +1743,59 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwcm
@@ -1763,29 +1863,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwcm_q
@@ -1865,51 +1969,59 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_sum
@@ -1975,18 +2087,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max
@@ -2052,18 +2166,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max_unsigned
@@ -2129,18 +2245,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nchw_max
@@ -2206,18 +2324,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min
@@ -2283,18 +2403,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min_unsigned
@@ -2360,18 +2482,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_sum
@@ -2443,18 +2567,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_max
@@ -2526,18 +2652,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_min
@@ -2609,18 +2737,20 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_tensor
@@ -2651,12 +2781,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
@@ -2703,107 +2834,128 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: T
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2147483647 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_index: 1
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_index: 0
- fn_name: cast
- !ScalarExpression
scalar_arg: seed
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- fn_name: cast
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
@@ -2811,15 +2963,15 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: min
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2.3283063999999999E-10 : f64'
- fn_name: cast
- !ScalarExpression
scalar_arg: min
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: soft_plus_2d
@@ -2852,28 +3004,33 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: log
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: exp
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 68c08809e16a3..d26aa077096c7 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -133,55 +133,36 @@ def __repr__(self):
f"[{', '.join([repr(i) for i in self.indices])}]")
-class TensorArithFn(TensorExpression):
- """Application of an arithmetic function."""
+class TensorFn(TensorExpression):
+ """Application of a tensor function."""
- def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
- self.arith_fn = arith_fn
- self.args = tuple(args)
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarArithFn(self.arith_fn.fn_name,
- *[arg.to_scalar_expression() for arg in self.args
- ]).expr()
-
- def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
- super().visit_tensor_exprs(callback)
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
-
-
-class TensorTypeFn(TensorExpression):
- """Application of a type conversion function."""
-
- def __init__(self, type_fn: Optional["TypeFn"],
- operand_def: Optional["OperandDef"], type_var: TypeVar,
- arg: TensorExpression):
- if bool(type_fn) + bool(operand_def) != 1:
- raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
- self.type_fn = type_fn
+ def __init__(self, kind: "FunctionKind", name: Optional[str],
+ operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
+ args: Sequence[TensorExpression]):
+ if bool(name) + bool(operand_def) != 1:
+ raise ValueError("One of 'name', 'operand_def' must be specified")
+ self.name = name
+ self.kind = kind
self.operand_def = operand_def
self.type_var = type_var
- self.arg = arg
+ self.args = args
def to_scalar_expression(self) -> ScalarExpression:
if self.operand_def:
- assert self.operand_def.name, "TypeFnAttr not registered with an op"
- fn_name = self.type_fn.fn_name if self.type_fn else None
+ assert self.operand_def.name, "TensorFn not registered with an op"
attr_name = self.operand_def.name if self.operand_def else None
- return ScalarTypeFn(fn_name, attr_name, self.type_var,
- self.arg.to_scalar_expression()).expr()
+ args = [arg.to_scalar_expression() for arg in self.args]
+ return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
super().visit_tensor_exprs(callback)
- self.arg.visit_tensor_exprs(callback)
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
def __repr__(self):
- return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
- f"({self.type_var}, {self.arg})")
+ name = self.operand_def.name if self.operand_def else self.name
+ return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
+ f"args={', '.join(repr(a) for a in self.args)})")
class TensorReduceFn(TensorExpression):
@@ -194,7 +175,7 @@ def __init__(self, reduce_use: "ReduceFnUse",
args: Sequence[TensorExpression]):
self.reduce_use = reduce_use
self.lhs = None # type: Optional[TensorUse]
- self.args = tuple(args)
+ self.args = args
def to_scalar_expression(self) -> ScalarExpression:
if self.lhs is None:
@@ -202,7 +183,8 @@ def to_scalar_expression(self) -> ScalarExpression:
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
+ return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
+ None, full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
@@ -259,6 +241,11 @@ def __repr__(self):
###############################################################################
+class FunctionKind(Enum):
+ ARITH = 0
+ TYPE = 1
+
+
class TypeFnType:
"""Type conversion function.
@@ -269,8 +256,8 @@ class TypeFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
- return TensorTypeFn(self, None, type_var, arg)
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
def __repr__(self):
return f"{self.fn_name}"
@@ -301,8 +288,8 @@ class ArithFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, *args) -> "TensorArithFn":
- return TensorArithFn(self, args)
+ def __call__(self, *args) -> "TensorFn":
+ return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
def __repr__(self):
return f"{self.fn_name}"
@@ -562,8 +549,8 @@ def __init__(self, default: "TypeFnType"):
self.operand_def = OperandDef(
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
- return TensorTypeFn(None, self.operand_def, type_var, arg)
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
###############################################################################
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index fc8c13bfe6ec9..07050f56fa640 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -270,19 +270,19 @@ def expression(self, expr: ScalarExpression) -> Value:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
- elif expr.arith_fn:
- fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
+ elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
+ fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
operand_values = [
- self.expression(operand) for operand in expr.arith_fn.operands
+ self.expression(operand) for operand in expr.scalar_fn.operands
]
return fn(*operand_values)
- elif expr.type_fn:
- fn_name = expr.type_fn.fn_name
- if expr.type_fn.attr_name:
- fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
+ elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
+ fn_name = expr.scalar_fn.fn_name
+ if expr.scalar_fn.attr_name:
+ fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
fn = self._get_function(f"_typefn_{fn_name}")
- operand = self.expression(expr.type_fn.operand)
- return fn(expr.type_fn.type_var.name, operand)
+ operand_value = self.expression(expr.scalar_fn.operands[0])
+ return fn(expr.scalar_fn.type_var.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def yield_outputs(self, *output_names: str):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index af21b40cf27ab..aa894dc10954f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -15,13 +15,13 @@
from typing import Optional, Sequence
-from .yaml_helper import *
+from .comprehension import *
from .types import *
+from .yaml_helper import *
__all__ = [
"ScalarAssign",
- "ScalarArithFn",
- "ScalarTypeFn",
+ "ScalarFn",
"ScalarArg",
"ScalarConst",
"ScalarIndex",
@@ -29,36 +29,27 @@
]
-class ScalarArithFn:
- """A type of ScalarExpression that applies an arithmetic function."""
-
- def __init__(self, fn_name: str, *operands: "ScalarExpression"):
- self.fn_name = fn_name
- self.operands = operands
-
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(arith_fn=self)
-
- def __repr__(self):
- return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
-
-
-class ScalarTypeFn:
- """A type of ScalarExpression that applies a type conversion function."""
+class ScalarFn:
+ """A type of ScalarExpression that applies a function."""
- def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
- type_var: TypeVar, operand: "ScalarExpression"):
+ def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
+ attr_name: Optional[str], type_var: Optional["TypeVar"],
+ operands: Sequence["ScalarExpression"]):
+ if bool(fn_name) + bool(attr_name) != 1:
+ raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+ self.kind = kind
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
- self.operand = operand
+ self.operands = operands
def expr(self) -> "ScalarExpression":
- return ScalarExpression(type_fn=self)
+ return ScalarExpression(scalar_fn=self)
def __repr__(self):
- return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
- f"({self.type_var}, {self.operand})")
+ name = self.fn_name if self.fn_name else self.attr_name
+ return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+ f"operands=[{', '.join(self.operands)}])")
class ScalarArg:
@@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject):
"""An expression on scalar values.
Can be one of:
- - ScalarArithFn
- - ScalarTypeFn
+ - ScalarFn
- ScalarArg
- ScalarConst
- ScalarIndex
- - ScalarSymbolicCast
"""
yaml_tag = "!ScalarExpression"
def __init__(self,
- arith_fn: Optional[ScalarArithFn] = None,
- type_fn: Optional[ScalarTypeFn] = None,
+ scalar_fn: Optional[ScalarFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
- if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
+ if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
bool(scalar_index)) != 1:
- raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
- "'scalar_const', 'scalar_index', must be specified")
- self.arith_fn = arith_fn
- self.type_fn = type_fn
+ raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+ "'scalar_index' must be specified")
+ self.scalar_fn = scalar_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
- if self.arith_fn:
- return dict(
- arith_fn=dict(
- fn_name=self.arith_fn.fn_name,
- operands=list(self.arith_fn.operands),
- ))
- if self.type_fn:
- # Note that even though operands must be arity 1, we write it the
- # same way as for apply because it allows handling code to be more
- # generic vs having a special form.
- type_fn_dict = dict(
- type_var=self.type_fn.type_var.name,
- operands=[self.type_fn.operand],
- )
- if self.type_fn.fn_name:
- type_fn_dict["fn_name"] = self.type_fn.fn_name
- if self.type_fn.attr_name:
- type_fn_dict["attr_name"] = self.type_fn.attr_name
- return dict(type_fn=type_fn_dict)
+ if self.scalar_fn:
+ scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+ if self.scalar_fn.fn_name:
+ scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+ if self.scalar_fn.attr_name:
+ scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+ if self.scalar_fn.type_var:
+ scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+ scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+ return dict(scalar_fn=scalar_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index f4019e85d5d79..660637e669a67 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -39,23 +39,26 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_const: '42 : i64'
- attr_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_index: 1
- attr_name: cast
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
@@ -236,7 +239,8 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
fn_name: cast
type_var: U
operands:
diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index 926a05e579158..5b87216ca4372 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -9,22 +9,24 @@
# CHECK: -
# CHECK: arg: C
# CHECK: value:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
# CHECK: fn_name: add
# CHECK: operands:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
# CHECK: fn_name: mul
# CHECK: operands:
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: attr_name: cast
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: A
+# CHECK: scalar_fn:
+# CHECK: kind: type
# CHECK: attr_name: cast
-# CHECK: type_fn:
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: B
-# CHECK: attr_name: cast
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
@@ -39,21 +41,28 @@ def matmul(
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: sub
# CHECK: operands:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '3.1415926535897931 : f64'
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '42 : i64'
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
@@ -70,7 +79,8 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 7c850e6f4672b..d1fc9ac944942 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -90,28 +90,23 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
-struct ScalarArithFn {
- std::string fnName;
- // NOTE: Must be pure heap allocated container (not SmallVector)
- // due to recursive data type.
- std::vector<ScalarExpression> operands;
-};
+enum class ScalarFnKind { Arith, Type };
-struct ScalarTypeFn {
- std::string typeVar;
+struct ScalarFn {
+ ScalarFnKind kind;
+ Optional<std::string> fnName;
+ Optional<std::string> attrName;
+ Optional<std::string> typeVar;
// NOTE: This must be of arity 1, but to break the self-referential cycle,
// we use a heap allocated vector.
std::vector<ScalarExpression> operands;
- Optional<std::string> fnName;
- Optional<std::string> attrName;
};
struct ScalarExpression {
Optional<std::string> arg;
Optional<std::string> constant;
Optional<int64_t> index;
- Optional<ScalarArithFn> arithFn;
- Optional<ScalarTypeFn> typeFn;
+ Optional<ScalarFn> scalarFn;
};
struct ScalarAssign {
@@ -265,16 +260,23 @@ struct MappingTraits<ScalarAssign> {
/// - `scalar_arg`: An operation argument.
/// - `scalar_const`: A constant definition.
/// - `scalar_index`: An iteration index.
-/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
-/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
+/// - `scalar_fn`: A named function (see `ScalarFn`).
template <>
struct MappingTraits<ScalarExpression> {
static void mapping(IO &io, ScalarExpression &info) {
io.mapOptional("scalar_arg", info.arg);
io.mapOptional("scalar_const", info.constant);
io.mapOptional("scalar_index", info.index);
- io.mapOptional("arith_fn", info.arithFn);
- io.mapOptional("type_fn", info.typeFn);
+ io.mapOptional("scalar_fn", info.scalarFn);
+ }
+};
+
+/// Scalar function kind enum.
+template <>
+struct ScalarEnumerationTraits<ScalarFnKind> {
+ static void enumeration(IO &io, ScalarFnKind &value) {
+ io.enumCase(value, "arith", ScalarFnKind::Arith);
+ io.enumCase(value, "type", ScalarFnKind::Type);
}
};
@@ -284,20 +286,13 @@ struct MappingTraits<ScalarExpression> {
/// - `add(lhs, rhs)`
/// - `mul(lhs, rhs)`
template <>
-struct MappingTraits<ScalarArithFn> {
- static void mapping(IO &io, ScalarArithFn &info) {
- io.mapRequired("fn_name", info.fnName);
- io.mapRequired("operands", info.operands);
- }
-};
-
-template <>
-struct MappingTraits<ScalarTypeFn> {
- static void mapping(IO &io, ScalarTypeFn &info) {
- io.mapRequired("type_var", info.typeVar);
- io.mapRequired("operands", info.operands);
+struct MappingTraits<ScalarFn> {
+ static void mapping(IO &io, ScalarFn &info) {
+ io.mapRequired("kind", info.kind);
io.mapOptional("fn_name", info.fnName);
io.mapOptional("attr_name", info.attrName);
+ io.mapOptional("type_var", info.typeVar);
+ io.mapRequired("operands", info.operands);
}
};
@@ -1060,11 +1055,12 @@ if ({0}Iter != attrs.end()) {{
cppIdent, *expression.index));
return cppIdent;
}
- if (expression.arithFn) {
+ if (expression.scalarFn &&
+ expression.scalarFn->kind == ScalarFnKind::Arith) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
- for (ScalarExpression &operand : expression.arithFn->operands) {
+ for (ScalarExpression &operand : expression.scalarFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
@@ -1073,28 +1069,30 @@ if ({0}Iter != attrs.end()) {{
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
- expression.arithFn->fnName,
+ expression.scalarFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
- if (expression.typeFn) {
+ if (expression.scalarFn &&
+ expression.scalarFn->kind == ScalarFnKind::Type) {
// Symbolic cast.
// Operands must be arity 1.
- if (expression.typeFn->operands.size() != 1) {
+ if (expression.scalarFn->operands.size() != 1) {
emitError(genContext.getLoc())
<< "type conversion operand arity must be 1";
return None;
}
Optional<std::string> operandCppValue =
- generateExpression(expression.typeFn->operands[0]);
+ generateExpression(expression.scalarFn->operands[0]);
if (!operandCppValue)
return None;
+ assert(expression.scalarFn->typeVar.hasValue());
Optional<std::string> typeCppValue =
- findTypeValue(expression.typeFn->typeVar, args);
+ findTypeValue(expression.scalarFn->typeVar.getValue(), args);
if (!typeCppValue) {
emitError(genContext.getLoc())
- << "type variable " << expression.typeFn->typeVar
+ << "type variable " << expression.scalarFn->typeVar.getValue()
<< ", used in a type conversion, must map to a predefined or "
<< "an argument type but it does not";
return None;
@@ -1102,17 +1100,17 @@ if ({0}Iter != attrs.end()) {{
// Use the function name or the attribute to build the type function.
std::string typeFunc = llvm::formatv(
- "TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
- if (expression.typeFn->attrName) {
+ "TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
+ if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
- arg.name == expression.typeFn->attrName.getValue();
+ arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing type function attribute "
- << expression.typeFn->attrName.getValue();
+ << expression.scalarFn->attrName.getValue();
}
- typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
+ typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv(
More information about the Mlir-commits
mailing list