[Mlir-commits] [mlir] 7744253 - [mlir][Linalg] Drop check for output indexing maps.

Mahesh Ravishankar llvmlistbot at llvm.org
Fri Aug 26 09:16:19 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-08-26T16:15:55Z
New Revision: 7744253f12a84879aa58522346e919df8b65364e

URL: https://github.com/llvm/llvm-project/commit/7744253f12a84879aa58522346e919df8b65364e
DIFF: https://github.com/llvm/llvm-project/commit/7744253f12a84879aa58522346e919df8b65364e.diff

LOG: [mlir][Linalg] Drop check for output indexing maps.

The current check for form of the output indexing maps disallows
generic ops that return both a reduced and unreduced value. Such an op
seems like it could fall within the scope of a Strucutred op. Drop the
check. The only load-bearing place this was found to cause isseus was
during vectorization, but the fix for that seems to be simple.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D132716

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b940e418df6c3..d2d800309e870 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -683,27 +683,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
-  // Output tensor indexing map may not depend on reduction indices.
-  for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
-    AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
-    for (AffineExpr expr : indexingMap.getResults()) {
-      for (unsigned pos : redDims) {
-        if (expr.isFunctionOfDim(pos)) {
-          std::string exprStr;
-          {
-            llvm::raw_string_ostream os(exprStr);
-            os << expr;
-          }
-          return op->emitOpError(
-                     "unexpected output tensor expression in indexing map #")
-                 << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
-                 << " a.k.a '" << exprStr
-                 << "' is function of reduction iterator 'd" << pos << "'";
-        }
-      }
-    }
-  }
-
   if (!linalgOp.getShapesToLoopsMap())
     return op->emitOpError("expected the shape-to-loops map to be non-null");
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6975937e622c5..94f338a0d1de8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -539,6 +539,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
     return failure();
   }
   for (OpOperand *opOperand : op.getOutputOperands()) {
+    AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
+    if (indexingMap.isPermutation())
+      continue;
+
     Operation *reduceOp = matchLinalgReduction(opOperand);
     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
       LDBG("reduction precondition failed: reduction detection failed");

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 6881dce30a9fd..a7fe2f09e533f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -269,21 +269,6 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
 
 // -----
 
-func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
-                                 %arg1: tensor<?xf32>) {
-  // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd0' is function of reduction iterator 'd0'}}
-  %0 = linalg.generic {
-    indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
-    iterator_types = ["reduction"]}
-       ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
-      outs(%arg1 : tensor<?xf32>) {
-    ^bb(%i: f32, %j: f32):
-      linalg.yield %i: f32
-  } -> tensor<?xf32>
-}
-
-// -----
-
 func.func @generic(%arg0: memref<?x?xf32>) {
   // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}}
   linalg.generic  {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ba21a1481aa64..70857b7e8e4a7 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -350,3 +350,24 @@ func.func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?x
   return %1 : tensor<?x?xf32>
 }
 // CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?x?xf32>, %arg3 : tensor<?x?xf32>) ->
+    (tensor<?x?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"]}
+      ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2, %arg3 : tensor<?x?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32):
+      %1 = arith.mulf %b0, %b1 : f32
+      %2 = arith.addf %1, %b3 : f32
+      linalg.yield %1, %2 : f32, f32
+  } -> (tensor<?x?x?xf32>, tensor<?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @mixed_parallel_reduced_results
+//       CHECK:     linalg.generic

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbc36b12556ed..229530587fdc5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1102,3 +1102,35 @@ func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf
     } -> tensor<6x6x3x3xf32>
   return %result : tensor<6x6x3x3xf32>
 }
+
+// -----
+
+// Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values.
+func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>,
+    %arg1 : tensor<2x4xf32>, %arg2 : tensor<2x4x8xf32>, %arg3 : tensor<2x4xf32>) ->
+    (tensor<2x4x8xf32>, tensor<2x4xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"]}
+      ins(%arg0, %arg1 : tensor<2x4x8xf32>, tensor<2x4xf32>)
+      outs(%arg2, %arg3 : tensor<2x4x8xf32>, tensor<2x4xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32):
+      %1 = arith.mulf %b0, %b1 : f32
+      %2 = arith.addf %1, %b3 : f32
+      linalg.yield %1, %2 : f32, f32
+  } -> (tensor<2x4x8xf32>, tensor<2x4xf32>)
+  return %0#0, %0#1 : tensor<2x4x8xf32>, tensor<2x4xf32>
+}
+// CHECK-LABEL: func @mixed_parallel_reduced_results(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x4x8xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<2x4x8xf32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//   CHECK-DAG:   %[[V0:.+]] = vector.transfer_read %[[ARG0]]
+//   CHECK-DAG:   %[[V1:.+]] = vector.transfer_read %[[ARG1]]
+//   CHECK-DAG:   %[[V2:.+]] = vector.transfer_read %[[ARG3]]
+//   CHECK-DAG:   %[[MUL:.+]] = arith.mulf %[[V0]], %[[V1]]
+//   CHECK-DAG:   %[[ADD:.+]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]]
+//   CHECK-DAG:   vector.transfer_write %[[MUL]], %[[ARG2]]
+//   CHECK-DAG:   vector.transfer_write %[[ADD]], %[[ARG3]]


        


More information about the Mlir-commits mailing list