[Mlir-commits] [mlir] 83df43f - [mlir] use strided layouts in vector transfer on memrefs

Alex Zinenko llvmlistbot at llvm.org
Fri Sep 16 23:11:42 PDT 2022


Author: Alex Zinenko
Date: 2022-09-17T08:11:30+02:00
New Revision: 83df43f3a204c529167caeedc26edae5faebbd31

URL: https://github.com/llvm/llvm-project/commit/83df43f3a204c529167caeedc26edae5faebbd31
DIFF: https://github.com/llvm/llvm-project/commit/83df43f3a204c529167caeedc26edae5faebbd31.diff

LOG: [mlir] use strided layouts in vector transfer on memrefs

One of the vector transformation patterns has been indiscriminately
converting layouts to affine maps. Leverage the strided form when
possible.

Reviewed By: nicolasvasilache, dcaballe

Differential Revision: https://reviews.llvm.org/D134047

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fded10629b307..18f6c5a154e55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
@@ -2631,22 +2632,29 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
                         targetType.getElementType());
 
     MemRefType resultMemrefType;
-    if (srcType.getLayout().getAffineMap().isIdentity()) {
+    MemRefLayoutAttrInterface layout = srcType.getLayout();
+    if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
       resultMemrefType = MemRefType::get(
           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          {}, srcType.getMemorySpaceAsInt());
+          nullptr, srcType.getMemorySpace());
     } else {
-      AffineMap map = srcType.getLayout().getAffineMap();
-      int numSymbols = map.getNumSymbols();
-      for (size_t i = 0; i < dimsToDrop; ++i) {
-        int dim = srcType.getRank() - i - 1;
-        map = map.replace(rewriter.getAffineDimExpr(dim),
-                          rewriter.getAffineConstantExpr(0),
-                          map.getNumDims() - 1, numSymbols);
+      MemRefLayoutAttrInterface updatedLayout;
+      if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
+        auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+        updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides);
+      } else {
+        AffineMap map = srcType.getLayout().getAffineMap();
+        int numSymbols = map.getNumSymbols();
+        for (size_t i = 0; i < dimsToDrop; ++i) {
+          int dim = srcType.getRank() - i - 1;
+          map = map.replace(rewriter.getAffineDimExpr(dim),
+                            rewriter.getAffineConstantExpr(0),
+                            map.getNumDims() - 1, numSymbols);
+        }
       }
       resultMemrefType = MemRefType::get(
           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          map, srcType.getMemorySpaceAsInt());
+          updatedLayout, srcType.getMemorySpace());
     }
 
     auto loc = readOp.getLoc();

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 25f38452e3b4a..ef0bd9ddf8abf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -8,9 +8,9 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
 }
 //      CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
 //      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
-// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32
+// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
 //      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
-// CHECK-SAME:    memref<1x1x8xf32, {{.*}}>, vector<1x8xf32>
+// CHECK-SAME:    memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
 //      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
 //      CHECK:   return %[[RESULT]]
 
@@ -34,8 +34,8 @@ func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -
 func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
-  %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>
-  %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>, vector<4x1xf32>
+  %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>>
+  %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, strided<[1, 1], offset: ?>>, vector<4x1xf32>
   return %1 : vector<4x1xf32>
 }
 //      CHECK: func @contiguous_inner_most_dim_bounds(%[[SRC:.+]]: memref<1000x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1xf32>
@@ -50,8 +50,8 @@ func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %i
 func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:index, %ii:index) -> (vector<4x1x1xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
-  %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>
-  %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>, vector<4x1x1xf32>
+  %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>
+  %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>, vector<4x1x1xf32>
   return %1 : vector<4x1x1xf32>
 }
 //      CHECK: func @contiguous_inner_most_dim_bounds_2d(%[[SRC:.+]]: memref<1000x1x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1x1xf32>


        


More information about the Mlir-commits mailing list