[Mlir-commits] [mlir] 670455c - [mlir][spirv] Legalize subviewop when used with vector transfer

Thomas Raoux llvmlistbot at llvm.org
Fri Jun 19 17:46:38 PDT 2020


Author: Thomas Raoux
Date: 2020-06-19T17:33:15-07:00
New Revision: 670455c77d4b2ee3bcf90fb454f62ae69ec47239

URL: https://github.com/llvm/llvm-project/commit/670455c77d4b2ee3bcf90fb454f62ae69ec47239
DIFF: https://github.com/llvm/llvm-project/commit/670455c77d4b2ee3bcf90fb454f62ae69ec47239.diff

LOG: [mlir][spirv] Legalize subviewop when used with vector transfer

Subview operations are not natively supported downstream in the spirv path.
This change allows removing subview when used by vector transfer the same way
we already do it when they are used by LoadOp/StoreOp

Differential Revision: https://reviews.llvm.org/D82106

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/legalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 3acd5959834d..0d949f74c191 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -15,28 +15,41 @@
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 
 using namespace mlir;
 
 namespace {
-/// Merges subview operation with load operation.
-class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
+/// Merges subview operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
 public:
-  using OpRewritePattern<LoadOp>::OpRewritePattern;
+  using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(LoadOp loadOp,
+  LogicalResult matchAndRewrite(OpTy loadOp,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  void replaceOp(OpTy loadOp, SubViewOp subViewOp,
+                 ArrayRef<Value> sourceIndices,
+                 PatternRewriter &rewriter) const;
 };
 
-/// Merges subview operation with store operation.
-class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
+/// Merges subview operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
 public:
-  using OpRewritePattern<StoreOp>::OpRewritePattern;
+  using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(StoreOp storeOp,
+  LogicalResult matchAndRewrite(OpTy storeOp,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
+                 ArrayRef<Value> sourceIndices,
+                 PatternRewriter &rewriter) const;
 };
 } // namespace
 
@@ -85,13 +98,14 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
 }
 
 //===----------------------------------------------------------------------===//
-// Folding SubViewOp and LoadOp.
+// Folding SubViewOp and LoadOp/TransferReadOp.
 //===----------------------------------------------------------------------===//
 
+template <typename OpTy>
 LogicalResult
-LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
-                                       PatternRewriter &rewriter) const {
-  auto subViewOp = loadOp.memref().getDefiningOp<SubViewOp>();
+LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
+                                             PatternRewriter &rewriter) const {
+  auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }
@@ -100,19 +114,36 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
                                   loadOp.indices(), sourceIndices)))
     return failure();
 
+  replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
+  return success();
+}
+
+template <>
+void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
+                                              SubViewOp subViewOp,
+                                              ArrayRef<Value> sourceIndices,
+                                              PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
                                       sourceIndices);
-  return success();
+}
+
+template <>
+void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
+    vector::TransferReadOp loadOp, SubViewOp subViewOp,
+    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+      loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices);
 }
 
 //===----------------------------------------------------------------------===//
-// Folding SubViewOp and StoreOp.
+// Folding SubViewOp and StoreOp/TransferWriteOp.
 //===----------------------------------------------------------------------===//
 
+template <typename OpTy>
 LogicalResult
-StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
-                                        PatternRewriter &rewriter) const {
-  auto subViewOp = storeOp.memref().getDefiningOp<SubViewOp>();
+StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
+                                              PatternRewriter &rewriter) const {
+  auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }
@@ -121,9 +152,25 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
                                   storeOp.indices(), sourceIndices)))
     return failure();
 
+  replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
+  return success();
+}
+
+template <>
+void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
+    StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
+    PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
                                        subViewOp.source(), sourceIndices);
-  return success();
+}
+
+template <>
+void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
+    vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
+    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+      tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
+      sourceIndices);
 }
 
 //===----------------------------------------------------------------------===//
@@ -132,7 +179,10 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
 
 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
     MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
+  patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
+                  LoadOpOfSubViewFolder<vector::TransferReadOp>,
+                  StoreOpOfSubViewFolder<StoreOp>,
+                  StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
index d3b339e82a88..acbda3540d22 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
@@ -62,3 +62,37 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 :
   store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
   return
 }
+
+// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
+func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> {
+  // CHECK-NOT: subview
+  // CHECK: [[C2:%.*]] = constant 2 : index
+  // CHECK: [[C3:%.*]] = constant 3 : index
+  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
+  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
+  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+  // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  %f0 = constant 0.0 : f32
+  %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+  %1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write
+// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32>
+func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) {
+  // CHECK-NOT: subview
+  // CHECK: [[C2:%.*]] = constant 2 : index
+  // CHECK: [[C3:%.*]] = constant 3 : index
+  // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
+  // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
+  // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
+  // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
+  // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] :
+    memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+  vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
+  return
+}


        


More information about the Mlir-commits mailing list