[Mlir-commits] [mlir] 96ca2d9 - [mlir][ArmSVE] Add basic load/store operations

Javier Setoain llvmlistbot at llvm.org
Wed Jun 9 08:04:25 PDT 2021


Author: Javier Setoain
Date: 2021-06-09T15:53:40+01:00
New Revision: 96ca2d92b52bd97fcdce4c0ba2723399b005e0a9

URL: https://github.com/llvm/llvm-project/commit/96ca2d92b52bd97fcdce4c0ba2723399b005e0a9
DIFF: https://github.com/llvm/llvm-project/commit/96ca2d92b52bd97fcdce4c0ba2723399b005e0a9.diff

LOG: [mlir][ArmSVE] Add basic load/store operations

ArmSVE-specific memory operations are needed to generate end-to-end
code for as long as MLIR core doesn't support scalable vectors. This
instructions will be eventually unnecessary, for now they're required
for more complex testing.

Differential Revision: https://reviews.llvm.org/D103535

Added: 
    mlir/test/Dialect/ArmSVE/memcpy.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Target/LLVMIR/arm-sve.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index 7114fdb9425a9..6e858db54ce19 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -356,6 +356,37 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
     "attr-dict `:` type($res)";
 }
 
+def ScalableLoadOp : ArmSVE_Op<"load">,
+    Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base, Index:$index)>,
+    Results<(outs ScalableVectorOf<[AnyType]>:$result)> {
+  let summary = "Load scalable vector from memory";
+  let description = [{
+    Load a slice of memory into a scalable vector.
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+  }];
+  let assemblyFormat = "$base `[` $index `]` attr-dict `:` "
+    "type($result) `from` type($base)";
+}
+
+def ScalableStoreOp : ArmSVE_Op<"store">,
+    Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base, Index:$index,
+               ScalableVectorOf<[AnyType]>:$value)> {
+  let summary = "Store scalable vector into memory";
+  let description = [{
+    Store a scalable vector on a slice of memory.
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+  }];
+  let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` "
+    "type($value) `to` type($base)";
+}
 
 def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
 

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index e43511a87a47b..ba84a955d66dd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -111,6 +111,80 @@ using ScalableMaskedDivFOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
                                  ScalableMaskedDivFIntrOp>;
 
+// Load operation is lowered to code that obtains a pointer to the indexed
+// element and loads from it.
+struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
+  using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = loadOp.getMemRefType();
+    if (!isConvertibleAndHasIdentityMaps(type))
+      return failure();
+
+    ScalableLoadOp::Adaptor transformed(operands);
+    LLVMTypeConverter converter(loadOp.getContext());
+
+    auto resultType = loadOp.result().getType();
+    LLVM::LLVMPointerType llvmDataTypePtr;
+    if (resultType.isa<VectorType>()) {
+      llvmDataTypePtr =
+          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
+    } else if (resultType.isa<ScalableVectorType>()) {
+      llvmDataTypePtr = LLVM::LLVMPointerType::get(
+          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
+                                          converter)
+              .getValue());
+    }
+    Value dataPtr =
+        getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
+                             transformed.index(), rewriter);
+    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
+        loadOp.getLoc(), llvmDataTypePtr, dataPtr);
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
+    return success();
+  }
+};
+
+// Store operation is lowered to code that obtains a pointer to the indexed
+// element, and stores the given value to it.
+struct ScalableStoreOpLowering
+    : public ConvertOpToLLVMPattern<ScalableStoreOp> {
+  using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = storeOp.getMemRefType();
+    if (!isConvertibleAndHasIdentityMaps(type))
+      return failure();
+
+    ScalableStoreOp::Adaptor transformed(operands);
+    LLVMTypeConverter converter(storeOp.getContext());
+
+    auto resultType = storeOp.value().getType();
+    LLVM::LLVMPointerType llvmDataTypePtr;
+    if (resultType.isa<VectorType>()) {
+      llvmDataTypePtr =
+          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
+    } else if (resultType.isa<ScalableVectorType>()) {
+      llvmDataTypePtr = LLVM::LLVMPointerType::get(
+          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
+                                          converter)
+              .getValue());
+    }
+    Value dataPtr =
+        getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
+                             transformed.index(), rewriter);
+    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
+        storeOp.getLoc(), llvmDataTypePtr, dataPtr);
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
+                                               bitCastedPtr);
+    return success();
+  }
+};
+
 static void
 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
                                          OwningRewritePatternList &patterns) {
@@ -191,6 +265,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedSDivIOpLowering,
                ScalableMaskedUDivIOpLowering,
                ScalableMaskedDivFOpLowering>(converter);
+  patterns.add<ScalableLoadOpLowering,
+               ScalableStoreOpLowering>(converter);
   // clang-format on
   populateBasicSVEArithmeticExportPatterns(converter, patterns);
   populateSVEMaskGenerationExportPatterns(converter, patterns);
@@ -226,7 +302,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedMulFOp,
                       ScalableMaskedSDivIOp,
                       ScalableMaskedUDivIOp,
-                      ScalableMaskedDivFOp>();
+                      ScalableMaskedDivFOp,
+                      ScalableLoadOp,
+                      ScalableStoreOp>();
   // clang-format on
   auto hasScalableVectorType = [](TypeRange types) {
     for (Type type : types)

diff  --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir
new file mode 100644
index 0000000000000..4b19b1ff2851b
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/memcpy.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s 
+
+// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
+func @memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %vs = arm_sve.vector_scale : index
+  %step = muli %c4, %vs : index
+
+  // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} 
+  scf.for %i0 = %c0 to %size step %step {
+    // CHECK: [[SRCMRS:%[0-9]+]] = llvm.mlir.cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+    // CHECK: [[SRCIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64
+    // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
+    // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+    // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
+    %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref<?xf32>
+    // CHECK: [[DSTMRS:%[0-9]+]] = llvm.mlir.cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+    // CHECK: [[DSTIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64
+    // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
+    // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DSTIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+    // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
+    arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref<?xf32>
+  }
+
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 2dde6c32a665e..f60cd44f94e29 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -119,6 +119,17 @@ func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
   return %0 : !arm_sve.vector<4xi1>
 }
 
+func @arm_sve_memory(%v: !arm_sve.vector<4xi32>,
+                     %m: memref<?xi32>)
+                     -> !arm_sve.vector<4xi32> {
+  %c0 = constant 0 : index
+  // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref<?xi32>
+  %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref<?xi32>
+  // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref<?xi32>
+  arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref<?xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vector_scale : index
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 1e857275ec0f3..5dcec26a975c6 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -190,6 +190,84 @@ llvm.func @arm_sve_abs_
diff (%arg0: !llvm.vec<? x 4 x i32>,
   llvm.return %6 : !llvm.vec<? x 4 x i32>
 }
 
+// CHECK-LABEL: define void @memcopy
+llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
+                   %arg2: i64, %arg3: i64, %arg4: i64,
+                   %arg5: !llvm.ptr<f32>, %arg6: !llvm.ptr<f32>,
+                   %arg7: i64, %arg8: i64, %arg9: i64,
+                   %arg10: i64) {
+  %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                       array<1 x i64>, array<1 x i64>)>
+  %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                        array<1 x i64>,
+                                                        array<1 x i64>)>
+  %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                        array<1 x i64>,
+                                                        array<1 x i64>)>
+  %6 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                       array<1 x i64>,
+                                       array<1 x i64>)>
+  %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                     array<1 x i64>,
+                                                     array<1 x i64>)>
+  %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                         array<1 x i64>,
+                                                         array<1 x i64>)>
+  %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                         array<1 x i64>,
+                                                         array<1 x i64>)>
+  %12 = llvm.mlir.constant(0 : index) : i64
+  %13 = llvm.mlir.constant(4 : index) : i64
+  // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64()
+  %14 = "arm_sve.vscale"() : () -> i64
+  // CHECK: mul i64 [[VL]], 4
+  %15 = llvm.mul %14, %13  : i64
+  llvm.br ^bb1(%12 : i64)
+^bb1(%16: i64):
+  %17 = llvm.icmp "slt" %16, %arg10 : i64
+  llvm.cond_br %17, ^bb2, ^bb3
+^bb2:
+  // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
+  %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                array<1 x i64>,
+                                                array<1 x i64>)>
+  // CHECK: etelementptr float, float*
+  %19 = llvm.getelementptr %18[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK: bitcast float* %{{[0-9]+}} to <vscale x 4 x float>*
+  %20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+  // CHECK: load <vscale x 4 x float>, <vscale x 4 x float>*
+  %21 = llvm.load %20 : !llvm.ptr<vec<? x 4 x f32>>
+  // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
+  %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                                 array<1 x i64>,
+                                                 array<1 x i64>)>
+  // CHECK: getelementptr float, float* %32
+  %23 = llvm.getelementptr %22[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK: bitcast float* %33 to <vscale x 4 x float>*
+  %24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+  // CHECK: store <vscale x 4 x float> %{{[0-9]+}}, <vscale x 4 x float>* %{{[0-9]+}}
+  llvm.store %21, %24 : !llvm.ptr<vec<? x 4 x f32>>
+  %25 = llvm.add %16, %15  : i64
+  llvm.br ^bb1(%25 : i64)
+^bb3:
+  llvm.return
+}
+
 // CHECK-LABEL: define i64 @get_vector_scale()
 llvm.func @get_vector_scale() -> i64 {
   // CHECK: call i64 @llvm.vscale.i64()


        


More information about the Mlir-commits mailing list