[Mlir-commits] [mlir] 9a2769d - [mir][Python][linalg] Support OpDSL extensions in C++.
Tobias Gysi
llvmlistbot at llvm.org
Wed May 19 06:37:33 PDT 2021
Author: Tobias Gysi
Date: 2021-05-19T13:36:56Z
New Revision: 9a2769db801d4c45edb939223abfb3e1a639732f
URL: https://github.com/llvm/llvm-project/commit/9a2769db801d4c45edb939223abfb3e1a639732f
DIFF: https://github.com/llvm/llvm-project/commit/9a2769db801d4c45edb939223abfb3e1a639732f.diff
LOG: [mir][Python][linalg] Support OpDSL extensions in C++.
The patch extends the yaml code generation to support the following new OpDSL constructs:
- captures
- constants
- iteration index accesses
- predefined types
These changes have been introduced by revision
https://reviews.llvm.org/D101364.
Differential Revision: https://reviews.llvm.org/D102075
Added:
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/arguments.py
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/CMakeLists.txt
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/lit.cfg.py
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 085eaed8a8d29..7e8d560e9bca6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,7 +1,7 @@
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
- cpp_op_name: MatmulOp
+ cpp_class_name: MatmulOp
doc: |-
Performs a matrix multiplication of two 2D inputs.
@@ -63,7 +63,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
- cpp_op_name: BatchMatmulOp
+ cpp_class_name: BatchMatmulOp
doc: |-
Performs a batched matrix multiplication of two 3D inputs.
@@ -126,7 +126,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
- cpp_op_name: MatvecOp
+ cpp_class_name: MatvecOp
doc: |-
Performs a matrix-vector multiplication.
@@ -187,7 +187,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
- cpp_op_name: VecmatOp
+ cpp_class_name: VecmatOp
doc: |-
Performs a vector-matrix multiplication.
@@ -248,7 +248,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
- cpp_op_name: DotOp
+ cpp_class_name: DotOp
doc: |-
Performs a dot product of two vectors to a scalar result.
@@ -305,4 +305,160 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: fill_rng_2d
+ cpp_class_name: FillRng2DOp
+ doc: |-
+ Fills the output tensor with pseudo random numbers.
+
+ The operation generations pseudo random numbers using a linear congruential
+ generator. It provides no guarantees regarding the distribution of the
+ generated random numbers. Instead of generating the random numbers
+ sequentially, it instantiates one random number generator per data element
+ and runs them in parallel. The seed operand and the indices of the data
+ element seed the random number generation. The min and max operands limit
+ the range of the generated random numbers.
+ Note: The captures are hard-coded till there is capture support on the C++
+ side.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: O
+ usage: output
+ shape: affine_map<()[s0, s1] -> (s0, s1)>
+ element_type_var: T
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ symbolic_cast:
+ type_var: T
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_const: '2147483647 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_index: 1
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_index: 0
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_const: '42 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_const: '1103515245 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_const: '12345 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_const: '1103515245 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: I32
+ operands:
+ - !ScalarExpression
+ scalar_const: '12345 : i64'
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_const: '1000 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_const: '-1000 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_const: '2.3283063999999999E-10 : f64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: F64
+ operands:
+ - !ScalarExpression
+ scalar_const: '-1000 : i64'
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1801f277069e3..ee2136a8f1105 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -220,14 +220,15 @@ namespace {
class RegionBuilderHelper {
public:
- RegionBuilderHelper(Block &block) : block(block) {}
+ RegionBuilderHelper(MLIRContext *context, Block &block)
+ : context(context), block(block) {}
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand) {
- OpBuilder builder = getBuilder(operand);
+ OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
if (operand.getType() == toType)
@@ -236,11 +237,14 @@ class RegionBuilderHelper {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>())
return builder.create<FPToSIOp>(loc, toType, operand);
+ // Cast index operands directly to the int type.
+ if (operand.getType().isIndex())
+ return builder.create<IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either sign extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth())
return builder.create<SignExtendIOp>(loc, toType, operand);
- else if (toIntType.getWidth() < fromIntType.getWidth())
+ if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<TruncateIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
@@ -251,7 +255,7 @@ class RegionBuilderHelper {
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<FPExtOp>(loc, toFloatType, operand);
- else if (toFloatType.getWidth() < fromFloatType.getWidth())
+ if (toFloatType.getWidth() < fromFloatType.getWidth())
return builder.create<FPTruncOp>(loc, toFloatType, operand);
}
}
@@ -262,19 +266,28 @@ class RegionBuilderHelper {
}
Value applyfn__add(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder(lhs);
+ OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
- else if (isInteger(lhs))
+ if (isInteger(lhs))
return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
+ Value applyfn__sub(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder();
+ if (isFloatingPoint(lhs))
+ return builder.create<SubFOp>(lhs.getLoc(), lhs, rhs);
+ if (isInteger(lhs))
+ return builder.create<SubIOp>(lhs.getLoc(), lhs, rhs);
+ llvm_unreachable("unsupported non numeric type");
+ }
+
Value applyfn__mul(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder(lhs);
+ OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
- else if (isInteger(lhs))
+ if (isInteger(lhs))
return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
@@ -284,18 +297,39 @@ class RegionBuilderHelper {
if (values.empty())
return;
Value first = values.front();
- OpBuilder builder = getBuilder(first);
+ OpBuilder builder = getBuilder();
builder.create<YieldOp>(first.getLoc(), values);
}
+ Value constant(std::string value) {
+ OpBuilder builder = getBuilder();
+ Location loc = builder.getUnknownLoc();
+ Attribute valueAttr = parseAttribute(value, builder.getContext());
+ return builder.create<ConstantOp>(loc, valueAttr.getType(), valueAttr);
+ }
+
+ Value index(int64_t dim) {
+ OpBuilder builder = getBuilder();
+ return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+ }
+
+ Type getIntegerType(unsigned width) {
+ return IntegerType::get(context, width);
+ }
+
+ Type getFloat32Type() { return Float32Type::get(context); }
+
+ Type getFloat64Type() { return Float64Type::get(context); }
+
private:
+ MLIRContext *context;
Block █
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
- OpBuilder getBuilder(Value value) {
- OpBuilder builder(value.getContext());
+ OpBuilder getBuilder() {
+ OpBuilder builder(context);
builder.setInsertionPointToEnd(&block);
return builder;
}
@@ -1476,7 +1510,6 @@ computeReshapeCollapsedType(MemRefType type,
MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
}
-
template <typename AffineExprTy>
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 9b93d33b32467..2ac0641a309f7 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -8,7 +8,7 @@
represent actual op definitions (i.e. YAML).
"""
-from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from mlir import ir as _ir
@@ -36,8 +36,8 @@ def _get_all_dim_defs(self) -> Set[DimDef]:
results = set()
def visit_dim_def(dim_def):
- if isinstance(dim_def, DimDef):
- results.add(dim_def)
+ if isinstance(dim_def, DimDef):
+ results.add(dim_def)
def visit_affine_exprs(expr):
if isinstance(expr, TensorUse):
@@ -52,23 +52,29 @@ def visit_affine_exprs(expr):
def collect_uses(self, uses: Set["TensorUse"]):
"""Collects all TensorUses reachable through this expression."""
+
def visit_tensor_use(expr):
if isinstance(expr, TensorUse):
uses.add(expr)
+
self.visit_tensor_exprs(visit_tensor_use)
def collect_indices(self, indices: Set["index"]):
"""Collects all index accesses reachable through this expression."""
+
def visit_index(expr):
if isinstance(expr, index):
indices.add(expr)
+
self.visit_tensor_exprs(visit_index)
def collect_captures(self, captures: Set["CaptureDef"]):
"""Collects all CaptureDefs reachable through this expression."""
+
def visit_capture_def(expr):
if isinstance(expr, CaptureDef):
captures.add(expr)
+
self.visit_tensor_exprs(visit_capture_def)
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
@@ -159,8 +165,8 @@ def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"):
def __getitem__(self, dims) -> TensorUse:
assert self.owner, "TensorDef is not attached to an op"
- state = AffineBuildState(global_state=self.owner._affine_state,
- allow_new_symbols=False)
+ state = AffineBuildState(
+ global_state=self.owner._affine_state, allow_new_symbols=False)
if not isinstance(dims, tuple):
dims = (dims,) # Handle single subscript case.
# Special case: (None) is a 0d-scalar use.
@@ -196,6 +202,7 @@ def __repr__(self):
return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
f"shape={self.shape})")
+
class CaptureDef(TensorExpression):
"""Defines an SSA value captured by the operation.
@@ -226,6 +233,7 @@ def to_scalar_expression(self) -> ScalarExpression:
def __repr__(self):
return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
+
class Comprehension:
"""Represents a single comprehension."""
@@ -334,23 +342,27 @@ def visit_tensor_exprs(self, callback):
def __repr__(self):
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
+
class const(TensorExpression):
"""Returns the given constant floating point or integer value."""
- def __init__(self, type_var: TypeVar, value: Any):
- if not isinstance(type_var, TypeVar):
- raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}")
- if not (isinstance(value, float) or isinstance(value, int)):
- raise ValueError(f"const requires int or float. Got: {type(value)}")
- self.type_var = type_var
- self.value = value
+ def __init__(self, value: Any):
+ with _ir.Context():
+ if isinstance(value, float):
+ self.value = str(_ir.FloatAttr.get_f64(float(value)))
+ elif isinstance(value, int):
+ self.value = str(
+ _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
+ else:
+ raise ValueError(f"const requires int or float. Got: {type(value)}")
def to_scalar_expression(self) -> ScalarExpression:
- return ScalarConst(self.type_var, self.value).expr()
+ return ScalarConst(self.value).expr()
def __repr__(self):
return f"const({self.type_var}, {self.value})"
+
class index(TensorExpression):
"""Returns the iteration index for a given dimension name.
@@ -358,7 +370,7 @@ class index(TensorExpression):
domain of the operation.
"""
- def __init__(self, dim : DimDef):
+ def __init__(self, dim: DimDef):
self.dim_def = dim
self.dim = -1
@@ -433,7 +445,8 @@ class OpMetadataDef(YAMLObject):
"""Metadata about the op (generally not behavior impacting)."""
yaml_tag = "!LinalgOpMetadata"
- def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
+ def __init__(self, name: str, cpp_class_name: Optional[str],
+ doc: Optional[str]):
self.name = name
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
self.doc = doc
@@ -457,7 +470,8 @@ def __init__(self,
name: str,
cpp_class_name: Optional[str] = None,
doc: Optional[str] = None):
- self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
+ self.metadata = OpMetadataDef(
+ name=name, cpp_class_name=cpp_class_name, doc=doc)
self.registered_tensors = dict() # type: Dict[str, TensorDef]
self.registered_captures = dict() # type: Dict[str, CaptureDef]
self.comprehensions = list() # type: List[Comprehension]
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index a67d18cc37adb..9026e2030e1f2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -11,7 +11,7 @@
to helpers on the comprehension objects themselves.
"""
-from typing import Any, Dict, Optional
+from typing import Dict, Optional
from mlir import ir as _ir
@@ -70,6 +70,7 @@ def to_yaml_custom_dict(self):
def __repr__(self):
return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
+
class CaptureDefConfig(YAMLObject):
"""Wrapper around a CaptureDef."""
yaml_tag = "LinalgCaptureDef"
@@ -113,8 +114,7 @@ def to_yaml_custom_dict(self):
class LinalgStructuredOpConfig(YAMLObject):
- """Configuration for metadata sufficient to construct a linalg single
- contraction named op."""
+ """Configuration for metadata sufficient to construct a linalg named op."""
yaml_tag = "!LinalgStructuredOpConfig"
@@ -156,8 +156,8 @@ def __init__(self,
for cuse in self.uses.values():
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
for cdef in self.tensor_args.values():
- cdef.shape_map = self._normalize_affine_map(cdef.shape_map,
- with_dims=False)
+ cdef.shape_map = self._normalize_affine_map(
+ cdef.shape_map, with_dims=False)
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
@@ -198,8 +198,8 @@ def __init__(self,
for index in collected_indices:
if index.dim_def.dimname not in self.affine_state.all_dims:
raise ValueError(
- f"The dimension {index.dim.dimname} is not part of the iteration "
- f"domain {self.affine_state.all_dims}")
+ f"The dimension {index.dim.dimname} is not part of the iteration "
+ f"domain {self.affine_state.all_dims}")
index.resolve_dimension_name(self.affine_state)
# Generate the scalar assignments (used to build a body).
@@ -210,18 +210,21 @@ def __init__(self,
@property
def ordered_tensor_args(self) -> Sequence[TensorDefConfig]:
- return sorted(self.tensor_args.values(),
- key=lambda tdc: tdc.tensor_def.registered_index)
+ return sorted(
+ self.tensor_args.values(),
+ key=lambda tdc: tdc.tensor_def.registered_index)
@property
def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]:
- return sorted(self.uses.values(),
- key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
+ return sorted(
+ self.uses.values(),
+ key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
@property
def ordered_capture_args(self) -> Sequence[CaptureDefConfig]:
- return sorted(self.capture_args.values(),
- key=lambda cdc: cdc.capture_def.registered_index)
+ return sorted(
+ self.capture_args.values(),
+ key=lambda cdc: cdc.capture_def.registered_index)
@property
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
@@ -252,15 +255,14 @@ def add_tensor_arg(self, tensor_def: TensorDef):
if tensor_def in self.tensor_args:
return
with self.context:
- local_state = AffineBuildState(global_state=self.affine_state,
- allow_new_dims=False)
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_dims=False)
exprs = []
for expr in tensor_def.shape:
exprs.append(expr.build(state=local_state))
assert local_state.local_dim_count == 0
- indexing_map = _ir.AffineMap.get(dim_count=0,
- symbol_count=local_state.symbol_count,
- exprs=exprs)
+ indexing_map = _ir.AffineMap.get(
+ dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
def_config = TensorDefConfig(tensor_def, indexing_map)
self.tensor_args[tensor_def] = def_config
@@ -269,15 +271,16 @@ def add_use(self, tensor_use: TensorUse):
if tensor_use in self.uses:
return
with self.context:
- local_state = AffineBuildState(global_state=self.affine_state,
- allow_new_symbols=False)
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False)
exprs = []
for expr in tensor_use.indices:
exprs.append(expr.build(state=local_state))
assert local_state.local_symbol_count == 0
- indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count,
- symbol_count=local_state.symbol_count,
- exprs=exprs)
+ indexing_map = _ir.AffineMap.get(
+ dim_count=local_state.dim_count,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs)
use_config = TensorUseConfig(tensor_use, indexing_map)
self.uses[tensor_use] = use_config
@@ -299,16 +302,15 @@ def _normalize_affine_map(self,
exprs=list(affine_map.results))
def to_yaml_custom_dict(self):
- self_dict = dict(
- args=self.ordered_tensor_args,
- captures=self.ordered_capture_args,
- # TODO: Refactor the hierarchy internally when supporting more
- # than static (preserving this serialized form).
- indexing_maps=LinalgIndexingMapsConfig(
- static_indexing_maps=self.indexing_maps),
- iterator_types=self.iterator_types,
- assignments=self.assignments,
- )
+ self_dict = dict(args=self.ordered_tensor_args)
+ if self.ordered_capture_args:
+ self_dict["captures"] = self.ordered_capture_args
+ # TODO: Refactor the hierarchy internally when supporting more
+ # than static (preserving this serialized form).
+ self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
+ static_indexing_maps=self.indexing_maps)
+ self_dict["iterator_types"] = self.iterator_types
+ self_dict["assignments"] = self.assignments
return self_dict
def __repr__(self):
@@ -359,9 +361,10 @@ def from_linalg_op_def(
assert len(
tc_op_def.comprehensions) == 1, "Only one comprehension supported"
return [
- LinalgOpConfig(tc_op_def.metadata,
- structured_op=LinalgStructuredOpConfig(
- tc_op_def.comprehensions[0], context)),
+ LinalgOpConfig(
+ tc_op_def.metadata,
+ structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0],
+ context)),
]
def __repr__(self):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 85c77d52fe5f8..5538a9e42e102 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Any, Dict, Sequence
+from typing import Dict, Sequence
from mlir.ir import *
from mlir.dialects import linalg
@@ -19,16 +19,17 @@
"emit_named_structured_op",
]
-def isa(cls : Type, ty : Type):
+
+def isa(cls: Type, ty: Type):
try:
cls(ty)
return True
except ValueError:
return False
+
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
- *ins: Value,
- outs: Sequence[Value],
+ *ins: Value, outs: Sequence[Value],
captures: Sequence[Value]):
all_arg_defs = op_config.ordered_tensor_args
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
@@ -82,11 +83,13 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
# Emit the generic op.
# TODO: Support emission of pure memref form.
- indexing_maps_attr = ArrayAttr.get(
- [AffineMapAttr.get(am)
- # TODO: linalg verification does not currently allow symbols.
- # Compress them for now.
- for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
+ indexing_maps_attr = ArrayAttr.get([
+ AffineMapAttr.get(am)
+ # TODO: linalg verification does not currently allow symbols.
+ # Compress them for now.
+ for am in AffineMap.compress_unused_symbols(op_config.indexing_maps,
+ Context.current)
+ ])
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
@@ -144,7 +147,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
# If we get here, there must exist a builtin class `op_class_name`.
ctx = Context.current
- fully_qualified_name = 'linalg.' + op_name
+ fully_qualified_name = "linalg." + op_name
if (not ctx.is_registered_operation(fully_qualified_name) or
not op_class_name in linalg.__dict__.keys()):
raise NotImplementedError(
@@ -156,7 +159,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
# Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
# attribute that the non-yaml path does not. The non-yaml path hardcodes the
# indexing_maps in C++ directly.
- named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
+ named_op.operation.attributes[
+ "linalg.memoized_indexing_maps"] = indexing_maps_attr
# iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
if len(result_types) == 1:
@@ -168,8 +172,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
class _BodyBuilder:
"""Constructs a structured op body by evaluating assignments."""
- def __init__(self,
- type_mapping: Dict[str, Type],
+ def __init__(self, type_mapping: Dict[str, Type],
block_arg_mapping: Dict[str, Value],
capture_arg_mapping: Dict[str, Value]):
self.type_mapping = type_mapping
@@ -195,12 +198,16 @@ def expression(self, expr: ScalarExpression) -> Value:
try:
return self.capture_arg_mapping[expr.scalar_capture.capture]
except KeyError:
- raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for "
- f"this structured op.")
+ raise ValueError(
+ f"Capture {expr.scalar_capture.capture} is not bound for "
+ f"this structured op.")
elif expr.scalar_const:
- return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value)
+ value_attr = Attribute.parse(expr.scalar_const.value)
+ return std.ConstantOp(value_attr.type, value_attr).result
elif expr.scalar_index:
- return self.index(expr.scalar_index.dim)
+ dim_attr = IntegerAttr.get(
+ IntegerType.get_signless(64), expr.scalar_index.dim)
+ return linalg.IndexOp(IndexType.get(), dim_attr).result
elif expr.scalar_apply:
try:
fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
@@ -217,25 +224,6 @@ def expression(self, expr: ScalarExpression) -> Value:
return self.cast(expr.symbolic_cast.to_type.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
- def constant(self, type_var_name: str, value: Any) -> Value:
- try:
- type = self.type_mapping[type_var_name]
- except KeyError:
- raise ValueError(f"Unbound type variable '{type_var_name}' ("
- f"expected one of {self.type_mappings.keys()}")
- try:
- if(_is_floating_point_type(type)):
- return std.ConstantOp(type, FloatAttr.get(type, float(value))).result
- elif(_is_integer_type(type)):
- return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result
- except ValueError:
- raise ValueError(f"Unable to cast value {value} to type {type}")
- raise NotImplementedError(f"Unimplemented constant type {type}")
-
- def index(self, dim: int) -> Value:
- dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim)
- return linalg.IndexOp(IndexType.get(), dim_attr).result
-
def cast(self, type_var_name: str, operand: Value) -> Value:
try:
to_type = self.type_mapping[type_var_name]
@@ -248,6 +236,7 @@ def cast(self, type_var_name: str, operand: Value) -> Value:
return self._cast_to_integer(to_type, operand)
elif _is_floating_point_type(to_type):
return self._cast_to_floating_point(to_type, operand)
+
def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
@@ -345,6 +334,7 @@ def _get_tensor_def_names(
*tensor_def_configs: TensorDefConfig) -> Sequence[str]:
return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
+
def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
if name in type_mapping:
if type_mapping[name] != type:
@@ -352,18 +342,22 @@ def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
f"{type_mapping[name]} by type {type}")
type_mapping[name] = type
+
def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
return (F64Type.isinstance(t) or F32Type.isinstance(t) or
F16Type.isinstance(t) or BF16Type.isinstance(t))
+
def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)
+
def _is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)
+
def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index bb1938d71f07b..2cc426b6211a0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -13,7 +13,7 @@
can be easily consumed from the C++ side, not necessarily for ergonomics.
"""
-from typing import Any, Optional, Sequence
+from typing import Optional, Sequence
from .yaml_helper import *
from .types import *
@@ -56,6 +56,7 @@ def expr(self) -> "ScalarExpression":
def __repr__(self):
return f"(ScalarArg({self.arg})"
+
class ScalarCapture:
"""A type of ScalarExpression that references a named capture."""
@@ -68,23 +69,24 @@ def expr(self) -> "ScalarExpression":
def __repr__(self):
return f"(ScalarCapture({self.capture})"
+
class ScalarConst:
"""A type of ScalarExpression representing a constant."""
- def __init__(self, type_var: TypeVar, value: Any):
- self.type_var = type_var
+ def __init__(self, value: str):
self.value = value
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_const=self)
def __repr__(self):
- return f"(ScalarConst({self.type_var}, {self.value})"
+ return f"(ScalarConst({self.value})"
+
class ScalarIndex:
"""A type of ScalarExpression accessing an iteration index."""
- def __init__(self, dim : int):
+ def __init__(self, dim: int):
self.dim = dim
def expr(self) -> "ScalarExpression":
@@ -93,9 +95,9 @@ def expr(self) -> "ScalarExpression":
def __repr__(self):
return f"(ScalarIndex({self.dim})"
+
class ScalarSymbolicCast:
- """A type of ScalarExpression that symbolically casts an operand to a TypeVar.
- """
+ """A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
self.to_type = to_type
@@ -142,25 +144,27 @@ def __init__(self,
def to_yaml_custom_dict(self):
if self.scalar_apply:
- return dict(scalar_apply=dict(
- fn_name=self.scalar_apply.fn_name,
- operands=list(self.scalar_apply.operands),
- ))
+ return dict(
+ scalar_apply=dict(
+ fn_name=self.scalar_apply.fn_name,
+ operands=list(self.scalar_apply.operands),
+ ))
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_capture:
return dict(scalar_capture=self.scalar_capture.capture)
elif self.scalar_const:
- return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name,
- attributes=[self.scalar_const.value]))
+ return dict(scalar_const=self.scalar_const.value)
elif self.scalar_index:
return dict(scalar_index=self.scalar_index.dim)
elif self.symbolic_cast:
# Note that even though operands must be arity 1, we write it the
# same way as for apply because it allows handling code to be more
# generic vs having a special form.
- return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name,
- operands=[self.symbolic_cast.operand]))
+ return dict(
+ symbolic_cast=dict(
+ type_var=self.symbolic_cast.to_type.name,
+ operands=[self.symbolic_cast.operand]))
else:
raise ValueError(f"Unexpected ScalarExpression type: {self}")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b52a0e2d65e1c..ad79963450cee 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -7,9 +7,10 @@
@linalg_structured_op
-def matmul(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
"""Performs a matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -20,9 +21,10 @@ def matmul(A=TensorDef(T1, S.M, S.K),
@linalg_structured_op
-def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
+def batch_matmul(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True)):
"""Performs a batched matrix multiplication of two 3D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -33,9 +35,10 @@ def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
@linalg_structured_op
-def matvec(A=TensorDef(T1, S.M, S.N),
- y=TensorDef(T2, S.N),
- x=TensorDef(U, S.M, output=True)):
+def matvec(
+ A=TensorDef(T1, S.M, S.N),
+ y=TensorDef(T2, S.N),
+ x=TensorDef(U, S.M, output=True)):
"""Performs a matrix-vector multiplication.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -46,9 +49,10 @@ def matvec(A=TensorDef(T1, S.M, S.N),
@linalg_structured_op
-def vecmat(y=TensorDef(T1, S.M),
- A=TensorDef(T2, S.M, S.N),
- x=TensorDef(U, S.N, output=True)):
+def vecmat(
+ y=TensorDef(T1, S.M),
+ A=TensorDef(T2, S.M, S.N),
+ x=TensorDef(U, S.N, output=True)):
"""Performs a vector-matrix multiplication.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -59,8 +63,8 @@ def vecmat(y=TensorDef(T1, S.M),
@linalg_structured_op
-def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
- output=True)):
+def dot(
+ A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
"""Performs a dot product of two vectors to a scalar result.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -68,3 +72,31 @@ def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
"""
implements(ContractionOpInterface)
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
+
+ at linalg_structured_op
+def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)):
+ """Fills the output tensor with pseudo random numbers.
+
+ The operation generations pseudo random numbers using a linear congruential
+ generator. It provides no guarantees regarding the distribution of the
+ generated random numbers. Instead of generating the random numbers
+ sequentially, it instantiates one random number generator per data element
+ and runs them in parallel. The seed operand and the indices of the data
+ element seed the random number generation. The min and max operands limit
+ the range of the generated random numbers.
+
+ Note: The captures are hard-coded till there is capture support on the C++
+ side.
+ """
+ min = cast(F64, const(-1000))
+ max = cast(F64, const(+1000))
+ seed = cast(I32, const(42))
+ multiplier = cast(I32, const(1103515245))
+ increment = cast(I32, const(12345))
+ rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+ rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
+ inv_range = cast(F64, const(2.3283064e-10))
+ offset = cast(F64, const(2147483647))
+ scaling = (max - min) * inv_range
+ O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 614a990bace05..79b8d8950b4e2 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -63,6 +63,7 @@ set(MLIR_TEST_DEPENDS
mlir-capi-sparse-tensor-test
mlir-cpu-runner
mlir-linalg-ods-gen
+ mlir-linalg-ods-yaml-gen
mlir-lsp-server
mlir-opt
mlir-reduce
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 251dfe609606c..4a431bd1a54a6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -28,6 +28,54 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
+// -----
+
+func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_fill_rng_2d_f32
+// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>)
+// CHECK-DAG: %[[MIN:.+]] = constant -1000 : i64
+// CHECK-DAG: %[[MAX:.+]] = constant 1000 : i64
+// CHECK-DAG: %[[SEED:.+]] = constant 42 : i32
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+// CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK-DAG: %[[VAL0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+// CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32
+// CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32
+// CHECK-DAG: %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32
+// CHECK-DAG: %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32
+// Skip random number computation for the second index.
+// CHECK-DAG: %[[MIN_CAST1:.+]] = sitofp %[[MIN]] : i64 to f64
+// CHECK-DAG: %[[MAX_CAST:.+]] = sitofp %[[MAX]] : i64 to f64
+// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX_CAST]], %[[MIN_CAST1]] : f64
+// CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
+// CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
+// CHECK-DAG: %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
+// CHECK-DAG: %[[MIN_CAST2:.+]] = sitofp %[[MIN]] : i64 to f64
+// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN_CAST2]] : f64
+// CHECK-DAG: %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32
+// CHECK-NEXT: linalg.yield %[[VAL6]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
+ return %0: tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @generalize_fill_rng_2d_i32
+// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>)
+// Verifies floating point to integer cast.
+// CHECK: %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32
+// CHECK-NEXT: linalg.yield %[[VAL6]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>
+
// -----
// Verifies floating point to integer cast.
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index ad46220f0b609..44f2ff12e5cc4 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -21,7 +21,7 @@
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.test']
+config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
@@ -64,6 +64,7 @@
'mlir-edsc-builder-api-test',
'mlir-cpu-runner',
'mlir-linalg-ods-gen',
+ 'mlir-linalg-ods-yaml-gen',
'mlir-reduce',
'mlir-sdbm-api-test',
]
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
new file mode 100644
index 0000000000000..72b7f6fe7dc98
--- /dev/null
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -0,0 +1,137 @@
+# RUN: mlir-linalg-ods-yaml-gen %s --o-ods-decl=- | FileCheck %s --check-prefix=ODS
+# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
+
+# @linalg_structured_op
+# def test1(O=TensorDef(T, S.M, S.N, output=True)):
+# """Title.
+
+# Detailed description.
+# """
+# O[D.m, D.n] = cast(T, const(42)) + cast(T, index(D.n))
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: test1
+ cpp_class_name: Test1Op
+ doc: |-
+ Title.
+
+ Detailed description.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: O
+ usage: output
+ shape: affine_map<()[s0, s1] -> (s0, s1)>
+ element_type_var: T
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: T
+ operands:
+ - !ScalarExpression
+ scalar_const: '42 : i64'
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: T
+ operands:
+ - !ScalarExpression
+ scalar_index: 1
+
+# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
+
+# ODS: let summary = [{ Title. }];
+# ODS-NEXT: let description = [{
+# ODS-NEXT: Detailed description.
+# ODS-NEXT: }];
+
+# ODS: let arguments =
+# ODS-NEXT: Variadic<AnyShaped>:$inputs,
+# ODS-NEXT: Variadic<AnyShaped>:$outputs
+
+# ODS: let builders =
+# ODS: $_state.addOperands(inputs);
+# ODS-NEXT: $_state.addOperands(outputs);
+# ODS-NEXT: $_state.addAttribute(
+# ODS-NEXT: "operand_segment_sizes",
+# ODS-NEXT: $_builder.getI32VectorAttr({
+# ODS-NEXT: static_cast<int32_t>(inputs.size()),
+# ODS-NEXT: static_cast<int32_t>(outputs.size())}));
+# ODS-NEXT: createAndFillStructuredOpRegion<Test1Op>(
+# ODS-NEXT: $_builder,
+# ODS-NEXT: $_state,
+# ODS-NEXT: TypeRange(inputs),
+# ODS-NEXT: TypeRange(outputs)
+
+# IMPL-LABEL: void Test1Op::regionBuilder
+# IMPL-SAME: (Block &block, ValueRange captures)
+# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
+# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
+# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
+# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]);
+# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
+
+
+# @linalg_structured_op
+# def test2(I=TensorDef(T, S.M, S.N),
+# O=TensorDef(T, S.M, S.N, output=True)):
+# """Title.
+
+# Detailed description.
+# """
+# O[D.m, D.n] = I[D.n, D.m]
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: test2
+ cpp_class_name: Test2Op
+ doc: |-
+ Title.
+
+ Detailed description.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: I
+ usage: input
+ shape: affine_map<()[s0, s1] -> (s0, s1)>
+ element_type_var: T
+ - !<LinalgTensorDef>
+ name: O
+ usage: output
+ shape: affine_map<()[s0, s1] -> (s0, s1)>
+ element_type_var: T
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_arg: I
+
+# IMPL-LABEL: Test2Op::iterator_types()
+# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
+
+# IMPL: Test2Op::indexing_maps()
+# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
+# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
+
+# IMPL: void Test2Op::regionBuilder(Block &block, ValueRange captures)
+# IMPL: yields.push_back(block.getArgument(0));
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
new file mode 100644
index 0000000000000..ce11188ba32dc
--- /dev/null
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -0,0 +1,37 @@
+# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s
+
+from mlir.dialects.linalg.opdsl.lang import *
+
+
+# CHECK: ---
+# CHECK-LABEL: matmul
+# CHECK: args:
+# CHECK: name: A
+# CHECK: usage: input
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: element_type_var: T
+# CHECK: name: B
+# CHECK: usage: input
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: element_type_var: T
+# CHECK: name: C
+# CHECK: usage: output
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: element_type_var: U
+ at linalg_structured_op
+def matmul(
+ A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+# CHECK: ---
+# CHECK-LABEL: fill
+# CHECK: captures:
+# CHECK: - !<LinalgCaptureDef>
+# CHECK: name: value
+# CHECK: type_var: T
+ at linalg_structured_op
+def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+ O[D.m, D.n] = value
diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index e96bc0de2204a..32c56d1649ad2 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -2,6 +2,7 @@
from mlir.dialects.linalg.opdsl.lang import *
+
# CHECK: ---
# CHECK-LABEL: matmul
# CHECK: assignments:
@@ -23,7 +24,65 @@
# CHECK: operands:
# CHECK: scalar_arg: B
@linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
- B=TensorDef(T, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+ A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+# CHECK: ---
+# CHECK-LABEL: constants
+# CHECK: assignments:
+# CHECK: -
+# CHECK: arg: O
+# CHECK: scalar_apply:
+# CHECK: fn_name: sub
+# CHECK: operands:
+# CHECK: scalar_apply:
+# CHECK: fn_name: add
+# CHECK: operands:
+# CHECK: symbolic_cast:
+# CHECK: type_var: T
+# CHECK: operands:
+# CHECK: scalar_const: '3.1415926535897931 : f64'
+# CHECK: symbolic_cast:
+# CHECK: type_var: T
+# CHECK: operands:
+# CHECK: scalar_const: '42 : i64'
+# CHECK: symbolic_cast:
+# CHECK: type_var: T
+# CHECK: operands:
+# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
+ at linalg_structured_op
+def constants(O=TensorDef(T, S.M, S.K, output=True)):
+ pi = cast(T, const(3.1415926535897931))
+ cst42 = cast(T, const(42))
+ cst1000 = cast(T, const(1e+3))
+ O[D.m, D.n] = pi + cst42 - cst1000
+
+
+# CHECK: ---
+# CHECK-LABEL: indices
+# CHECK: assignments:
+# CHECK: -
+# CHECK: arg: O
+# CHECK: scalar_apply:
+# CHECK: fn_name: add
+# CHECK: operands:
+# CHECK: scalar_index: 1
+# CHECK: scalar_index: 0
+ at linalg_structured_op
+def indices(O=TensorDef(T, S.M, S.K, output=True)):
+ O[D.m, D.n] = index(D.n) + index(D.m)
+
+
+# CHECK: ---
+# CHECK-LABEL: fill
+# CHECK: assignments:
+# CHECK: -
+# CHECK: arg: O
+# CHECK: scalar_capture: value
+ at linalg_structured_op
+def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+ O[D.m, D.n] = value
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 91274dd797403..f84db9b407a70 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -1,7 +1,5 @@
# RUN: %PYTHON %s | FileCheck %s
-from typing import Optional, Sequence
-
from mlir.ir import *
from mlir.dialects import builtin
from mlir.dialects import linalg
@@ -11,30 +9,36 @@
@linalg_structured_op
-def matmul_mono(A=TensorDef(T, S.M, S.K),
- B=TensorDef(T, S.K, S.N),
- C=TensorDef(T, S.M, S.N, output=True)):
+def matmul_mono(
+ A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(T, S.M, S.N, output=True)):
C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
@linalg_structured_op
-def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
- B=TensorDef(TV.T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
+def matmul_poly(
+ A=TensorDef(TV.T1, S.M, S.K),
+ B=TensorDef(TV.T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
@linalg_structured_op
-def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
- min=CaptureDef(F64),
- max=CaptureDef(F64),
- seed=CaptureDef(I32)):
- multiplier = const(I32, 1103515245)
- increment = const(I32, 12345)
- temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
- temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment
- inv_randmax = const(F64, 2.3283064e-10)
- scaling = (max - min) * inv_randmax
- A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min)
+def fill_rng(
+ O=TensorDef(T, S.M, S.N, output=True),
+ min=CaptureDef(F64),
+ max=CaptureDef(F64),
+ seed=CaptureDef(I32)):
+ multiplier = cast(I32, const(1103515245))
+ increment = cast(I32, const(12345))
+ rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+ rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
+ inv_range = cast(F64, const(2.3283064e-10))
+ offset = cast(F64, const(2147483647))
+ scaling = (max - min) * inv_range
+ O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
+
with Context() as ctx, Location.unknown():
module = Module.create()
@@ -64,8 +68,8 @@ def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
# CHECK-SAME: ins(%[[A]], %[[B]]
# CHECK-SAME: outs(%[[INITC]]
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
- RankedTensorType.get((16, 8), f32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32))
def test_matmul_mono(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
return matmul_mono(lhs, rhs, outs=[init_result.result])
@@ -78,9 +82,9 @@ def test_matmul_mono(lhs, rhs):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), i32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), i32))
def test_i8i8i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -92,9 +96,9 @@ def test_i8i8i32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i16),
- RankedTensorType.get((4, 8), i32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16),
+ RankedTensorType.get((4, 8), i32))
def test_i8i16i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -106,9 +110,9 @@ def test_i8i16i32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
# CHECK-NEXT: linalg.yield %[[ADD]] : i16
# CHECK-NEXT: -> tensor<4x8xi16>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
- RankedTensorType.get((16, 8), i32),
- RankedTensorType.get((4, 8), i16))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32),
+ RankedTensorType.get((4, 8), i16))
def test_i32i32i16_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -120,9 +124,9 @@ def test_i32i32i16_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), f32))
def test_i8i8f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -134,9 +138,9 @@ def test_i8i8f32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
- RankedTensorType.get((16, 8), f16),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16),
+ RankedTensorType.get((4, 8), f32))
def test_f16f16f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -148,33 +152,36 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
- RankedTensorType.get((16, 8), f64),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64),
+ RankedTensorType.get((4, 8), f32))
def test_f64f64f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
- # CHECK-LABEL: @test_fill_rng_2d
+ # CHECK-LABEL: @test_fill_rng
# CHECK-SAME: %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32
# CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
# CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
# CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
# CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
# CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
- # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32
- # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32
- # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32
- # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32
- # CHECK: %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64
+ # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64
+ # CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
+ # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i64
+ # CHECK-DAG: %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32
+ # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32
+ # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32
+ # Skip random number computation for the second index.
# CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
- # CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
- # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
- # CHECK-DAG: %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64
+ # CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
+ # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
+ # CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
# CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
# CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32
- @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
- f64, f64, i32)
- def test_fill_rng_2d(init_result, min, max, seed):
- return fill_rng_2d(outs=[init_result], captures=[min, max, seed])
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), i32), f64, f64, i32)
+ def test_fill_rng(init_result, min, max, seed):
+ return fill_rng(outs=[init_result], captures=[min, max, seed])
+
print(module)
diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index c46863ba90360..8d48f0a340620 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -8,13 +8,15 @@
from mlir.passmanager import *
from mlir.execution_engine import *
+
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
-boilerplate = """
+
+matmul_boiler = """
func @main() -> f32 attributes {llvm.emit_c_interface} {
%v0 = constant 0.0 : f32
%v1 = constant 1.0 : f32
@@ -27,7 +29,7 @@ def log(*args):
linalg.fill(%B, %v2) : memref<16x8xf32>, f32
linalg.fill(%C, %v0) : memref<4x8xf32>, f32
- call @matmul_on_buffers(%A, %B, %C) :
+ call @matmul_on_buffers(%A, %B, %C) :
(memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
%c0 = constant 0 : index
@@ -38,7 +40,23 @@ def log(*args):
}
"""
-def transform(module):
+fill_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+ %O = memref.alloc() : memref<4x16xi32>
+
+ call @fill_on_buffers(%O) :
+ (memref<4x16xi32>) -> ()
+
+ %c0 = constant 0 : index
+ %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
+
+ // TODO: FFI-based solution to allow testing and printing with python code.
+ return %0 : i32
+}
+"""
+
+
+def transform(module, boilerplate):
import mlir.conversions
import mlir.dialects.linalg.passes
import mlir.transforms
@@ -46,26 +64,27 @@ def transform(module):
# TODO: Allow cloning functions from one module to another.
# Atm we have to resort to string concatenation.
mod = Module.parse(
- str(module.operation.regions[0].blocks[0].operations[0].operation) +
- boilerplate)
- pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
- "convert-vector-to-llvm," +
- "convert-std-to-llvm")
+ str(module.operation.regions[0].blocks[0].operations[0].operation) +
+ boilerplate)
+ pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
+ "convert-vector-to-llvm," + "convert-std-to-llvm")
pm.run(mod)
return mod
-def test_builtin():
+
+def test_matmul_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
- @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32))
def matmul_on_buffers(lhs, rhs, out):
linalg.matmul(lhs, rhs, outs=[out])
-
- execution_engine = ExecutionEngine(transform(module))
+
+ execution_engine = ExecutionEngine(transform(module, matmul_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
# Prepare arguments: one result f32.
@@ -74,23 +93,26 @@ def matmul_on_buffers(lhs, rhs, out):
res = c_float_p(-1.)
execution_engine.invoke("main", res)
- log('RESULT: ', res[0])
+ log("RESULT: ", res[0])
# CHECK: RESULT: 32.0
-test_builtin()
-def test_generic():
+test_matmul_builtin()
+
+
+def test_matmul_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
- @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32))
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32))
def matmul_on_buffers(lhs, rhs, out):
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
-
- execution_engine = ExecutionEngine(transform(module))
+
+ execution_engine = ExecutionEngine(transform(module, matmul_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
# Prepare arguments: one result f32.
@@ -99,7 +121,62 @@ def matmul_on_buffers(lhs, rhs, out):
res = c_float_p(-1.)
execution_engine.invoke("main", res)
- log('RESULT: ', res[0])
+ log("RESULT: ", res[0])
# CHECK: RESULT: 32.0
-test_generic()
+
+test_matmul_generic()
+
+
+def test_fill_builtin():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
+ def fill_on_buffers(out):
+ linalg.fill_rng_2d(outs=[out])
+
+ execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -480
+
+
+test_fill_builtin()
+
+
+def test_fill_generic():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
+ def fill_on_buffers(out):
+ linalg.fill_rng_2d(outs=[out])
+
+ execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: -480
+
+
+test_fill_generic()
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 53a5807bd1797..1c5a20b80f538 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -51,7 +51,7 @@ struct LinalgYAMLContext {
struct LinalgOpMetadata {
std::string name;
- std::string cppOpName;
+ std::string cppClassName;
Optional<std::string> doc;
SmallVector<std::string> implements;
};
@@ -102,6 +102,8 @@ struct ScalarSymbolicCast {
struct ScalarExpression {
Optional<std::string> arg;
+ Optional<std::string> constant;
+ Optional<int64_t> index;
Optional<ScalarApply> apply;
Optional<ScalarSymbolicCast> symbolicCast;
};
@@ -208,7 +210,7 @@ template <>
struct MappingTraits<LinalgOpMetadata> {
static void mapping(IO &io, LinalgOpMetadata &info) {
io.mapRequired("name", info.name);
- io.mapRequired("cpp_op_name", info.cppOpName);
+ io.mapRequired("cpp_class_name", info.cppClassName);
io.mapOptional("doc", info.doc);
io.mapOptional("implements", info.implements);
}
@@ -247,6 +249,8 @@ template <>
struct MappingTraits<ScalarExpression> {
static void mapping(IO &io, ScalarExpression &info) {
io.mapOptional("scalar_arg", info.arg);
+ io.mapOptional("scalar_const", info.constant);
+ io.mapOptional("scalar_index", info.index);
io.mapOptional("scalar_apply", info.apply);
io.mapOptional("symbolic_cast", info.symbolicCast);
}
@@ -370,12 +374,26 @@ findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
return None;
}
-static Optional<int>
-findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+// Try to map the TypeVar to a predefined or an argument type.
+static Optional<std::string>
+findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+ // Handle all predefined types.
+ if (typeVar == "I32")
+ return std::string("helper.getIntegerType(32)");
+ if (typeVar == "I64")
+ return std::string("helper.getIntegerType(64)");
+ if (typeVar == "F32")
+ return std::string("helper.getFloat32Type()");
+ if (typeVar == "F64")
+ return std::string("helper.getFloat64Type()");
+
+ // Search all argument types.
for (auto it : llvm::enumerate(args)) {
if (it.value().elementTypeVar == typeVar)
- return it.index();
+ return llvm::formatv("block.getArgument({0}).getType()", it.index())
+ .str();
}
+
return None;
}
@@ -563,10 +581,10 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
- os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName,
- opConfig.metadata->name, interfaceNameList, doc, attrList,
- opConfig.structuredOp->args.size(), attrBuilder,
- attrMethods);
+ os << llvm::formatv(
+ structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName,
+ opConfig.metadata->name, interfaceNameList, doc, attrList,
+ opConfig.structuredOp->args.size(), attrBuilder, attrMethods);
return success();
}
@@ -578,7 +596,7 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
return success();
raw_ostream &os = genContext.defns();
- StringRef className = opConfig.metadata->cppOpName;
+ StringRef className = opConfig.metadata->cppClassName;
// Implementation banner.
std::string bannerComment = llvm::formatv("Implementation of {0}", className);
@@ -734,12 +752,15 @@ std::string {0}::getLibraryCallName() {{
{
// Generates a regionBuilder method. Parameters.
// {0}: Class name
- // {1}: Statements
+ // {1}: Number of args
+ // {2}: Statements
static const char structuredOpRegionBuilderFormat[] = R"FMT(
void {0}::regionBuilder(Block &block, ValueRange captures) {{
- RegionBuilderHelper helper(block);
+ assert({1} > 0 && block.getNumArguments() == {1} &&
+ "{0} regionBuilder expects {1} (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
SmallVector<Value> yields;
- {1}
+ {2}
helper.yieldOutputs(yields);
}
)FMT";
@@ -769,12 +790,27 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
if (!argIndex) {
emitError(genContext.getLoc())
- << "scalar argument not defined on the op: " << arg.name;
+ << "scalar argument not defined on the op: " << *expression.arg;
return None;
}
return std::string(
llvm::formatv("block.getArgument({0})", *argIndex));
- } else if (expression.apply) {
+ }
+ if (expression.constant) {
+ std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+ stmts.push_back(
+ llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
+ cppIdent, expression.constant));
+ return cppIdent;
+ }
+ if (expression.index) {
+ // Access an iteration index.
+ std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+ stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
+ cppIdent, *expression.index));
+ return cppIdent;
+ }
+ if (expression.apply) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
@@ -790,7 +826,8 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
expression.apply->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
- } else if (expression.symbolicCast) {
+ }
+ if (expression.symbolicCast) {
// Symbolic cast.
// Operands must be arity 1.
if (expression.symbolicCast->operands.size() != 1) {
@@ -803,29 +840,23 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
if (!operandCppValue)
return None;
- // Try to map the TypeVar to an arg index (which map to block arg
- // indices), since we can just get that type directly.
- // TODO: Handle free type variables which do not map to an argument.
- Optional<int> typeArgIndex =
- findTypeVarArgIndex(expression.symbolicCast->typeVar, args);
- if (!typeArgIndex) {
+ Optional<std::string> typeCppValue =
+ findTypeValue(expression.symbolicCast->typeVar, args);
+ if (!typeCppValue) {
emitError(genContext.getLoc())
<< "type variable " << expression.symbolicCast->typeVar
- << ", used in a symbolic cast must map to an argument but it "
- << "does not";
+ << ", used in a symbolic cast must map to a predefined or "
+ << "an argument type but it does not";
return None;
}
- std::string typeCppValue =
- llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex);
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
- cppIdent, typeCppValue,
+ cppIdent, typeCppValue.getValue(),
*operandCppValue));
return cppIdent;
- } else {
- emitError(genContext.getLoc()) << "unknown ScalarExpression type";
- return None;
}
+ emitError(genContext.getLoc()) << "unknown ScalarExpression type";
+ return None;
};
Optional<std::string> cppValue = generateExpression(assignment->value);
if (!cppValue)
@@ -837,7 +868,8 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
return emitError(genContext.getLoc())
<< "mismatched number of assignments vs output arguments";
- os << llvm::formatv(structuredOpRegionBuilderFormat, className,
+ int64_t numOfArgs = args.size();
+ os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
interleaveToString(stmts, "\n "));
}
@@ -937,7 +969,7 @@ int main(int argc, char **argv) {
}
genContext.setLoc(NameLoc::get(
- Identifier::get(opConfig.metadata->cppOpName, &mlirContext)));
+ Identifier::get(opConfig.metadata->cppClassName, &mlirContext)));
if (failed(generateOp(opConfig, genContext))) {
return 1;
}
More information about the Mlir-commits
mailing list