[Mlir-commits] [mlir] 68330ee - [mlir][vector] Relax transfer_read/transfer_write restriction on memref operand

Thomas Raoux llvmlistbot at llvm.org
Mon Aug 10 08:58:08 PDT 2020


Author: Thomas Raoux
Date: 2020-08-10T08:57:48-07:00
New Revision: 68330ee0a977926d2f2857c62420b7729f4e45d3

URL: https://github.com/llvm/llvm-project/commit/68330ee0a977926d2f2857c62420b7729f4e45d3
DIFF: https://github.com/llvm/llvm-project/commit/68330ee0a977926d2f2857c62420b7729f4e45d3.diff

LOG: [mlir][vector] Relax transfer_read/transfer_write restriction on memref operand

Relax the verifier for transfer_read/transfer_write operation so that it can
take a memref with a different element type than the vector being read/written.

This is based on the discourse discussion:
https://llvm.discourse.group/t/memref-cast/1514

Differential Revision: https://reviews.llvm.org/D85244

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1e92b80d830f..522a45965722 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1147,6 +1147,26 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     Type i64Type = rewriter.getIntegerType(64);
     MemRefType memRefType = xferOp.getMemRefType();
 
+    if (auto memrefVectorElementType =
+            memRefType.getElementType().dyn_cast<VectorType>()) {
+      // Memref has vector element type.
+      if (memrefVectorElementType.getElementType() !=
+          xferOp.getVectorType().getElementType())
+        return failure();
+      // Check that memref vector type is a suffix of 'vectorType.
+      unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+      unsigned resultVecRank = xferOp.getVectorType().getRank();
+      assert(memrefVecEltRank <= resultVecRank);
+      // TODO: Move this to isSuffix in Vector/Utils.h.
+      unsigned rankOffset = resultVecRank - memrefVecEltRank;
+      auto memrefVecEltShape = memrefVectorElementType.getShape();
+      auto resultVecShape = xferOp.getVectorType().getShape();
+      for (unsigned i = 0; i < memrefVecEltRank; ++i)
+        assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
+               "memref vector element shape should match suffix of vector "
+               "result shape.");
+    }
+
     // 1. Get the source/dst address as an LLVM vector pointer.
     //    The vector pointer would always be on address space 0, therefore
     //    addrspacecast shall be used when source/dst memrefs are not on

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 7c715bfdb6d0..1c0a5ceb8d86 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1500,37 +1500,33 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
     // Memref has vector element type.
 
-    // Check that 'memrefVectorElementType' and vector element types match.
-    if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+    unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
+                             memrefVectorElementType.getShape().back();
+    unsigned resultVecSize =
+        vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+    if (resultVecSize % memrefVecSize != 0)
       return op->emitOpError(
-          "requires memref and vector types of the same elemental type");
+          "requires the bitwidth of the minor 1-D vector to be an integral "
+          "multiple of the bitwidth of the minor 1-D vector of the memref");
 
-    // Check that memref vector type is a suffix of 'vectorType.
     unsigned memrefVecEltRank = memrefVectorElementType.getRank();
     unsigned resultVecRank = vectorType.getRank();
     if (memrefVecEltRank > resultVecRank)
       return op->emitOpError(
           "requires memref vector element and vector result ranks to match.");
-    // TODO: Move this to isSuffix in Vector/Utils.h.
     unsigned rankOffset = resultVecRank - memrefVecEltRank;
-    auto memrefVecEltShape = memrefVectorElementType.getShape();
-    auto resultVecShape = vectorType.getShape();
-    for (unsigned i = 0; i < memrefVecEltRank; ++i)
-      if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
-        return op->emitOpError(
-            "requires memref vector element shape to match suffix of "
-            "vector result shape.");
     // Check that permutation map results match 'rankOffset' of vector type.
     if (permutationMap.getNumResults() != rankOffset)
       return op->emitOpError("requires a permutation_map with result dims of "
                              "the same rank as the vector type");
   } else {
     // Memref has scalar element type.
-
-    // Check that memref and vector element types match.
-    if (memrefType.getElementType() != vectorType.getElementType())
+    unsigned resultVecSize =
+        vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+    if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
       return op->emitOpError(
-          "requires memref and vector types of the same elemental type");
+          "requires the bitwidth of the minor 1-D vector to be an integral "
+          "multiple of the bitwidth of the memref element type");
 
     // Check that permutation map results match rank of vector type.
     if (permutationMap.getNumResults() != vectorType.getRank())
@@ -1560,7 +1556,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                            VectorType vector, Value memref, ValueRange indices,
                            AffineMap permutationMap,
                            ArrayRef<bool> maybeMasked) {
-  Type elemType = vector.cast<VectorType>().getElementType();
+  Type elemType = memref.getType().cast<MemRefType>().getElementType();
   Value padding = builder.create<ConstantOp>(result.location, elemType,
                                              builder.getZeroAttr(elemType));
   if (maybeMasked.empty())
@@ -1673,9 +1669,9 @@ static LogicalResult verify(TransferReadOp op) {
       return op.emitOpError("requires valid padding vector elemental type");
 
     // Check that padding type and vector element types match.
-    if (paddingType != vectorType.getElementType())
+    if (paddingType != memrefElementType)
       return op.emitOpError(
-          "requires formal padding and vector of the same elemental type");
+          "requires formal padding and memref of the same elemental type");
   }
 
   return verifyPermutationMap(permutationMap,

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d91d4db06106..d35c7fa645b7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -906,6 +906,24 @@ func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17
 // 2. Rewrite as a load.
 //       CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<17 x float>>
 
+func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
+  %c0 = constant 0: i32
+  %v = vector.transfer_read %A[%base], %c0 {masked = [false]} :
+    memref<?xi32>, vector<12xi8>
+  return %v: vector<12xi8>
+}
+// CHECK-LABEL: func @transfer_read_1d_cast
+//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<12 x i8>
+//
+// 1. Bitcast to vector form.
+//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm.ptr<i32>, !llvm.i64) -> !llvm.ptr<i32>
+//       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+//  CHECK-SAME: !llvm.ptr<i32> to !llvm.ptr<vec<12 x i8>>
+//
+// 2. Rewrite as a load.
+//       CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<12 x i8>>
+
 func @genbool_1d() -> vector<8xi1> {
   %0 = vector.constant_mask [4] : vector<8xi1>
   return %0 : vector<8xi1>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e1d03ca480f4..71d2989661ae 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -339,16 +339,6 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
-  %c3 = constant 3 : index
-  %f0 = constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
-  // expected-error at +1 {{requires memref and vector types of the same elemental type}}
-  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32>
-}
-
-// -----
-
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
@@ -359,12 +349,12 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
 
 // -----
 
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
-  // expected-error at +1 {{ requires memref vector element shape to match suffix of vector result shape}}
-  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32>
+  %vf0 = splat %f0 : vector<6xf32>
+  // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
+  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0381c88cc247..6f9990d5d97c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -4,12 +4,15 @@
 
 // CHECK-LABEL: func @vector_transfer_ops(
 func @vector_transfer_ops(%arg0: memref<?x?xf32>,
-                          %arg1 : memref<?x?xvector<4x3xf32>>) {
+                          %arg1 : memref<?x?xvector<4x3xf32>>,
+                          %arg2 : memref<?x?xvector<4x3xi32>>) {
   // CHECK: %[[C3:.*]] = constant 3 : index
   %c3 = constant 3 : index
   %cst = constant 3.0 : f32
   %f0 = constant 0.0 : f32
+  %c0 = constant 0 : i32
   %vf0 = splat %f0 : vector<4x3xf32>
+  %v0 = splat %c0 : vector<4x3xi32>
 
   //
   // CHECK: vector.transfer_read
@@ -24,6 +27,9 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
   // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
   %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+  // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+  %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+
 
   // CHECK: vector.transfer_write
   vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -33,6 +39,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
   // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
   vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+  // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
+  vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
 
   return
 }


        


More information about the Mlir-commits mailing list