[Mlir-commits] [mlir] 3b2f26a - [mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 13:25:21 PDT 2020
Author: MaheshRavishankar
Date: 2020-04-13T13:23:01-07:00
New Revision: 3b2f26ab05a80ffb3fcee62fd690da2e6d39c4a3
URL: https://github.com/llvm/llvm-project/commit/3b2f26ab05a80ffb3fcee62fd690da2e6d39c4a3
DIFF: https://github.com/llvm/llvm-project/commit/3b2f26ab05a80ffb3fcee62fd690da2e6d39c4a3.diff
LOG: [mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops
The invertPermutation method does not return a nullptr anymore, but
rather returns an empty map for the scalar case. Update the check in
LinalgToLoops to reflect this.
Also add test case for generating scalar code.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/test/Dialect/Linalg/loops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 9717bb874345..6be0bd8ea204 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -652,8 +652,8 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps =
functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
- auto invertedMap = inversePermutation(concatAffineMaps(maps));
- if (!invertedMap) {
+ AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
+ if (invertedMap.isEmpty()) {
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
{}, linalgOp);
return LinalgLoops();
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index a4d3acd91c38..48e4b6ecd10d 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -913,3 +913,46 @@ func @generic_const_init(%arg0: memref<?xf32>) {
// CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32
// CHECKPARALLEL: loop.parallel (%[[i:.*]])
// CHECKPARALLEL: store %[[CONST]], %[[ARG0]]
+
+#scalar_access = [
+ affine_map<() -> ()>,
+ affine_map<() -> ()>,
+ affine_map<() -> ()>
+]
+#scalar_trait = {
+ args_in = 2,
+ args_out = 1,
+ iterator_types = [],
+ indexing_maps = #scalar_access,
+ library_call = "some_external_fn"
+}
+func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
+{
+ linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
+ ^bb(%a : f32, %b : f32, %c : f32) :
+ %0 = addf %a, %b : f32
+ linalg.yield %0 : f32
+ } : memref<f32>, memref<f32>, memref<f32>
+ return
+}
+// CHECKLOOP-LABEL: @scalar_code
+// CHECKLOOP-SAME: %[[ARG0]]: memref<f32>
+// CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
+// CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
+// CHECKLOOP-NOT: loop.for
+// CHECKLOOP-DAG: load %[[ARG0]][]
+// CHECKLOOP-DAG: load %[[ARG1]][]
+// CHECKLOOP-DAG: load %[[ARG2]][]
+// CHECKLOOP: addf
+// CHECKLOOP: store %{{.*}}, %[[ARG2]][]
+
+// CHECKPARALLEL-LABEL: @scalar_code
+// CHECKPARALLEL-SAME: %[[ARG0]]: memref<f32>
+// CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
+// CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
+// CHECKPARALLEL-NOT: loop.for
+// CHECKPARALLEL-DAG: load %[[ARG0]][]
+// CHECKPARALLEL-DAG: load %[[ARG1]][]
+// CHECKPARALLEL-DAG: load %[[ARG2]][]
+// CHECKPARALLEL: addf
+// CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]
More information about the Mlir-commits
mailing list