[Mlir-commits] [mlir] dd53244 - [mlir] Disallow broadcast dimensions on TransferWriteOp.

Matthias Springer llvmlistbot at llvm.org
Tue Apr 20 15:44:05 PDT 2021


Author: Matthias Springer
Date: 2021-04-21T07:43:45+09:00
New Revision: dd5324467d1d42a3d5d4e7492778ce4997d0bc57

URL: https://github.com/llvm/llvm-project/commit/dd5324467d1d42a3d5d4e7492778ce4997d0bc57
DIFF: https://github.com/llvm/llvm-project/commit/dd5324467d1d42a3d5d4e7492778ce4997d0bc57.diff

LOG: [mlir] Disallow broadcast dimensions on TransferWriteOp.

The current implementation allows for TransferWriteOps with broadcasts that do not make sense. E.g., a broadcast could write a vector into a single (scalar) memory location, which is effectively the same as writing only the last element of the vector.

Differential Revision: https://reviews.llvm.org/D100842

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index da79ca7fa2098..580c6985f5ad7 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -114,6 +114,21 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodBody=*/"return $_op.permutation_map();"
       /*defaultImplementation=*/
     >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns true if at least one of the dimensions in the
+                  permutation map is a broadcast.}],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasBroadcastDim",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::any_of(
+            $_op.permutation_map().getResults(),
+            [](AffineExpr e) {
+                return e.isa<AffineConstantExpr>() &&
+                       e.dyn_cast<AffineConstantExpr>().getValue() == 0; });
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/"Return the `in_bounds` boolean ArrayAttr.",
       /*retTy=*/"Optional<ArrayAttr>",

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 934ea611c4324..d654819551559 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2683,6 +2683,11 @@ static LogicalResult verify(TransferWriteOp op) {
   if (llvm::size(op.indices()) != shapedType.getRank())
     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
 
+  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
+  // as the semantics is unclear. This can be revisited later if necessary.
+  if (op.hasBroadcastDim())
+    return op.emitOpError("should not have broadcast dimensions");
+
   if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
                               permutationMap,
                               op.in_bounds() ? *op.in_bounds() : ArrayAttr())))

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index b4cf222bd7d7a..827241315dfe0 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -461,6 +461,17 @@ func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>) {
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error at +1 {{should not have broadcast dimensions}}
+  vector.transfer_write %arg1, %arg0[%c3]
+      {permutation_map = affine_map<(d0) -> (0)>}
+      : vector<7xf32>, memref<?xf32>
+}
+
+// -----
+
 func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets of same size as destination vector rank}}
   %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>


        


More information about the Mlir-commits mailing list