[Mlir-commits] [mlir] 335d2df - [mlir][Python][Linalg] Add missing attributes to linalg ops
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Apr 1 01:21:00 PDT 2021
Author: Nicolas Vasilache
Date: 2021-04-01T08:16:50Z
New Revision: 335d2df5335f95d49c864ecdba4fd5731c7c3e89
URL: https://github.com/llvm/llvm-project/commit/335d2df5335f95d49c864ecdba4fd5731c7c3e89
DIFF: https://github.com/llvm/llvm-project/commit/335d2df5335f95d49c864ecdba4fd5731c7c3e89.diff
LOG: [mlir][Python][Linalg] Add missing attributes to linalg ops
This revision tightens up the handling of attributes for both named
and generic linalg ops.
To demonstrate the IR validity, a working e2e Linalg example is added.
Differential Revision: https://reviews.llvm.org/D99430
Added:
mlir/test/Bindings/Python/dialects/linalg/opsrun.py
Modified:
mlir/include/mlir-c/AffineMap.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Bindings/Python/IRAffine.cpp
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/lib/CAPI/IR/AffineMap.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/Bindings/Python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
index de4f42f09f249..e35b7cc6b51d5 100644
--- a/mlir/include/mlir-c/AffineMap.h
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -169,6 +169,17 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
MLIR_CAPI_EXPORTED MlirAffineMap
mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
+/// Returns the simplified affine map resulting from dropping the symbols that
+/// do not appear in any of the individual maps in `affineMaps`.
+/// Asserts that all maps in `affineMaps` are normalized to the same number of
+/// dims and symbols.
+/// Takes a callback `populateResult` to fill the `res` container with value
+/// `m` at entry `idx`. This allows returning without worrying about ownership
+/// considerations.
+MLIR_CAPI_EXPORTED void mlirAffineMapCompressUnusedSymbols(
+ MlirAffineMap *affineMaps, intptr_t size, void *result,
+ void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m));
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index abc3e1b4a6fe7..2463452a8f06d 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -340,6 +340,11 @@ AffineMap simplifyAffineMap(AffineMap map);
/// Drop the dims that are not used.
AffineMap compressUnusedDims(AffineMap map);
+/// Drop the dims that are not used by any of the individual maps in `maps`.
+/// Asserts that all maps in `maps` are normalized to the same number of
+/// dims and symbols.
+SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
+
/// Drop the dims that are not listed in `unusedDims`.
AffineMap compressDims(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims);
@@ -347,6 +352,11 @@ AffineMap compressDims(AffineMap map,
/// Drop the symbols that are not used.
AffineMap compressUnusedSymbols(AffineMap map);
+/// Drop the symbols that are not used by any of the individual maps in `maps`.
+/// Asserts that all maps in `maps` are normalized to the same number of
+/// dims and symbols.
+SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
+
/// Drop the symbols that are not listed in `unusedSymbols`.
AffineMap compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols);
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 73a57d95e1586..5d3b790b35d0e 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -538,6 +538,23 @@ void mlir::python::populateIRAffine(py::module &m) {
printAccum.parts.append(")");
return printAccum.join();
})
+ .def_static("compress_unused_symbols",
+ [](py::list affineMaps, DefaultingPyMlirContext context) {
+ SmallVector<MlirAffineMap> maps;
+ pyListToVector<PyAffineMap, MlirAffineMap>(
+ affineMaps, maps, "attempting to create an AffineMap");
+ std::vector<MlirAffineMap> compressed(affineMaps.size());
+ auto populate = [](void *result, intptr_t idx,
+ MlirAffineMap m) {
+ static_cast<MlirAffineMap *>(result)[idx] = (m);
+ };
+ mlirAffineMapCompressUnusedSymbols(
+ maps.data(), maps.size(), compressed.data(), populate);
+ std::vector<PyAffineMap> res;
+ for (auto m : compressed)
+ res.push_back(PyAffineMap(context->getRef(), m));
+ return res;
+ })
.def_property_readonly(
"context",
[](PyAffineMap &self) { return self.getContext().getObject(); },
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2395a422e354b..682f191387017 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -19,6 +19,13 @@
"emit_named_structured_op",
]
+def isa(cls : Type, ty : Type):
+ try:
+ cls(ty)
+ return True
+ except ValueError:
+ return False
+
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value):
@@ -37,6 +44,8 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
out_arg_defs, outs)
+ result_types = [t for t in out_types if isa(RankedTensorType, t)]
+
# Extract type vars for input/output based types.
type_mapping = dict() # type: Dict[str, Type]
for arg_def, arg_element_type in zip(
@@ -48,30 +57,37 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
# Emit the generic op.
# TODO: Support emission of pure memref form.
indexing_maps_attr = ArrayAttr.get(
- [AffineMapAttr.get(am) for am in op_config.indexing_maps])
+ [AffineMapAttr.get(am)
+ # TODO: linalg verification does not currently allow symbols.
+ # Compress them for now.
+ for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
+ sparse_attr = ArrayAttr.get(
+ [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)])
+ if len(sparse_attr) == 0:
+ sparse_attr = None
- return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types,
- type_mapping, indexing_maps_attr, iterator_types_attr)
+ return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
+ type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value = ()):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
- type_mapping, indexing_maps_attr, iterator_types_attr = \
+ all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
+ type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)
generic_op = linalg.GenericOp(
- result_tensors=out_types,
+ result_tensors=result_types,
inputs=ins,
outputs=outs,
indexing_maps=indexing_maps_attr,
iterator_types=iterator_types_attr,
doc=None, # TODO: Make optional.
library_call=None, # TODO: Make optional.
- sparse=BoolAttr.get(False)) # TODO: Make optional.
+ sparse=sparse_attr) # TODO: Make optional.
# Construct the body.
block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
@@ -84,7 +100,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
body_builder.assign(assignment)
body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
- if len(out_arg_defs) == 1:
+ if len(result_types) == 1:
return generic_op.result
else:
return generic_op.results
@@ -95,8 +111,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
op_class_name: str,
*ins: Value,
outs: Value = ()):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
- type_mapping, indexing_maps_attr, iterator_types_attr = \
+ all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
+ type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)
# If we get here, there must exist a builtin class `op_class_name`.
@@ -107,11 +123,16 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
- named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+ named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
linalgDialect = ctx.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, named_op.operation)
+ # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
+ # attribute that the non-yaml path does not. The non-yaml path hardcodes the
+ # indexing_maps in C++ directly.
+ named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
+ # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
- if len(out_arg_defs) == 1:
+ if len(result_types) == 1:
return named_op.result
else:
return named_op.results
diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
index f532d5dae72e0..e0c07afc3b75e 100644
--- a/mlir/lib/CAPI/IR/AffineMap.cpp
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -137,3 +137,14 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
intptr_t numResults) {
return wrap(unwrap(affineMap).getMinorSubMap(numResults));
}
+
+void mlirAffineMapCompressUnusedSymbols(
+ MlirAffineMap *affineMaps, intptr_t size, void *result,
+ void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {
+ SmallVector<AffineMap> maps;
+ for (intptr_t idx = 0; idx < size; ++idx)
+ maps.push_back(unwrap(affineMaps[idx]));
+ intptr_t idx = 0;
+ for (auto m : mlir::compressUnusedSymbols(maps))
+ populateResult(result, idx++, wrap(m));
+}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index dc9a5c54c7ff7..4e25f8c4deea5 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -543,6 +543,41 @@ AffineMap mlir::compressUnusedDims(AffineMap map) {
return compressDims(map, unusedDims);
}
+static SmallVector<AffineMap>
+compressUnusedImpl(ArrayRef<AffineMap> maps,
+ llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
+ if (maps.empty())
+ return SmallVector<AffineMap>();
+ SmallVector<AffineExpr> allExprs;
+ allExprs.reserve(maps.size() * maps.front().getNumResults());
+ unsigned numDims = maps.front().getNumDims(),
+ numSymbols = maps.front().getNumSymbols();
+ for (auto m : maps) {
+ assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
+ "expected maps with same num dims and symbols");
+ llvm::append_range(allExprs, m.getResults());
+ }
+ AffineMap unifiedMap = compressionFun(
+ AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
+ unsigned unifiedNumDims = unifiedMap.getNumDims(),
+ unifiedNumSymbols = unifiedMap.getNumSymbols();
+ ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
+ SmallVector<AffineMap> res;
+ res.reserve(maps.size());
+ for (auto m : maps) {
+ res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
+ unifiedResults.take_front(m.getNumResults()),
+ m.getContext()));
+ unifiedResults = unifiedResults.drop_front(m.getNumResults());
+ }
+ return res;
+}
+
+SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
+ return compressUnusedImpl(maps,
+ [](AffineMap m) { return compressUnusedDims(m); });
+}
+
AffineMap
mlir::compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
@@ -576,6 +611,11 @@ AffineMap mlir::compressUnusedSymbols(AffineMap map) {
return compressSymbols(map, unusedSymbols);
}
+SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
+ return compressUnusedImpl(
+ maps, [](AffineMap m) { return compressUnusedSymbols(m); });
+}
+
AffineMap mlir::simplifyAffineMap(AffineMap map) {
SmallVector<AffineExpr, 8> exprs;
for (auto e : map.getResults()) {
diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index f27f79a4fb037..5445daefa49f0 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -37,9 +37,9 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
# Note that these all have the same indexing maps. We verify the first and
# then do more permutation tests on casting and body generation
# behavior.
- # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
# CHECK-LABEL: func @test_matmul_mono
# CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index 489aa63f47fdf..afcb5820a2216 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -94,6 +94,7 @@ def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# First check the named form with custom format
# CHECK: linalg.matmul
+ # CHECK-NOT: linalg.memoized_indexing_maps
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
# CHECK-SAME: -> tensor<4x8xf32>
@@ -118,7 +119,7 @@ def named_form(lhs, rhs):
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
+ # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])
diff --git a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py
new file mode 100644
index 0000000000000..c46863ba90360
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py
@@ -0,0 +1,105 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import sys
+from mlir.ir import *
+from mlir.dialects import builtin
+from mlir.dialects import linalg
+from mlir.dialects import std
+from mlir.passmanager import *
+from mlir.execution_engine import *
+
+# Log everything to stderr and flush so that we have a unified stream to match
+# errors/info emitted by MLIR to stderr.
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+boilerplate = """
+func @main() -> f32 attributes {llvm.emit_c_interface} {
+ %v0 = constant 0.0 : f32
+ %v1 = constant 1.0 : f32
+ %v2 = constant 2.0 : f32
+
+ %A = memref.alloc() : memref<4x16xf32>
+ %B = memref.alloc() : memref<16x8xf32>
+ %C = memref.alloc() : memref<4x8xf32>
+ linalg.fill(%A, %v1) : memref<4x16xf32>, f32
+ linalg.fill(%B, %v2) : memref<16x8xf32>, f32
+ linalg.fill(%C, %v0) : memref<4x8xf32>, f32
+
+ call @matmul_on_buffers(%A, %B, %C) :
+ (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
+
+ %c0 = constant 0 : index
+ %0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
+
+ // TODO: FFI-based solution to allow testing and printing with python code.
+ return %0 : f32
+}
+"""
+
+def transform(module):
+ import mlir.conversions
+ import mlir.dialects.linalg.passes
+ import mlir.transforms
+
+ # TODO: Allow cloning functions from one module to another.
+ # Atm we have to resort to string concatenation.
+ mod = Module.parse(
+ str(module.operation.regions[0].blocks[0].operations[0].operation) +
+ boilerplate)
+ pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
+ "convert-vector-to-llvm," +
+ "convert-std-to-llvm")
+ pm.run(mod)
+ return mod
+
+def test_builtin():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def matmul_on_buffers(lhs, rhs, out):
+ linalg.matmul(lhs, rhs, outs=[out])
+
+ execution_engine = ExecutionEngine(transform(module))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.)
+ execution_engine.invoke("main", res)
+
+ log('RESULT: ', res[0])
+ # CHECK: RESULT: 32.0
+
+test_builtin()
+
+def test_generic():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
+ MemRefType.get((16, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def matmul_on_buffers(lhs, rhs, out):
+ linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
+
+ execution_engine = ExecutionEngine(transform(module))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.)
+ execution_engine.invoke("main", res)
+
+ log('RESULT: ', res[0])
+ # CHECK: RESULT: 32.0
+
+test_generic()
More information about the Mlir-commits
mailing list