[Mlir-commits] [mlir] ac2cf07 - [spirv] Fix legalize standard to spir-v for transfer ops

Thomas Raoux llvmlistbot at llvm.org
Wed Oct 21 13:56:25 PDT 2020


Author: Thomas Raoux
Date: 2020-10-21T13:56:01-07:00
New Revision: ac2cf07195b5833a888dc6878a9a3cb377ef59ac

URL: https://github.com/llvm/llvm-project/commit/ac2cf07195b5833a888dc6878a9a3cb377ef59ac
DIFF: https://github.com/llvm/llvm-project/commit/ac2cf07195b5833a888dc6878a9a3cb377ef59ac.diff

LOG: [spirv] Fix legalize standard to spir-v for transfer ops

Forward missing attributes when creating the new transfer op otherwise the
builder would use default values.

Differential Revision: https://reviews.llvm.org/D89907

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/legalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index a2e608dcb713..1cf3a326367c 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -67,7 +67,8 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
     vector::TransferReadOp loadOp, SubViewOp subViewOp,
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-      loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices);
+      loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
+      loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr());
 }
 
 template <>
@@ -84,7 +85,8 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
       tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
-      sourceIndices);
+      sourceIndices, tranferWriteOp.permutation_map(),
+      tranferWriteOp.maskedAttr());
 }
 } // namespace
 

diff  --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
index acbda3540d22..c5c59613b56e 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
@@ -67,16 +67,17 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 :
 // CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
 func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> {
   // CHECK-NOT: subview
+  // CHECK: [[F1:%.*]] = constant 1.000000e+00 : f32
   // CHECK: [[C2:%.*]] = constant 2 : index
   // CHECK: [[C3:%.*]] = constant 3 : index
   // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
   // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
   // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
   // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
-  %f0 = constant 0.0 : f32
+  // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}, [[F1]] {masked = [false]}
+  %f1 = constant 1.0 : f32
   %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
-  %1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
+  %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {masked = [false]} : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
   return %1 : vector<4xf32>
 }
 
@@ -90,9 +91,9 @@ func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>,
   // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
   // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
   // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
-  // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} {masked = [false]}
   %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] :
     memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
-  vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
+  vector.transfer_write %arg5, %0[%arg3, %arg4] {masked = [false]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
   return
 }


        


More information about the Mlir-commits mailing list