[Mlir-commits] [mlir] [mlir][linalg] `LinalgOp`: Disallow mixed tensor/buffer semantics (PR #80660)

Matthias Springer llvmlistbot at llvm.org
Fri Feb 23 06:47:17 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/80660

>From 225395d575108e719367759f3767fff68d511a97 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 23 Feb 2024 14:46:33 +0000
Subject: [PATCH] [mlir][linalg] `LinalgOp`: Disallow mixed tensor/buffer
 semantics

Related discussion: https://github.com/llvm/llvm-project/pull/73908/files#r1414913030.

This change fixes #73547.
---
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  5 ++
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 55 +++++--------------
 .../Linalg/fusion-elementwise-ops.mlir        | 40 --------------
 mlir/test/Dialect/Linalg/invalid.mlir         | 10 ++++
 4 files changed, 29 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7eed7928456d57..3627ff6617eda3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
 
+  // Mixed tensor/buffer operands are not allowed.
+  if (!linalgOp.hasPureTensorSemantics() &&
+      !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
+    return op->emitOpError("expected to have pure tensor or buffer semantics");
+
   // Before checking indexing maps, we need to make sure the attributes
   // referenced by it are valid.
   if (linalgOp.hasDynamicIndexingMaps())
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7adde3117deeaa..206d7e9f1ce8df 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten
 // -----
 
 // CHECK-LABEL: func @linalg_effects(
-//  CHECK-SAME:     %[[A:[a-z0-9]*]]: tensor<?x?xf32>
-//  CHECK-SAME:     %[[B:[a-z0-9]*]]: memref<?x?xf32>
-//  CHECK-SAME:     %[[C:[a-z0-9]*]]: tensor<?x?xf32>
-func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
+func.func @linalg_effects(
+    %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
+    %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
   // CHECK-NOT:   %{{.*}} = linalg.matmul
-  %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+  %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
                     outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
 
   // CHECK:   linalg.matmul
-  linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
-               outs(%b : memref<?x?xf32>)
+  linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
+               outs(%f : memref<?x?xf32>)
   return
 }
 
@@ -889,11 +888,11 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
 // -----
 
 #map = affine_map<(d0) -> (d0)>
-func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
+func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
   linalg.generic {
     indexing_maps = [#map, #map],
     iterator_types = ["parallel"]
-  } ins(%arg0 : tensor<?xf32>)
+  } ins(%arg0 : memref<?xf32>)
     outs(%arg1 : memref<?xf32>) {
   ^bb0(%arg2 : f32, %arg3 : f32):
     linalg.yield %arg2 : f32
@@ -901,14 +900,13 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
   return
 }
 
-// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
-// For now, check generic remained unchanged.
-// CHECK-LABEL: func @identity_mixed
-//  CHECK-SAME:     (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
+// Do not erase ops with buffer semantics.
+// CHECK-LABEL: func @identity_buffer
+//  CHECK-SAME:     (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
 //       CHECK:     linalg.generic {
 //  CHECK-SAME:    indexing_maps = [#map, #map],
 //  CHECK-SAME:    iterator_types = ["parallel"]
-//  CHECK-SAME:  } ins(%[[ARG1]] : tensor<?xf32>)
+//  CHECK-SAME:  } ins(%[[ARG1]] : memref<?xf32>)
 //  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {
 
 // -----
@@ -916,12 +914,12 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
 // Just make sure that we don't crash.
 
 // CHECK-LABEL: func @dedeplicate_regression_test
-func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
+func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
   %36 = linalg.generic
     {indexing_maps = [affine_map<(d0) -> (d0)>,
                       affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
      iterator_types = ["parallel"]}
-    ins(%1, %1 : memref<4xf32>, memref<4xf32>)
+    ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
     outs(%0 : tensor<4xf32>) {
   ^bb0(%in: f32, %in_24: f32, %out: f32):
     linalg.yield %in : f32
@@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
 
 // -----
 
-#map = affine_map<(d0) -> (d0)>
-func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
-  %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
-  linalg.generic {
-    indexing_maps = [#map, #map],
-    iterator_types = ["parallel"]
-  } ins(%0 : tensor<?xf32>)
-    outs(%arg1 : memref<?xf32>) {
-  ^bb0(%arg2 : f32, %arg3 : f32):
-    linalg.yield %arg2 : f32
-  }
-  return
-}
-
-// We need a mixed linalg as a bridge between tensor and memref worlds.
-// CHECK-LABEL: func @cast_producer_mixed
-//  CHECK-SAME:     (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
-//       CHECK:     linalg.generic {
-//  CHECK-SAME:    indexing_maps = [#map, #map],
-//  CHECK-SAME:    iterator_types = ["parallel"]
-//  CHECK-SAME:  } ins(%[[ARG1]] : tensor<5xf32>)
-//  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {
-
-// -----
-
 // CHECK-LABEL: dead_softmax
 func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
   %0 = tensor.empty() : tensor<16x64x256xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421cbab49d8..15a4f6cdd3bbe4 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1110,43 +1110,3 @@ module {
 //   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
 //       CHECK:     linalg.yield %[[T3]] : f32
 //       CHECK:   return %[[GENERIC]]
-
-// -----
-
-// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @mixed_fusion
-func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
-{
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
-  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%2 : tensor<?x?xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
-      %4 = arith.addf %arg3, %arg4 : f32
-      linalg.yield %4 : f32
-  } -> tensor<?x?xf32>
-  // CHECK: linalg.generic {
-  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
-  linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
-      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%arg8 : memref<?x?xf32>) {
-    // CHECK: ^{{[a-zA-Z0-9_]*}}
-    // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
-    // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
-    // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
-    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
-      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
-      // CHECK-NOT: linalg.yield
-      // CHECK: arith.mulf [[T1]], [[ARG2]]
-      // CHECK: linalg.yield
-      %5 = arith.mulf %arg5, %arg6 : f32
-      linalg.yield %5 : f32
-    }
-  return
-}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 916c04f33e9c67..44c81c31ace0f9 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
                      -> tensor<8x8xf32>
     return %res : tensor<8x8xf32>
 }
+
+// -----
+
+func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) {
+  // expected-error @+1 {{expected to have pure tensor or buffer semantics}}
+  linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
+               outs(%c: memref<?x?xf32>)
+  return
+}
+



More information about the Mlir-commits mailing list