[Mlir-commits] [mlir] 96ca2d9 - [mlir][ArmSVE] Add basic load/store operations
Javier Setoain
llvmlistbot at llvm.org
Wed Jun 9 08:04:25 PDT 2021
Author: Javier Setoain
Date: 2021-06-09T15:53:40+01:00
New Revision: 96ca2d92b52bd97fcdce4c0ba2723399b005e0a9
URL: https://github.com/llvm/llvm-project/commit/96ca2d92b52bd97fcdce4c0ba2723399b005e0a9
DIFF: https://github.com/llvm/llvm-project/commit/96ca2d92b52bd97fcdce4c0ba2723399b005e0a9.diff
LOG: [mlir][ArmSVE] Add basic load/store operations
ArmSVE-specific memory operations are needed to generate end-to-end
code for as long as MLIR core doesn't support scalable vectors. This
instructions will be eventually unnecessary, for now they're required
for more complex testing.
Differential Revision: https://reviews.llvm.org/D103535
Added:
mlir/test/Dialect/ArmSVE/memcpy.mlir
Modified:
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Target/LLVMIR/arm-sve.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index 7114fdb9425a9..6e858db54ce19 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -356,6 +356,37 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
"attr-dict `:` type($res)";
}
+def ScalableLoadOp : ArmSVE_Op<"load">,
+ Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base, Index:$index)>,
+ Results<(outs ScalableVectorOf<[AnyType]>:$result)> {
+ let summary = "Load scalable vector from memory";
+ let description = [{
+ Load a slice of memory into a scalable vector.
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ }];
+ let assemblyFormat = "$base `[` $index `]` attr-dict `:` "
+ "type($result) `from` type($base)";
+}
+
+def ScalableStoreOp : ArmSVE_Op<"store">,
+ Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base, Index:$index,
+ ScalableVectorOf<[AnyType]>:$value)> {
+ let summary = "Store scalable vector into memory";
+ let description = [{
+ Store a scalable vector on a slice of memory.
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ }];
+ let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` "
+ "type($value) `to` type($base)";
+}
def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index e43511a87a47b..ba84a955d66dd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -111,6 +111,80 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
+// Load operation is lowered to code that obtains a pointer to the indexed
+// element and loads from it.
+struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
+ using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = loadOp.getMemRefType();
+ if (!isConvertibleAndHasIdentityMaps(type))
+ return failure();
+
+ ScalableLoadOp::Adaptor transformed(operands);
+ LLVMTypeConverter converter(loadOp.getContext());
+
+ auto resultType = loadOp.result().getType();
+ LLVM::LLVMPointerType llvmDataTypePtr;
+ if (resultType.isa<VectorType>()) {
+ llvmDataTypePtr =
+ LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
+ } else if (resultType.isa<ScalableVectorType>()) {
+ llvmDataTypePtr = LLVM::LLVMPointerType::get(
+ convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
+ converter)
+ .getValue());
+ }
+ Value dataPtr =
+ getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
+ transformed.index(), rewriter);
+ Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
+ loadOp.getLoc(), llvmDataTypePtr, dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
+ return success();
+ }
+};
+
+// Store operation is lowered to code that obtains a pointer to the indexed
+// element, and stores the given value to it.
+struct ScalableStoreOpLowering
+ : public ConvertOpToLLVMPattern<ScalableStoreOp> {
+ using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = storeOp.getMemRefType();
+ if (!isConvertibleAndHasIdentityMaps(type))
+ return failure();
+
+ ScalableStoreOp::Adaptor transformed(operands);
+ LLVMTypeConverter converter(storeOp.getContext());
+
+ auto resultType = storeOp.value().getType();
+ LLVM::LLVMPointerType llvmDataTypePtr;
+ if (resultType.isa<VectorType>()) {
+ llvmDataTypePtr =
+ LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
+ } else if (resultType.isa<ScalableVectorType>()) {
+ llvmDataTypePtr = LLVM::LLVMPointerType::get(
+ convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
+ converter)
+ .getValue());
+ }
+ Value dataPtr =
+ getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
+ transformed.index(), rewriter);
+ Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
+ storeOp.getLoc(), llvmDataTypePtr, dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
+ bitCastedPtr);
+ return success();
+ }
+};
+
static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
@@ -191,6 +265,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
+ patterns.add<ScalableLoadOpLowering,
+ ScalableStoreOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
populateSVEMaskGenerationExportPatterns(converter, patterns);
@@ -226,7 +302,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
- ScalableMaskedDivFOp>();
+ ScalableMaskedDivFOp,
+ ScalableLoadOp,
+ ScalableStoreOp>();
// clang-format on
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir
new file mode 100644
index 0000000000000..4b19b1ff2851b
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/memcpy.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
+
+// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
+func @memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %vs = arm_sve.vector_scale : index
+ %step = muli %c4, %vs : index
+
+ // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
+ scf.for %i0 = %c0 to %size step %step {
+ // CHECK: [[SRCMRS:%[0-9]+]] = llvm.mlir.cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+ // CHECK: [[SRCIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64
+ // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
+ // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+ // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
+ %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref<?xf32>
+ // CHECK: [[DSTMRS:%[0-9]+]] = llvm.mlir.cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+ // CHECK: [[DSTIDX:%[0-9]+]] = llvm.mlir.cast [[LOOPIDX]] : index to i64
+ // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
+ // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DSTIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+ // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
+ arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref<?xf32>
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 2dde6c32a665e..f60cd44f94e29 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -119,6 +119,17 @@ func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
return %0 : !arm_sve.vector<4xi1>
}
+func @arm_sve_memory(%v: !arm_sve.vector<4xi32>,
+ %m: memref<?xi32>)
+ -> !arm_sve.vector<4xi32> {
+ %c0 = constant 0 : index
+ // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref<?xi32>
+ %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref<?xi32>
+ // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref<?xi32>
+ arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref<?xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 1e857275ec0f3..5dcec26a975c6 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -190,6 +190,84 @@ llvm.func @arm_sve_abs_
diff (%arg0: !llvm.vec<? x 4 x i32>,
llvm.return %6 : !llvm.vec<? x 4 x i32>
}
+// CHECK-LABEL: define void @memcopy
+llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
+ %arg2: i64, %arg3: i64, %arg4: i64,
+ %arg5: !llvm.ptr<f32>, %arg6: !llvm.ptr<f32>,
+ %arg7: i64, %arg8: i64, %arg9: i64,
+ %arg10: i64) {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>, array<1 x i64>)>
+ %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %6 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ %12 = llvm.mlir.constant(0 : index) : i64
+ %13 = llvm.mlir.constant(4 : index) : i64
+ // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64()
+ %14 = "arm_sve.vscale"() : () -> i64
+ // CHECK: mul i64 [[VL]], 4
+ %15 = llvm.mul %14, %13 : i64
+ llvm.br ^bb1(%12 : i64)
+^bb1(%16: i64):
+ %17 = llvm.icmp "slt" %16, %arg10 : i64
+ llvm.cond_br %17, ^bb2, ^bb3
+^bb2:
+ // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
+ %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ // CHECK: etelementptr float, float*
+ %19 = llvm.getelementptr %18[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK: bitcast float* %{{[0-9]+}} to <vscale x 4 x float>*
+ %20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+ // CHECK: load <vscale x 4 x float>, <vscale x 4 x float>*
+ %21 = llvm.load %20 : !llvm.ptr<vec<? x 4 x f32>>
+ // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
+ %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<1 x i64>,
+ array<1 x i64>)>
+ // CHECK: getelementptr float, float* %32
+ %23 = llvm.getelementptr %22[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK: bitcast float* %33 to <vscale x 4 x float>*
+ %24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+ // CHECK: store <vscale x 4 x float> %{{[0-9]+}}, <vscale x 4 x float>* %{{[0-9]+}}
+ llvm.store %21, %24 : !llvm.ptr<vec<? x 4 x f32>>
+ %25 = llvm.add %16, %15 : i64
+ llvm.br ^bb1(%25 : i64)
+^bb3:
+ llvm.return
+}
+
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()
More information about the Mlir-commits
mailing list