[Mlir-commits] [mlir] andrzej/extend vec cont to op tests (PR #70039)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Oct 24 05:57:13 PDT 2023
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/70039
- Revert "[gn build] Add rules for crtbegin/end (#66012)"
- [YAMLTraits] Fix std::optional input on empty documents (#68947)
- [llvm-rc] Accept filenames provided as multiple string literals (#68881)
- [flang][runtime] Handle incomplete NAMELIST input derived type compon… (#66831)
- Revert "[SLP]Fix PR69196: Instruction does not dominate all uses"
- [libc] Add simple long double to printf float fuzz (#68449)
- [Clang][NFC] Use correct tool name for NVIDIA's 'nvlink'
- [libc++][test] Add `stop_token` benchmark (#69117)
- [clang-tidy][NFC] Clarify documentation for misc-definitions-in-headers
- [flang][runtime] Implement EX editing for input & output (#67208)
- [flang] Submodule names can clash only with submodule names (#67361)
- [RS4GC] Copy argument attributes from call to statepoint (#68475)
- [flang][runtime] Better non-repeatable RANDOM_INIT() (#67363)
- [clang-format] Allow default values for template parameters in lambda (#69052)
- [flang] Ensure component attributes affect characteristics (#67465)
- [mlir][sparse] remove sparse2sparse path in library (#69247)
- [hwasan][test] Fix regex so deep-recursion.c is unsupported on aarch64 targets (#69254)
- [docs] Add a new GlobalISel office hours session to the list.
- [docs] Fix google meet link
- [flang] Catch a dangerous ambiguity in standard Fortran (#67483)
- [flang] Avoid needless overflow when folding NORM2 (#67499)
- [HWASAN] Add bcmp interceptor (#69257)
- [TOSA] Add StatefulOps to TOSA Dialect (#66843)
- [flang][runtime] Fix edge cases with ROUND=UP/DOWN (#67508)
- [compiler-rt] Implement __extendxftf2 and __trunctfxf2 for x86_64 (#66918)
- llvm-gsymutil now handles empty linkage names correctly. (#68931)
- [flang] Fix CFI_CDESC_T for C++ interoperability (#67568)
- [flang] Remove IEEE_DENORM from IEEE_ALL (#67573)
- [RISCV] Pre-commit concat-vectors-constant-stride.ll
- [RISCV] Improve performCONCAT_VECTORCombine stride matching
- [compiler-rt] Fix build of builtins on Windows
- [flang][NFC] Speed up large DATA statement initializations (#67585)
- [flang] Handle separate module procedures with INTERFACE dummy arguments (#67608)
- [flang] Fix construct names on labeled DO (#67622)
- [clang] Implement C23 <stdckdint.h>
- [RISCV] Support STRICT_FP_ROUND and STRICT_FP_EXTEND when only have Zvfhmin (#68559)
- Revert "[clang][Sema] Use original template pattern when declaring implicit deduction guides for nested template classes (#68379)"
- [RISCV][GISel] Add legalizer for G_UMAX, G_UMIN, G_SMAX, G_SMIN (#69150)
- [M68k][NFC] Fix some unused variable warnings
- [clang][Interp] Check pointer inc/dec ops for null (#69168)
- [CI] Add Github actions job to build LLVM documentation (#69269)
- [Clang] Fix dependence handling of nttp for variable templates (#69075)
- [hwasan] Fix and re-enable deep-recursion.c (#69265)
- [CI] Fix documentation build CI job
- [MLIR][LLVM] Change addressof builders to use opaque pointers (#69215)
- Revert "[MLIR][LLVM] Change addressof builders to use opaque pointers (#69215)"
- [AArch64] Fix pairing different types of registers when computing CSRs. (#66642)
- [flang] Deallocate INTENT(OUT) dummy allocatable components (#69164)
- [flang][runtime] Fix another IsContiguous edge case (#69199)
- [flang][runtime] fix buildbot failure after #69199
- [flang][hlfir] Do not emit extra declare for dummy used in BLOCK (#69184)
- [AArch64] Fix -Wunused-variable in AArch64LowerHomogeneousPrologEpilog.cpp (NFC)
- [lldb][lldb-vscode] Add example configuration for connecting to a remote gdbserver (#68866)
- [compiler-rt][HWASAN] Add missing include in deep-recursion.c test
- [clang][Interp][NFC] Add thread_local tests
- [TableGen] Use buildConstant to emit apply pattern immediates (#66077)
- [TableGen] Handle duplicate rules in combiners (#69296)
- [clang][NFC] Replace TypeAlignment with alignof(T) (#69185)
- [ci] diff with main merge-base (#69308)
- Reland "[MLIR][LLVM] Change addressof builders to use opaque pointers" (#69292)
- [libc++] Eliminate extra allocations from `std::move(oss).str()` (#67294)
- [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (#68728)
- [AArch64] Allow only LSL to be folded into addressing mode (#69235)
- [mlir][transform] Fix new interpreter and library preloading passes. (#69190)
- [X86] vselect.ll - add vXi8 select-by-constant tests with repeated/broadcastable shuffle mask
- [HIP][Clang][CodeGen] Add CodeGen support for `hipstdpar`
- [mlir][nvvm] Support predicates in `BasicPtxBuilder` (#67102)
- [mlir][nvgpu] Fix packing accumlator matrix (#69316)
- [InstCombine] Don't mix X << Y / Z << Y with X << Y / X << Z (#69302)
- [mlir][nvvm] Add prefetch.tensormap (#67564)
- [MLIR][NVGPU] Test warpgroup matrix multiply 128x128x64 (#68817)
- [mlir][Tosa] Fix test failure when running with Asan.
- [gn] port 3694697003bb
- [gn] port 3694697003bb
- [Bazel] disable preload-library.mlir test
- [Bazel] fix typo
- [Bazel] Fix dependencies for clang codegen
- [AMDGPU][AsmParser] Eliminate custom predicates for named-bit operands. (#69243)
- [AMDGPU] Remove support for no-return buffer atomic intrinsics. NFC. (#69326)
- [TargetParser][AMDGPU] Fix getArchEntry(). (#69222)
- [CUDA][HIP] Fix init var diag in temmplate (#69081)
- [Clang][SVE2.1] Add svcntp prototype
- [ADT][DebugInfo][RemoveDIs] Add extra bits to ilist_iterator for debug-info
- [LoongArch] Precommit a test for atomic cmpxchg optmization
- [HIP][Clang][CodeGen] Simplify test for `hipstdpar`
- [flang] Fix constant subscript operations (#68352)
- Update documentation on x86 constraint codes (#68830)
- [gn build] Port 088d272e8325
- [X86] Enable bfloat type support in inline assembly constraints (#68469)
- [AArch64][SME] Remove immediate argument restriction for svldr and svstr (#68908)
- [flang][runtime] Fix SAME_TYPE_AS()/EXTENDS_TYPE_OF() for CLASS(*) (#67727)
- [AMDGPU] Simplify definition of SIbuffer_atomic_*. NFC.
- [RISCV] Use separate CCValAssign for both parts of f64 with ilp32. (#69129)
- [DAG] foldConstantFPMath - accept ArrayRef<SDValue> Ops instead of explicit N1/N2 ops
- [Github][OpenMP] Adding rule for OpenMP label (#65331)
- Reland: [AArch64][GlobalISel] Adopt dup(load) -> LD1R patterns from SelectionDAG
- Enable v for RISCV64 Android (#69261)
- [hwasan] Exclude bcmp interceptor test from Android
- [InstCombine] Create a class to lazily track computed known bits (#66611)
- [Clang][SVE2.1] Add svpext builtins
- [flang] Round derived type byte sizes up to alignment multiple (#67571)
- [X86, Peephole] Enable FoldImmediate for X86
- [mlir][sparse] complete migration to dim2lvl/lvl2dim in library (#69268)
- [DAG] SimplifyDemandedBits - fix isOperationLegal typo in D146121
- [YAMLParser] Improve plain scalar spec compliance (#68946)
- [AArch64] Add additional tests for fptosi/fptoui. NFC
- [AArch64] Convert negative constant aarch64_neon_sshl to VASHR (#68918)
- [SLP]Fix PR69196: Instruction does not dominate all uses
- [AMDGPU] support 64-bit immediates in SIInstrInfo::FoldImmediate (#69260)
- [mlir][sparse] implementating stageSparseOpPass as an interface (#69022)
- [libc] Implement the 'ungetc' function on the GPU (#69248)
- [NVPTX] Fixed few more corner cases for v4i8 lowering. (#69263)
- [unittest] Allow LLVM unit test to run under a wrapper program. (#66821)
- [VPlan] Insert Trunc/Exts for reductions directly in VPlan.
- [mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… (#69362)
- [clang-tidy] Add check to diagnose coroutine-hostile RAII objects (#68738)
- Correctly compute conversion seq for args to fn with reversed param order (#68999)
- [CodeGen] Avoid potential sideeffects from XOR (#67193)
- [mlgo] Fix tests post 760e7d0
- [gn] port dd64c82cbc9c6
- [gn build] Port 31512811b8c0
- [CodeGen] Temporary disable the unreachable
- [CodeGen][NFC] Fix formatting
- [ELF][test] Improve relocatable link & /DISCARD/ test
- [libcxx] [test] Add a test parameter for disabling memory intensive tests (#68214)
- [OpenMPIRBuilder] Added `if` clause for `teams` (#69139)
- [mlir][sparse] Populate lvlToDim (#68937)
- [ELF] Move demoteSymbols to Writer.cpp. NFC
- [MLIR][Doc] Prepend "Variadic of" in front of variadic operands (#69285)
- [ELF] Merge demoteSymbols and isPreemptible computation. NFC
- [CMake] Support per-target linker flags (#68393)
- [llvm-profdata] Do not create numerical strings for MD5 function names read from a Sample Profile. (#66164)
- [ELF] Demote symbols in /DISCARD/ discarded sections to Undefined (#69295)
- [flang][openacc] Accept scalar integer expression in the if clause (#69381)
- [NFC][SLP] Test case exposing gather nodes matching deficiency affecting cost. (#69382)
- [SLP][NFC] Try to cleanup and better document some isGatherShuffledEntry code. (#69384)
- [ELF] Remove unused setSymbolAndType after #69295. NFC
- [LLDB][NFC] Add a missing namespace
- [lldb] Scalar::GetValue() should take a Stream by reference (#69231)
- [LLDB][NFC] Move some constructors to their cpp file
- Turn an assert in mlir-tblgen into a runtime check to be more user friendly (NFC)
- Fix build: the dump() method is only available in Asserts build (NFC)
- [unittest] Add option to allow disabling sharding in unittest (#67063)
- [LangRef] "cc 10" -> "ghccc" (#69380)
- [sanitizer_common] Use 38-bit mmap range for Fuchsia (#69387)
- [Docs] Remove future extensions section from writing a pass docs (#69286)
- [M68k] Fix assertion build after cc6a5ea6e33d3febafc4334617230c528a0c4fa7
- Detect against invalid variant index for LibStdC++ std::variant data formatters (#69253)
- [Docs][NFC] fix URL
- [RISCV][GlobalISel] Select G_FRAME_INDEX (#68254)
- [CodeLayout] CDSortImpl: remove HotChains and remove linear-time erase_value from mergeChains (#69276)
- [PowerPC] Auto gen test checks for #69299. NFC.
- [Driver][NFC] Remove identifier with the comment (#68351)
- [DAG] Constant fold FMAD (#69324)
- [CodeGen] Remove unused function isMSInlineAsm (#69132)
- [ModuleInliner] Fix the heap maintenance (#69251)
- [X86] Add tests for transform `(icmp eq/ne (and X, C0), (shift X, C1))`; NFC
- [DAGCombiner] Transform `(icmp eq/ne (and X,C0),(shift X,C1))` to use rotate or to getter constants.
- [flang] Fold IS_CONTIGUOUS of component refs with non-contiguous base (#69327)
- Add RV64 constraint to SRLIW (#69416)
- [AArch64][GlobalISel] Precommit indexed sextload/zextload tests.
- [BOLT] Fix instrumentation test (#69383)
- nfc, address post commit comments related to code format for 581c64a
- [clang][USR] Encode full decl-context also for anon namespaces (#68325)
- [clang] [unittest] Add a test for Generic_GCC::GCCVersion::Parse (#69078)
- [gn build] Port b4b35a5d2b4e
- [compiler-rt] Only build SME ABI routines for baremetal or platforms that have sys/auxv.h (#69423)
- [SVE ACLE] Allow default zero initialisation for svcount_t. (#69321)
- [AMDGPU] Fix image intrinsic optimizer on loads from different resources (#69355)
- Revert "Detect against invalid variant index for LibStdC++ std::variant data formatters (#69253)"
- [MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (#69332)
- Revert "[AMDGPU] Remove Code Object V3 (#67118)"
- Add missing test from #68661
- [Clang] Run update_cc_test_checks across SVE acle tests.
- [lld] Sort code section chunks by range types on Arm64EC targets. (#69099)
- [clang][Interp][NFC] Remove from(Boolean) overload
- [BOLT][test] Update checkvma-large-section.test (#69419)
- [ARM] Correct v2i1 concat extract types.
- Revert "[clang] [unittest] Add a test for Generic_GCC::GCCVersion::Parse (#69078)"
- [gn build] Port 1072b94ed8e5
- [Support] Add KnownBits::computeForSubBorrow (#67788)
- [builtins] Convert more int to fp functions to use common implementation (#67540)
- [mlir][nvvm] Use NVVMMemorySpace instead of hardcoded values (nfc)
- [clang] Bail out if the result of function template instantiation is not a function type. (#69459)
- [mlir][LLVM] Improve function debug info import (#69446)
- [AMDGPU] Save/Restore SCC bit across waterfall loop. (#68363)
- [mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Builder (#68969)
- [Clang] Fill in documentation gaps for some attributes (#68967)
- [Driver] Link Flang runtime on Solaris (#65644)
- Revert "Correctly compute conversion seq for args to fn with reversed param order (#68999)"
- [mlir][python] Expose `PyInsertionPoint`'s reference operation (#69082)
- [Clang][SVE2p1] Add svpsel builtins
- [lldb] Fix linking to libtinfo (#69458)
- [InstCombine] Add aligned_alloc with pointer icmp as only use.
- [flang][openacc] Fixed private/reduction for combined constructs. (#69417)
- [LLVM] Add new attribute `optdebug` to optimize for debugging (#66632)
- [mlir] Add ContractionOpInterface utility functions for vector matrix multiplication (#68945)
- [ELF] Merge copyLocalSymbols and demoteLocalSymbolsInDiscardedSections (#69425)
- [MLIR][Doc] Clarify the cf.asssert doc that this is a runtime assertion
- [llvm][CMake] Check dependency cxx source compiles (#68549)
- [SLP] Improve gather tree nodes matching when users are PHIs. (#69392)
- [DebugInfo] Separate error generation from reporting in DWARFHeaderUnit::extract (#68242)
- [CONCEPTS]Corrected comparison of constraints with out of line CTD (#69244)
- Add a FIXME comment; NFC
- [Presburger] Fraction: resolve ambiguous overload in some cases
- [AMDGPU] Add legality check when folding short 64-bit literals (#69391)
- [Kaleidoscope] Switch to the new PassManager. (#69032)
- [clang-tidy][DOC] Fix list.rst
- [mlir] ADTExtras: include mlir/Support/LLVM.h (#69479)
- [clang-tidy][DOC] Fix syntax in coroutine-hostile-raii.rst
- Add missing include breaking the modules build
- [Sema] Add check for bitfield assignments to integral types (#69049)
- [SLP][NFC]Use MutableArrayRef instead of SmallVectorImpl& in param, NFC.
- [Libomptarget] Make the references to 'malloc' and 'free' weak. (#69356)
- Attributes (#69358)
- [SystemZ] Support builtin_{frame,return}_address() with non-zero argument (#69405)
- [TableGen] SubtargetEmitter must use std::nullopt (#69475)
- [MLIR] reverse int8 type's printing logic (#69361)
- [clang-tidy][DOC] Fix syntax in coroutine-hostile-raii.rst
- [mlir][nvgpu] Add predicate argument to NVGPU Ops (#69322)
- [hwasan] Fix rare false negative (zero tag) in stack-uar.c (#69374)
- [CodeExtractor] Allow to use 0 addr space for aggregate arg (#66998)
- [HIP] Document func ptr and virtual func (#68126)
- [CFI/MergeFunctions] Modify MergeFunctions to propagate type information (#68628)
- [VectorCombine] Add tests for unspeculatable VP binops. NFC
- [libc++][NFC] Reformat new.cpp and stdlib_new_delete.cpp
- [AMDGPU] Add missing test checks. NFC. (#69484)
- [libc++][NFC] Refactor the core logic of operator new into helper functions (#69407)
- [ARM] Lower i1 concat via MVETRUNC
- [CodeGen] -fsanitize=alignment: add cl::opt sanitize-alignment-builtin to disable memcpy instrumentation (#69240)
- [compiler-rt] Fix a warning
- [mlir][sparse] implement non-permutation MapRef encoding (#69406)
- Initialize sigset in asan_interceptors (#69502)
- [ModuleInliner] Use SmallVector::pop_back_val (NFC)
- [SLP][NFC]Use MutableArrayRef instead of SmallVectorImpl&, rename function, NFC.
- [ModuleInliner] Update a comment (NFC)
- [mlir][sparse] Update examples in Ops.td (#69499)
- [AMDGPU] Make S_MOV_B64_IMM_PSEUDO foldable (#69483)
- [RISCV] Don't let performBUILD_VECTORCombine form a division or remainder with undef elements. (#69482)
- workflows/release-lit: Fix dev suffix removal (#69397)
- workflows/release-lit: Pass correct build directory to pypa/gh-action-pypi-publish (#69438)
- [clang-format] Fix a bug in annotating TrailingReturnArrow (#69249)
- [Kaleidoscope] Register new dependencies introduced by #69032. (#69510)
- [DWARFLinker] Only extract unit DIEs when cloning clang modules (#69495)
- [flang][openacc] Avoid privatizing symbols during semantics (#69506)
- [Tosa] Rename variables to coding style guideline (#69509)
- [WebAssembly] add: hidden option to disable slow wasm pass (#67715)
- [RISCV][GISel] Add ISel supports for SHXADD from Zba extension (#67863)
- [clang] Expand invalid PCM diagnostic (#69489)
- [ELF] Set large section flag for globals with an explicit section (#69396)
- [X86][RFC] Support AVX10 options (#67278)
- [llvm] Use StringRef::contains (NFC)
- [clang-format][NFC] Take a constant conjunct out of a loop condition
- [llvm] Use StringRef::contains (NFC)
- [Driver][DragonFly] Fixes for linker path and command-line option handling (#69095)
- [RISCV][GISel] Support passing arguments through the stack. (#69289)
- [mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (#69540)
- [InstCombine] Refactor matchFunnelShift to allow more pattern (NFC) (#68474)
- [ELF][test] --emit-relocs: test ALLOC sections discarded by --gc-sections and referenced by non-ALLOC
- [LoongArch] Implement COPY instruction between CFRs (#69300)
- [LoongArch] Improve codegen for atomic cmpxchg ops (#69339)
- [Support] Use StringRef::contains_insensitive (NFC)
- [mlir] Add debug messages for failures of isValidIntOrFloat
- [mlir] Use the process (host) triple in MLIRTargetLLVMTests (#69538)
- [libc] Add simple features.h with implementation macro (#69402)
- [Memory] Call __clear_cache in InvalidateInstructionCache on LoongArch (#67285)
- nfc, add test case for llvm-symbolizer on XCOFF
- [libc++] Move the check-generated-files job to Github Actions (#68920)
- [libc++] Fix inconsistency between is_lock_free and is_always_lock_free (#68109)
- [OpenCL][RISCV] Support SPIR_KERNEL calling convention (#69282)
- [libc++][docs] Update contributing docs to reflect the move to GitHub (#69386)
- [libc++] Add assertions for potential OOB reads in std::nth_element (#67023)
- [libc++] Improve the tests for std::basic_stringbuf's constructors and assignment operators
- [mlir] Only attempt to vectorize conv if conv.
- [SPIR-V] Emit proper pointer type for OpenCL kernel arguments (#67726)
- [RISCV] Fix assertion failure from performBUILD_VECTORCombine when the binop is a shift. (#69349)
- [SPIR-V] Remove calls to deprecated PointerType methods (1/2) (#68336)
- [RISCV] Replace PostRAScheduler with PostMachineScheduler (#68696)
- [RISCV] Remove FrameIndex case in lui+addi MacroFusion (#68701)
- [Github] Make PR formatting job only run with C/C++ changes (#69556)
- [Github] Add steps to build clang docs to CI (#69550)
- [clang][Interp] IntegralAP zero-initializers (#68081)
- Revert "[VPlan] Insert Trunc/Exts for reductions directly in VPlan."
- Revert "[Github] Make PR formatting job only run with C/C++ changes (#69556)"
- Reapply "[clang analysis][thread-safety] Handle return-by-reference..… (#68572)
- [IR] Don't mark experimental.guard as willreturn (#69433)
- [GVN] Fix use-after-free in load PRE with select available value (#69314)
- [X86] Support -march=pantherlake,clearwaterforest (#69277)
- [clangd] Disable crashy unchecked-optional-access tidy check (#69427)
- [Tablegen] Add keyword `dump`. (#68793)
- [compiler-rt] Fix a warning
- [ReleaseNotes][TableGen] Add `dump` and `!repr`. (#68893)
- [TableGen] Update editor modes for new keywords and bang operators. (#68897)
- [RISCV] Combine (and (select cond, x, -1), c) to (select cond, x, (and x, c)) with Zicond. (#69563)
- [mlir][nvvm] Introduce `nvvm.stmatrix` Op (#69467)
- [FunctionAttrs] Add additional tests for writeonly (NFC)
- [Clang] Add __builtin_vectorelements to get number of elements in vector (#69010)
- [mlir][transform] Support symlinks in module loading. Reorganize tests. (#69329)
- [clang] Provide an SSE4.2 implementation of identifier token lexer (#68962)
- Revert "[mlir][transform] Support symlinks in module loading. Reorganize tests. (#69329)"
- [Clang][SVE2.1] Add pfalse builtin
- Reapply "[mlir][transform] Support symlinks in module loading. Reorganize tests. (#69329)"
- [libc] Fix accidental LIBC_NAMESPACE_syscall definition (#69548)
- Reapply "[dataflow] use true/false literals in formulas, rather than variables"
- [Flang][OpenMP][Sema] Add directive rewrite pass to support atomic_default_mem_order REQUIRES clause
- [clang][Interp][NFC] Add more tests for bitfield initializers
- [AMDGPU] PeepholeSDWA: Don't assume inst srcs are registers (#69576)
- Fix __builtin_vectorelements tests with REQUIRES (#69582)
- [ARM] fix "+fp.dp" in multilib selection (#67412)
- [clang][Interp][NFC] Use an APInt instead of APSint
- [Clang][SVE2.1] Add svwhile (predicate-as-counter) builtins
- [DAG] Expand vXi1 add/sub overflow operations as xor/and (#69191)
- [AMDGPU] Remove legality checks from imm folding in shrink. NFCI. (#69539)
- [NVPTX] Preserve v16i8 vector loads when legalizing
- [Clang] Actually fix tests for __builtin_vectorelements (#69589)
- ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering (#66924)
- Re-apply '[AArch64] Enable "sink-and-fold" in MachineSink by default (#67432)'
- Rename test to avoid overlapping with debug output
- [clang][Interp][NFC] Use a const reference in IncDecHelper
- [DAG] Add test coverage for Issue #66603
- [DAG] canCreateUndefOrPoison - remove AssertSext/AssertZext assumption that they never create undef/poison
- [clang][Interp] Create only globals when initializing a global variable
- [Libomptarget] Add a test for the `libc` implementation of assert (#69518)
- [libomptarget][OpenMP] Initial implementation of omp_target_memset() and omp_target_memset_async() (#68706)
- [InstCombine] Add additional aligned allocation tests for #69474.
- [AMDGPU] Constant fold FMAD_FTZ (#69443)
- [DIAG][msan] fix libc check string for dladdr1 call (#69359)
- [libc][math][NFC] Remove global scope constants declaration in math tests (#69558)
- [MemoryBuiltins] Simplify getAllocFnKind() implementation (NFC)
- Fixed some wmma store builtins that had non-const src param
- [flang] Put ISO_Fortran_binding.h where it can be easily used (#69121)
- [Clang][SVE2.1] Add builtins for Multi-vector load and store
- [TwoAddressInstruction] Handle physical registers with LiveIntervals (#66784)
- [flang][openacc] Warn about misplaced end loop directive and ignore it (#69512)
- InlineSpiller: Delete assert that implicit_def has no implicit operands (#69087)
- Let clang-cl support CUDA/HIP (#68921)
- [mlir][ODS] Add `OptionalTypesMatchWith` and remove a custom assemblyFormat (#68876)
- [Clang][SVE2.1] Add builtins for 2-way svdot (vectors, indexed)
- [SLP][NFC]Add avx2 test run, NFC.
- [hwasan] Fix rare false negative (zero tag) in two more test cases (#69491)
- [mlir][sparse] Update verifier for block sparsity and singleton (#69389)
- AMDGPU: Minor updates to program resource registers (#69525)
- [Clang][SVE2.1] Add builtins and intrinsics for SVBFMLSLB/T
- [VectorCombine] Use isSafeToSpeculativelyExecute to guard VP scalarization (#69494)
- [DebugInfo] Correctly report header parsing errors from DWARFContext::fixupIndex (#69505)
- [lldb] Rename lldb-vscode to lldb-dap (#69264)
- Fix test clang/test/Driver/cl-offload.cu
- Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)
- [gn build] Port
- [gn build] Port 01263c6c6fb4
- [docs][NewPM] Add comment about declaring analysis managers in the correct order
- [lldb] Fix ASCII art in CommandObjectSource.h (NFC)
- [unittest] Refactoring the gtest sharding option. (#69537)
- [InstCombine] Don't consider aligned_alloc removable if icmp uses result (#69474)
- [libc] Fix accidental LIBC_NAMESPACE_clock_freq (#69620)
- [AMDGPU] Allow lit() on operands which do not accept modifiers (#69527)
- [AArch64][GlobalISel] Fix miscompile on carry-in selection (#68840)
- [RISCV] Add getSameRatioLMUL (#69570)
- [ELF][test] Demonstrate --no-allow-shlib-undefined behavior with a hidden relocatable object file definition
- [libc][NFC] Fix features.h.def file header
- Disallow _BitInt as an underlying type for an enumeration
- [lldb] Remove CompileUnit::SetSupportFiles overload (NFC)
- [gn build] Port 460e84398a19
- [gn] port 01263c6c6fb495 (lldb-vscode -> lldb-dap)
- [libc][libm][GPU] Add missing vendor entrypoints to the GPU version of `libm` (#66034)
- [libcxx][test] Fix empty.gen selftest on windows (#69403)
- [LV] Add interleave only test case with reduction requiring casts.
- [mlir][drr] Set operand segment in rewrite
- [clang][index] Fix processing of CompoundAssignOperator at setting up reference roles (#69370)
- [lldb] Remove FileSpecList::GetFileSpecPointerAtIndex (NFC)
- [clang-format][NFC] Use UnwrappedLineParser::eof() for consistency
- [AMDGPU] Add doc updates for kernarg preloading (#67516)
- [mlir][sparse] use uint64_t type for dim/rank consistently (#69626)
- [NFC] Format some code in GlobalVariable.h
- [libc++][Android] Support libc++ testing on Android (#69274)
- [FunctionComparator] Differentiate instructions passing different MDStrings (#69543)
- [libc] Rework the 'fgets' implementation on the GPU (#69635)
- [libc] Partially implement 'rand' for the GPU (#66167)
- workflows/release-tasks: Fix release note artifact upload (#69522)
- [flang][openacc] Warn for num_gangs, num_workers and vector_length on acc serial (#69622)
- [lldb][NFCI] Remove duplicated code in DWARFParser (#69531)
- [libc++][Android] Add libcxx-builder-android Docker image (#69273)
- [mlir][python] remove mixins (#68853)
- [mlir][sparse] Remove old syntax (#69624)
- [Fuchsia] Add lldb-dap to LLDB distribution
- [RISCV] Apply `IsSignExtendingOpW = 1` on `fcvtmod.w.d` (#69633)
- [lldb] Remove FileSpecList::GetFilesMatchingPartialPath (NFC)
- [bazel][mlir] fixes for a2288a89
- [mlir][sparse] introduce sparse_tensor.crd_translate operation (#69630)
- [scudo] Add ConditionVariable in SizeClassAllocator64 (#69031)
- [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (#68495)
- [gn] port ab17ecd10767
- [mlir][python] simplify extensions (#69642)
- [analyzer] WebKit checkers: recognize dynamicDowncast as a safe function.
- [RISCV] Fix some GlobalISel tests using -march instead of -mtriple.
- [libc][NFC] Forcing data type in gettimeofday_test when comparing the diff. (#69652)
- [mlir][sparse] support BSR for cuSPARSE (libgen path only) (#69646)
- [flang][openacc] Do not error when bind symbol is defined later or external (#69657)
- [libc++][Android] Mark tests XFAIL/UNSUPPORTED (#69271)
- [libc++][Android] Don't list Android as supported yet (#69660)
- [Scalar] Use LLVMContext::MD_mem_parallel_loop_access directly (NFC) (#69549)
- Fix test clang/test/Driver/cl-offload.cu
- [RISCV] Support Xsfvqmaccdod and Xsfvqmaccqoq extensions (#68295)
- [LoongArch] Fix td pattern for CACOP LDPTE and LDDIR
- [ValueTracking] Implement sdiv/udiv support for isKnownNonNullFromDominatingCondition (#67282)
- [RISCV][NFC] Use !range bang operator (#66494)
- [mlir][scf] Implement getSingle... of LoopLikeOpInterface for scf::ForallOp (#67883)
- [Tablegen] Bugfix and refactor VarLenCodeEmitter HwModes. (#68795)
- [X86][AMX] remove related code of X86PreAMXConfigPass (#69569)
- [DWARF] Remove unused declaration verifyIndexes
- [MC][NFC] Allow MCInstrAnalysis to store state (#65479)
- [RISCV] Add more prefetch tests (#67644)
- [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (#67083)
- [BOLT] Use llvm::is_contained (NFC)
- [libc++] Fix uninitialized algorithms when using unconstrained comparison operators (#69373)
- [RISCV] Match prefetch address with offset (#66072)
- [clangd] Use llvm::erase_value (NFC)
- [mlir] Use llvm::erase_value (NFC)
- [RISCV][MC] Implement evaluateBranch for auipc+jalr pairs (#65480)
- [mlir][Bazel] Add missing dependency after d871daea8159c4b39b17b3ab8f3dd3adb1b51de3
- [CMake] Avoid build spam by switching to Debug message (#69497)
- [Driver] Use llvm::any_of (NFC)
- [FunctionAttrs] Only check ArgMem effects when inferring argument attrs (#69571)
- [llvm] Use llvm::find_if (NFC)
- [mlir][tosa] Update pass pipeline for TosaToLinalg (#69679)
- [mlir][scf] Implement getSingle... of LoopLikeOpinterface for scf::ParallelOp (#68511)
- [Transforms] Use llvm::erase_if (NFC)
- [DebugInfo] Use llvm::erase_if (NFC)
- [lldb] Use llvm::erase_if (NFC)
- [run-clang-tidy,clang-tidy-diff] Accept directory as value for -export-fixes (#69453)
- [llvm][llvm-readobj] Add AArch64 Tagged Address note type (#68568)
- [ExecutionEngine] Use llvm::is_contained (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in mlir-cat.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in StandaloneOps.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in StandalonePasses.cpp (NFC)
- Apply clang-tidy fixes for misc-unused-alias-decls in StandaloneExtension.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in standalone-opt.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in standalone-plugin.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in standalone-translate.cpp (NFC)
- [FunctionAttrs] Regenerate test checks (NFC)
- [BOLT] Filter itrace from perf script mmap & task events (#69585)
- [Driver][DragonFly][NFC] Some cleaning up
- [flang] Remove test from #69121 to fix gcc build with gcc < 10.0
- [Clang][SVE2.1] Add builtins for svrevd
- [mlir][ArmSME] Name arguments of SME intrinsics (NFC) (#69608)
- [mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (#69192)
- [mlir][Bazel] Add missing dependencies after aa0208d1bc52e45dc0032f41e58b50d3134d1089
- [flang][openmp] Update copyHostAssociateVar to use hlfir.assign for HLFIR (#69441)
- [PowerPC] Remove HTM instruction from P10 SchedModel (#69579)
- [NFC][LV] Add test for vectorizing fmuladd with another call (#68601)
- [flang][hlfir] Make the parent type the first component (#69348)
- [mlir][transform] Support for multiple top-level transform ops (#69615)
- [clangd] Don't run slow clang-tidy checks by default
- [clang-format] Add space in placement new expression
- [IR] Fix nested constant to instruction conversion (#69682)
- [mlir][SCF] Pass result of getAsOpFoldResult to getBoundedTileSize.
- [lld][NFC] Remove unnecessary else statements. (#69451)
- [MemCpyOpt] Remove unnecessary typed pointer handling (NFC)
- [mlir][Tosa] Fix Clamp verifier to handle quantized types.
- [mlir][tosa] Check for 0-ranked-tensors during fold (#68512)
- [mlir][SCF] Fix memory leak in LoopLikeSCFOpsTest.cpp
- [run-clang-tidy] Accept export directory if PyYAML is not installed (#69700)
- [LIT] Print discovered tests and percentages (#66057)
- [clang] Handle templated operators with reversed arguments (#69595)
- Revert "[LIT] Print discovered tests and percentages (#66057)" (#69715)
- [lldb][AArch64] Add release notes and documentation for SME (#66767)
- [AMDGPU] Segregate 16-bit fix-sgpr-copies tests. (#69353)
- [AA] Make LI and EphValues option in EarliestEscapeInfo (NFC)
- [GVN] Add tests for captured-before analysis (NFC)
- [Interpreter] Add initialization of array members (#66172)
- Recommit "[VPlan] Insert Trunc/Exts for reductions directly in VPlan."
- Diagnose use of VLAs in C++ by default
- Revert "Diagnose use of VLAs in C++ by default"
- [lldb] Remove more references to lldb-vscode (#69696)
- [libc++] mdspan - implement layout_stride (#69650)
- [gn build] Port 639a0986f3a3
- [clang-tidy]Add new check bugprone-casting-through-void (#69465)
- [gn build] Port 9a5c6f1760a3
- [libc][NFC] Attempt to deflake gettimeofday_test. (#69719)
- [llvm][AArch64][Assembly] Implement support to read/write FPMR (#69618)
- [Peephole] Check instructions from CopyMIs are still COPY (#69511)
- Change owner of Hexagon backend
- [mlir][index][spirv] Add conversion for index to spirv (#68085)
- [BOLT][RISCV] Handle CIE's produced by GNU as (#69578)
- [lit] Clean up internal shell parse errors with ScriptFatal (#68496)
- [MLIR][Presburger] Implement matrix inverse (#67382)
- Revert "[flang] Put ISO_Fortran_binding.h where it can be easily used (#69121)"
- Revert "[Intrinsics][ObjC] Mark objc_retain and friends as thisreturn."
- [clang][modules] Use file name as requested (#68957)
- [mlir][sparse] update COO buffer reader doc (#69664)
- [llvm] Use XMACROS for MachO platforms. (#69262)
- [Clang][OpenMP] Check if value is contained in array, not if it's contained in the first element (#69462)
- [RISCV] Use range-based for loops in RISCVOptWInstrs. NFC (#69647)
- clang-linker-wrapper/LinkerWrapperOpts.td: "--sysroot" => "--sysroot=" (#65313)
- [cmake] Option to create Ninja job pools depending on available resources (#65274)
- Diagnose use of VLAs in C++ by default
- [RISCV][CostModel] Recommit VPIntrinsics have same cost as their non-vp counterparts (#68752)
- Workaround for MSVC ARM64 build performance regression (#65215)
- Revert "[mlir][index][spirv] Add conversion for index to spirv (#68085)"
- [tsan][go]: add atomic or/and functions (#65695)
- [libc++][Android] Disable Android ABI list checking (#69666)
- Remove accidental merge conflict marker; NFC
- clarify tensor.pad docs for low/high config
- mlir/lib/Dialect/GPU/Transforms: improve context management in SerializeToCubin (#65779)
- Revert "[SLP] Improve gather tree nodes matching when users are PHIs. (#69392)"
- [RISCV][GISel] Disable call lowering for integers larger than 2*XLen. (#69144)
- [LVI] Handle icmp of ashr. (#68010)
- Revert "[RISCV][GISel] Disable call lowering for integers larger than 2*XLen. (#69144)"
- [RISCV] Use LMUL=1 for vmv_s_x_vl with non-undef passthru (#66659)
- Recommit "[RISCV][GISel] Disable call lowering for integers larger than 2*XLen. (#69144)"
- [RISCV][GISel] Support G_PTRTOINT and G_INTTOPTR (#69542)
- [RISCV][InsertVSETVLI] Make VL preserving vsetvli emission more explicit [nfc]
- [WebAssembly] Add exp10 libcall signatures (#69661)
- Fix MLIR gcc7 build: ambiguous overload from user conversion
- Fixed typo in GPU libm device library warning (#69752)
- [clang-tidy] modernize-avoid-bind only return for non-void function (#69207)
- [mlir][sparse] tiny cleanup making local 'using' explicit (#69740)
- [clang-format] Annotate do while while
- [Driver] Corrections for linker flags passed with relocatable linking on OpenBSD (#67254)
- [Libomptarget][NFC] Remove use of VLA in the AMDGPU plugin (#69761)
- [Modules] textual headers in submodules never resolve their `use`s (#69651)
- Reland [clang] [unittest] Add a test for Generic_GCC::GCCVersion::Parse (#69078)
- [gn build] Port 538b7ba2aacd
- [LLD] [COFF] Add a separate option for allowing duplicate weak symbols (#68077)
- [InstCombine][NFC] Precommit tests for https://reviews.llvm.org/D149918
- [mlir][linalg] regionBuilder for transpose, broadcast (#69742)
- [-Wunsafe-buffer-usage] Add AST info to the unclaimed DRE debug notes for analysis
- [MSVC] fix the build (#69634)
- [RISCV][llvm-mca] Vector Unit Stride Loads and stores use EEW and EMU… (#69409)
- [GISel] Add LookThroughInstrs for getIConstantVRegVal and getIConstan… (#68327)
- [OpenMP][mlir] Add translation for `if` in `omp.teams` (#69404)
- [Workflow] make code-format-helper.py mypy-safe (NFC) (#69691)
- [VPlan] Support scalable vectors in outer-loop vectorization
- Update SimplifyIndVar.cpp (#69760)
- [mlir][sparse] fix stack overflow due to memref.alloca in loops (#69786)
- [mlir][sparse] implement sparse_tensor.crd_translate operation (#69653)
- [RISCV][GISel] Minor refactoring of RISCVCallReturnHandler and RISCVIncomingValueHandler to match other targets (#69757)
- Fix build warning caused by mixed signed/unsigned compare (#69797)
- [mlir][sparse] support CSR/BSR conversion (#69800)
- [CI] Set minimal permission on libcxx-check-generated-file workflow (#69737)
- [InstCombine] Precommit tests for PR67216
- [InstCombine] optimize powi(X,Y)/X with Ofast (#67236)
- [MLIR][python bindings] invalidate ops after PassManager run (#69746)
- [clang][dataflow]Use cast_or_null instead of cast to prevent crash (#68510)
- [Github] Fetch number of commits in PR for docs action (#69763)
- [lldb][test] Turn ObjC string literals to C-style literals (NFC) (#69793)
- [OpenMPOpt][FIX] Properly track changes to NestedParallelism
- [Attributor][FIX] Interposable constants cannot be propagated
- [OpenMP][NFC] Move DebugKind to make it reusable from the host
- Revert "[mlir] Silence a few -Wunused-but-set-parameter warnings" (#68667)
- [MLIR][python bindings][fix] invalidate ops after PassManager run
- Add IR name to -print-pass-numbers output
- [AST] Use explicit type erasure in TypeSourceInfo constructor (#68435)
- Fix typos and formatting in GettingStarted.md (#68537)
- [mlir] Avoid including <alloca.h> on DragonFly
- Fix typos in Debug.h (#68761)
- [compiler-rt] Fix a warning
- [llvm][CMake] Respect LIBCXX_HARDENING_MODE on command-line (#68541)
- [compiler-rt] Switch LLD specific tests to a more precise option (#69781)
- [DebugInfo] Use llvm::erase_value (NFC)
- [flang] Use llvm::any_of (NFC)
- [polly] Use llvm::erase_value (NFC)
- [Serialization] Use llvm::is_contained (NFC)
- [MachineBasicBlock] Fix SlotIndexUpdater for insertion order (#69424)
- [Driver][NetBSD][NFC] Some cleaning up
- [Github] Remove CMake options from docs CI resetting defaults
- [clang-tidy][DOC] Fix 'table cell spanning'
- Clang: Define macro _MIPS_SPFPSET
- [Windows] Add git-clang-format wrapper bat file (#69228)
- [analyzer][NFC] Substitute operator() with lambda in StreamChecker
- [RISCV] Replace RISCV -> RISC-V in comments. NFC
- [lldb] Update qRegisterInfo docs to recommend target.xml (#69853)
- [Attributor][NFC] Precommit test
- [Attributor] Ignore different kernels for kernel lifetime objects
- Reland [LLD] [COFF] Don't try to detect MSVC installations in mingw mode
- [LV] Enforce order of reductions with intermediate stores in VPlan (NFC)
- [OpenMP][FIX] Ensure thread states do not crash on the GPU
- [OpenMP] Basic BumpAllocator for (AMD)GPUs (#69806)
- [OpenMP] Rewrite test to check the correct (CPU) result
- [Github] Add clang-tools-extra docs to CI (#69827)
- [Github] Add lldb docs step to Github docs action (#69832)
- [mlir][doc] Include ml_program passes in passes doc
- [mlir][doc] Add basic doc for extraTraitClassDeclaration.
- [Github] Fetch an additional commit for docs CI on PRs
- [lldb] improve dwo path in missing dwo error when relative (#69783)
- Apply clang-tidy fixes for llvm-qualified-auto in LowerToLLVM.cpp (NFC)
- Apply clang-tidy fixes for llvm-qualified-auto in LowerToLLVM.cpp (NFC)
- Apply clang-tidy fixes for llvm-qualified-auto in CallGraph.cpp (NFC)
- Apply clang-tidy fixes for llvm-qualified-auto in IRNumbering.cpp (NFC)
- Reland: "[mlir][index][spirv] Add conversion for index to spirv" (#69790)
- [docs] Fix suggested darker command in coding standards (#69860)
- [LIT] Print discovered tests and percentages (#66057) (#69831)
- [llvm] Stop including Endian.h (NFC)
- [Clang][OHOS] Keep ARM ABI selection logic in sync between Clang and LLVM (#68656)
- [AMDGPU] Set size to all SOP pseudos (#69756)
- [libc++][PSTL] Implement std::move
- [gn build] Port d2a46e6480f3
- [VPlan] Simplify redundant trunc (zext A) pairs to A.
- [VPlan] Make ExpandedSCEVs argument const (NFC).
- [clangd] Show alignment for records and fields decls (#67213)
- [mlir][minimal-opt] Fix typo
- [llvm] Stop including llvm/ADT/SmallString.h (NFC)
- [mlir] Remove an extraneous typename (NFC)
- [llvm] Use llvm::any_of (NFC)
- [Utils] Use std::remove_pointer_t (NFC)
- [CodeGen][Remarks] Add the function name to the stack size remark (#69346)
- [llvm] Stop including llvm/ADT/StringMap.h (NFC)
- [lldb] Remove an unused using decl (NFC)
- [llvm] Stop including llvm/ADT/DepthFirstIterator.h (NFC)
- [compiler-rt] Use std::clamp (NFC)
- [OpenMP][FIX] Ensure test runs correct with (at least) 2 threads
- [clang-format][NFC] Simplify the logic in a return statement
- Apply clang-tidy fixes for misc-include-cleaner in toyc.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in Dialect.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in MLIRGen.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in AST.cpp (NFC)
- Apply clang-tidy fixes for misc-include-cleaner in AST.cpp (NFC)
- [OpenMP][FIX] Fix memset oversight to partially unblock test
- [JITLink] Allow multiple relocations at same offset in EHFrameEdgeFixer (#68252)
- [mlir][DeadCodeAnalysis] Don't Require `RegionBranchTerminatorOpInterface` in `visitRegionTerminator()` (#69043)
- [llvm] Stop including llvm/ADT/iterator_range.h (NFC)
- [LLDB] Update breakpoint-command.test to use string instead of number. (#69796)
- [LegacyPM] Remove LowerExpectIntrinsicPass
- [CodeLayout] cache-directed sort: limit max chain size (#69039)
- [llvm-profgen] More tweaks to warnings (#68608)
- [mlir][bufferization] Ownership-based deallocation: Allow manual (de)allocs (#68648)
- [OpenMP][FIX] Enlarge thread state array, improve test and add second
- [C++20] [Modules] [Driver] Don't enable -fdelayed-template-parsing by default on windows with C++20 (#69431)
- [Clang][RISCV] Support CSRs in clobbered registers of inline assembly (#67646)
- [TableGen][NFC] Remove MultiClass argument and Scoper in QualifyName (#69297)
- Use llvm::count (NFC)
- [lldb] Use llvm::is_contained (NFC)
- [TextAPI] Use std::remove_reference_t (NFC)
- [RISCV] Disable hasAllNBitUsers for vector types.
- [mlir][SCF] Use getResult() instead of static_cast<Value>().
- [mlir][Bazel] Add missing dependency.
- [clang-format] Add a new style for the clang-format source code (#69814)
- [clang][dataflow] Remove `DataflowAnalysisContext::flowConditionIsTautology()`. (#69601)
- [BOLT][RISCV] Use target features from object file (#69836)
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (#68526)
- [flang] Do not stop on mismatched DATA substring length (#69336)
- [MemDep] Use EarliestEscapeInfo (#69727)
- [BOLT][RISCV] Set minimum function alignment to 2 for RVC (#69837)
- [flang] Place MIN/MAX A1/A2 first in semantic analysis (#69722)
- Typos: 'maxium', 'minium'
- [clang] Fix designated initializers inside templates (#69712)
- [mlir][tosa] Check for unranked tensors during validation (#68509)
- [mlir][SCF] Minor fixes in documentation examples (#69802)
- [libc++] <experimental/simd> Add operator value_type() of simd reference (#68960)
- [InstCombine] Remove scalable vector extracts to and from the same type (#69702)
- [STLExtras] Undo C++20 hack
- [Sema] Change order of displayed overloads in diagnostics
- [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (#69456)
- [AArch64] Allow SVE code generation for fixed-width vectors (#67122)
- [mlir][VectorOps] Support string literals in `vector.print`
- [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage` pass
- Always align SVE vectors to 16 bytes and predicates to 2 bytes
- [mlir][SVE] Add an e2e test for vector.contract
- [mlir][vector] Add scalable vectors to tests for vector.contract
>From d56e1cb81e9e7dd9d1294cf1d12397df4bed72d2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:38:12 +0000
Subject: [PATCH 1/5] [mlir][VectorOps] Support string literals in
`vector.print`
Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:
```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")
func.func @entry() {
%0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}
``
So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
// Print a vector of characters ;)
vector.print str "Hello, World!"
return
}
```
Most of the logic for this is now shared with `cf.assert` which already
does something similar.
---
.../Conversion/LLVMCommon/PrintCallHelper.h | 36 ++++++++++
.../mlir/Dialect/Vector/IR/VectorOps.td | 37 +++++++++--
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 49 +-------------
mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 +
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +-
.../VectorToLLVM/vector-to-llvm.mlir | 14 ++++
.../Dialect/Vector/CPU/test-hello-world.mlir | 10 +++
8 files changed, 168 insertions(+), 51 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef symbolName, StringRef string,
+ const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 917b27a40f26f13..0da4ca617a94c3a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -2477,12 +2478,18 @@ def Vector_TransposeOp :
}
def Vector_PrintOp :
- Vector_Op<"print", []>,
+ Vector_Op<"print", [
+ PredOpTrait<
+ "`source` or `punctuation` are not set printing strings",
+ CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+ >,
+ ]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
- "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+ "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+ OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
@@ -2521,6 +2528,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```
+
+ Additionally, to aid with debugging and testing `vector.print` can also
+ print constant strings:
+
+ ```mlir
+ vector.print str "Hello, World!"
+ ```
}];
let extraClassDeclaration = [{
Type getPrintType() {
@@ -2529,11 +2543,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
- build($_builder, $_state, {}, punctuation);
+ build($_builder, $_state, {}, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source), [{
+ build($_builder, $_state, source, PrintPunctuation::NewLine);
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+ build($_builder, $_state, source, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::llvm::StringRef":$string), [{
+ build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];
- let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+ let assemblyFormat = [{
+ ($source^ `:` type($source))?
+ oilist(
+ `str` $stringLiteral
+ | `punctuation` $punctuation)
+ attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
- std::string prefix = "assert_msg_";
- int counter = 0;
- while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
- ++counter;
- return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef msg,
- const LLVMTypeConverter &typeConverter) {
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(moduleOp.getBody());
- MLIRContext *ctx = builder.getContext();
-
- // Create a zero-terminated byte representation and allocate global symbol.
- SmallVector<uint8_t> elementVals;
- elementVals.append(msg.begin(), msg.end());
- elementVals.push_back(0);
- auto dataAttrType = RankedTensorType::get(
- {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
- auto dataAttr =
- DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
- auto arrayTy =
- LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
- dataAttr);
-
- // Emit call to `printStr` in runtime library.
- builder.restoreInsertionPoint(ip);
- auto msgAddr = builder.create<LLVM::AddressOfOp>(
- loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
- SmallVector<LLVM::GEPArg> indices(1, 0);
- Value gep = builder.create<LLVM::GEPOp>(
- loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
- indices);
- Operation *printer = LLVM::lookupOrCreatePrintStrFn(
- moduleOp, typeConverter.useOpaquePointers());
- builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
- gep);
-}
-
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+ LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+ *getTypeConverter());
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
+ PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+ StringRef symbolName) {
+ static int counter = 0;
+ std::string uniqueName = std::string(symbolName);
+ while (moduleOp.lookupSymbol(uniqueName)) {
+ uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+ }
+ return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+ ModuleOp moduleOp, StringRef symbolName,
+ StringRef string,
+ const LLVMTypeConverter &typeConverter) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(string.begin(), string.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+ indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+ moduleOp, typeConverter.useOpaquePointers());
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
auto punct = printOp.getPunctuation();
- if (punct != PrintPunctuation::NoPunctuation) {
+ if (auto stringLiteral = printOp.getStringLiteral()) {
+ LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ *stringLiteral, *getTypeConverter());
+ } else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
// -----
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+ vector.print str "Hello, World!"
+ return
+}
+
+// -----
+
func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
return %0 : vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
new file mode 100644
index 000000000000000..c4076e65151ac72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ // CHECK: Hello, World!
+ vector.print str "Hello, World!"
+ return
+}
>From 05122b64a403c714fc2cd6291a08e813adb0c195 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 18:53:05 +0000
Subject: [PATCH 2/5] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage`
pass
This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.
This pass currently fixes two issues.
It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.
For example:
```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```
The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size`
x `element size` bytes. For non-legal vector types like
`vector<[8]xf32>` this results in 8 x 4 = 32-byte alignment, but the
backend only supports up to 16-byte alignment for SVE vectors on the
stack. Explicitly setting a smaller alignment prevents this issue.
---
.../mlir/Dialect/ArmSVE/CMakeLists.txt | 1 +
.../Dialect/ArmSVE/Transforms/CMakeLists.txt | 5 +
.../mlir/Dialect/ArmSVE/Transforms/Passes.h | 33 ++
.../mlir/Dialect/ArmSVE/Transforms/Passes.td | 67 ++++
mlir/include/mlir/InitAllPasses.h | 2 +
.../Dialect/ArmSVE/Transforms/CMakeLists.txt | 2 +
.../Transforms/LegalizeVectorStorage.cpp | 310 ++++++++++++++++++
.../ArmSVE/legalize-vector-storage.mlir | 160 +++++++++
.../ArmSVE/arrays-of-scalable-vectors.mlir | 121 +++++++
9 files changed, 701 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..7226642daf86172
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE)
+add_public_tablegen_target(MLIRArmSVEPassIncGen)
+
+add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
new file mode 100644
index 000000000000000..317fb9021b3c577
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::arm_sve {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+/// Pass to legalize the types of mask stores.
+std::unique_ptr<Pass> createLegalizeVectorStoragePass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+} // namespace mlir::arm_sve
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
new file mode 100644
index 000000000000000..35c49607181da0c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
@@ -0,0 +1,67 @@
+//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeVectorStorage
+ : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> {
+ let summary = "Ensures stores of SVE vector types will be legal";
+ let description = [{
+ This pass ensures that loads, stores, and allocations of SVE vector types
+ will be legal in the LLVM backend. It does this at the memref level, so this
+ pass must be applied before lowering all the way to LLVM.
+
+ This pass currently fixes two issues.
+
+ ## Loading and storing predicate types
+
+ It is only legal to load/store predicate types equal to (or greater than) a
+ full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller
+ predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full
+ predicate type (referred to as a `svbool`) before and after storing and
+ loading respectively. This pass does this by widening allocations and
+ inserting conversion intrinsics.
+
+ For example:
+
+ ```mlir
+ %alloca = memref.alloca() : memref<vector<[4]xi1>>
+ %mask = vector.constant_mask [4] : vector<[4]xi1>
+ memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+ %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+ ```
+ Becomes:
+ ```mlir
+ %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %mask = vector.constant_mask [4] : vector<[4]xi1>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+ memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
+ %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
+ %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
+ ```
+
+ ## Relax alignments for SVE vector allocas
+
+ The storage for SVE vector types only needs to have an alignment that
+ matches the element type (for example 4 byte alignment for `f32`s). However,
+ the LLVM backend currently defaults to aligning to `base size` x
+ `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this
+ results in 8 x 4 = 32-byte alignment, but the backend only supports up to
+ 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller
+ alignment prevents this issue.
+ }];
+ let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()";
+ let dependentDialects = ["func::FuncDialect",
+ "memref::MemRefDialect", "vector::VectorDialect",
+ "arm_sve::ArmSVEDialect"];
+}
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..7301905954f56d8 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
transform::registerTransformPasses();
vector::registerVectorPasses();
arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 2f1c43fae240d51..a70c489a51fea9a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
+ LegalizeVectorStorage.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
+ MLIRArmSVEPassIncGen
LINK_LIBS PUBLIC
MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
new file mode 100644
index 000000000000000..610eb38089c4c88
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -0,0 +1,310 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arm_sve {
+#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sve
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+constexpr StringLiteral kPassLabel("__arm_sve_legalize_vector_storage__");
+
+namespace {
+
+/// A (legal) SVE predicate mask that has a logical size, i.e. the number of
+/// bits match the number of lanes it masks (such as vector<[4]xi1>), but is too
+/// small to be stored to memory.
+bool isLogicalSVEPredicateType(VectorType type) {
+ return type.getRank() > 0 && type.getElementType().isInteger(1) &&
+ type.getScalableDims().back() && type.getShape().back() < 16 &&
+ llvm::isPowerOf2_32(type.getShape().back()) &&
+ !llvm::is_contained(type.getScalableDims().drop_back(), true);
+}
+
+VectorType widenScalableMaskTypeToSvbool(VectorType type) {
+ assert(isLogicalSVEPredicateType(type));
+ return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
+ TLegalizerCallback callback) {
+ // Clone the previous op to preserve any properties/attributes.
+ auto newOp = op.clone();
+ rewriter.insert(newOp);
+ rewriter.replaceOp(op, callback(newOp));
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
+ TLegalizerCallback callback) {
+ replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
+ // Mark our `unrealized_conversion_casts` with a pass label.
+ return rewriter.create<UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{op.getResult().getType()},
+ ValueRange{callback(newOp)},
+ NamedAttribute(rewriter.getStringAttr(kPassLabel),
+ rewriter.getUnitAttr()));
+ });
+}
+
+/// Extracts the legal memref value from the `unrealized_conversion_casts` added
+/// by this pass.
+static FailureOr<Value> getLegalMemRef(Value illegalMemref) {
+ Operation *definingOp = illegalMemref.getDefiningOp();
+ if (!definingOp || !definingOp->hasAttr(kPassLabel))
+ return failure();
+ auto unrealizedConversion =
+ llvm::cast<UnrealizedConversionCastOp>(definingOp);
+ return unrealizedConversion.getOperand(0);
+}
+
+/// The default alignment of an alloca may request overaligned sizes for SVE
+/// types, which will fail during stack frame allocation. This rewrite
+/// explicitly adds a reasonable alignment to allocas of scalable types.
+struct RelaxScalableVectorAllocaAlignment
+ : public OpRewritePattern<memref::AllocaOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
+ PatternRewriter &rewriter) const override {
+ auto elementType = allocaOp.getType().getElementType();
+ auto vectorType = llvm::dyn_cast<VectorType>(elementType);
+ if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
+ return failure();
+
+ unsigned elementByteSize =
+ vectorType.getElementType().getIntOrFloatBitWidth() / 8;
+
+ unsigned aligment = std::max(1u, elementByteSize);
+ allocaOp.setAlignment(aligment);
+
+ return success();
+ }
+};
+
+/// Replaces allocations of SVE predicates smaller than an svbool with a wider
+/// allocation and a tagged unrealized conversion.
+///
+/// Example
+/// ```
+/// %alloca = memref.alloca() : memref<vector<[4]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+/// %alloca = builtin.unrealized_conversion_cast %widened
+/// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
+/// {__arm_sve_legalize_vector_storage__}
+/// ```
+template <typename AllocLikeOp>
+struct LegalizeAllocLikeOpConversion : public OpRewritePattern<AllocLikeOp> {
+ using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
+ PatternRewriter &rewriter) const override {
+ auto vectorType =
+ llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
+
+ if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+ return failure();
+
+ // Replace this alloc-like op of an SVE mask with one of a (storable)
+ // svbool_t mask. A temporary unrealized_conversion_cast is added to the old
+ // type to allow local rewrites.
+ replaceOpWithUnrealizedConversion(
+ rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
+ newAllocLikeOp.getResult().setType(
+ llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
+ {}, widenScalableMaskTypeToSvbool(vectorType))));
+ return newAllocLikeOp;
+ });
+
+ return success();
+ }
+};
+
+/// Replaces vector.type_casts of unrealized conversions to illegal memref types
+/// with legal type casts, followed by unrealized conversions.
+///
+/// Example:
+/// ```
+/// %alloca = builtin.unrealized_conversion_cast %widened
+/// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
+/// {__arm_sve_legalize_vector_storage__}
+/// %cast = vector.type_cast %alloca
+/// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %widened_cast = vector.type_cast %widened
+/// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+/// %cast = builtin.unrealized_conversion_cast %widened_cast
+/// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
+/// {__arm_sve_legalize_vector_storage__}
+/// ```
+struct LegalizeVectorTypeCastConversion
+ : public OpRewritePattern<vector::TypeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto resultType = typeCastOp.getResultMemRefType();
+ auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
+
+ if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+ return failure();
+
+ auto legalMemref = getLegalMemRef(typeCastOp.getMemref());
+ if (failed(legalMemref))
+ return failure();
+
+ // Replace this vector.type_cast with one of a (storable) svbool_t mask.
+ replaceOpWithUnrealizedConversion(
+ rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
+ newTypeCast.setOperand(*legalMemref);
+ newTypeCast.getResult().setType(
+ llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
+ {}, widenScalableMaskTypeToSvbool(vectorType))));
+ return newTypeCast;
+ });
+
+ return success();
+ }
+};
+
+/// Replaces stores to unrealized conversions to illegal memref types with
+/// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
+///
+/// Example:
+/// ```
+/// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
+/// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
+/// ```
+struct LegalizeMemrefStoreConversion
+ : public OpRewritePattern<memref::StoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = storeOp.getLoc();
+
+ Value valueToStore = storeOp.getValueToStore();
+ auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+ if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+ return failure();
+
+ auto legalMemref = getLegalMemRef(storeOp.getMemref());
+ if (failed(legalMemref))
+ return failure();
+
+ auto legalMaskType = widenScalableMaskTypeToSvbool(
+ llvm::cast<VectorType>(valueToStore.getType()));
+ auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
+ loc, legalMaskType, valueToStore);
+ // Replace this store with a conversion to a storable svbool_t mask,
+ // followed by a wider store.
+ replaceOpWithLegalizedOp(rewriter, storeOp,
+ [&](memref::StoreOp newStoreOp) {
+ newStoreOp.setOperand(0, convertToSvbool);
+ newStoreOp.setOperand(1, *legalMemref);
+ return newStoreOp;
+ });
+
+ return success();
+ }
+};
+
+/// Replaces loads from unrealized conversions to illegal memref types with
+/// (legal) wider loads, followed by `arm_sve.convert_from_svbool`s.
+///
+/// Example:
+/// ```
+/// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+/// ```
+/// is rewritten into:
+/// ```
+/// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
+/// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
+/// ```
+struct LegalizeMemrefLoadConversion : public OpRewritePattern<memref::LoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = loadOp.getLoc();
+
+ Value loadedMask = loadOp.getResult();
+ auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
+
+ if (!vectorType || !isLogicalSVEPredicateType(vectorType))
+ return failure();
+
+ auto legalMemref = getLegalMemRef(loadOp.getMemref());
+ if (failed(legalMemref))
+ return failure();
+
+ auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
+ // Replace this load with a legal load of an svbool_t type, followed by a
+ // conversion back to the original type.
+ replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
+ newLoadOp.setMemRef(*legalMemref);
+ newLoadOp.getResult().setType(legalMaskType);
+ return rewriter.create<arm_sve::ConvertFromSvboolOp>(
+ loc, loadedMask.getType(), newLoadOp);
+ });
+
+ return success();
+ }
+};
+
+struct LegalizeVectorStorage
+ : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<RelaxScalableVectorAllocaAlignment,
+ LegalizeAllocLikeOpConversion<memref::AllocaOp>,
+ LegalizeAllocLikeOpConversion<memref::AllocOp>,
+ LegalizeVectorTypeCastConversion,
+ LegalizeMemrefStoreConversion, LegalizeMemrefLoadConversion>(
+ patterns.getContext());
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ ConversionTarget target(getContext());
+ target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+ [](UnrealizedConversionCastOp unrealizedConversion) {
+ return !unrealizedConversion->hasAttr(kPassLabel);
+ });
+ // This detects if we failed to completely legalize the IR.
+ if (failed(applyPartialConversion(getOperation(), target, {})))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
+ return std::make_unique<LegalizeVectorStorage>();
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
new file mode 100644
index 000000000000000..fda8e2e0fab9618
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s
+
+/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>)
+func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[1]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
+ return %reload : vector<[1]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>)
+func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[2]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
+ return %reload : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>)
+func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[4]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
+ return %reload : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>)
+func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
+ return %reload : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_2d_mask_and_reload_slice(
+// CHECK-SAME: %[[MASK:.*]]: vector<3x[8]xi1>)
+func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
+ // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<3x[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
+ // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+ %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
+ // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+ %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
+ // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @set_sve_alloca_alignment
+func.func @set_sve_alloca_alignment() {
+ // CHECK-COUNT-6: alignment = 1
+ %a1 = memref.alloca() : memref<vector<[32]xi8>>
+ %a2 = memref.alloca() : memref<vector<[16]xi8>>
+ %a3 = memref.alloca() : memref<vector<[8]xi8>>
+ %a4 = memref.alloca() : memref<vector<[4]xi8>>
+ %a5 = memref.alloca() : memref<vector<[2]xi8>>
+ %a6 = memref.alloca() : memref<vector<[1]xi8>>
+
+ // CHECK-COUNT-6: alignment = 2
+ %b1 = memref.alloca() : memref<vector<[32]xi16>>
+ %b2 = memref.alloca() : memref<vector<[16]xi16>>
+ %b3 = memref.alloca() : memref<vector<[8]xi16>>
+ %b4 = memref.alloca() : memref<vector<[4]xi16>>
+ %b5 = memref.alloca() : memref<vector<[2]xi16>>
+ %b6 = memref.alloca() : memref<vector<[1]xi16>>
+
+ // CHECK-COUNT-6: alignment = 4
+ %c1 = memref.alloca() : memref<vector<[32]xi32>>
+ %c2 = memref.alloca() : memref<vector<[16]xi32>>
+ %c3 = memref.alloca() : memref<vector<[8]xi32>>
+ %c4 = memref.alloca() : memref<vector<[4]xi32>>
+ %c5 = memref.alloca() : memref<vector<[2]xi32>>
+ %c6 = memref.alloca() : memref<vector<[1]xi32>>
+
+ // CHECK-COUNT-6: alignment = 8
+ %d1 = memref.alloca() : memref<vector<[32]xi64>>
+ %d2 = memref.alloca() : memref<vector<[16]xi64>>
+ %d3 = memref.alloca() : memref<vector<[8]xi64>>
+ %d4 = memref.alloca() : memref<vector<[4]xi64>>
+ %d5 = memref.alloca() : memref<vector<[2]xi64>>
+ %d6 = memref.alloca() : memref<vector<[1]xi64>>
+
+ // CHECK-COUNT-6: alignment = 4
+ %e1 = memref.alloca() : memref<vector<[32]xf32>>
+ %e2 = memref.alloca() : memref<vector<[16]xf32>>
+ %e3 = memref.alloca() : memref<vector<[8]xf32>>
+ %e4 = memref.alloca() : memref<vector<[4]xf32>>
+ %e5 = memref.alloca() : memref<vector<[2]xf32>>
+ %e6 = memref.alloca() : memref<vector<[1]xf32>>
+
+ // CHECK-COUNT-6: alignment = 8
+ %f1 = memref.alloca() : memref<vector<[32]xf64>>
+ %f2 = memref.alloca() : memref<vector<[16]xf64>>
+ %f3 = memref.alloca() : memref<vector<[8]xf64>>
+ %f4 = memref.alloca() : memref<vector<[4]xf64>>
+ %f5 = memref.alloca() : memref<vector<[2]xf64>>
+ %f6 = memref.alloca() : memref<vector<[1]xf64>>
+
+ "prevent.dce"(
+ %a1, %a2, %a3, %a4, %a5, %a6,
+ %b1, %b2, %b3, %b4, %b5, %b6,
+ %c1, %c2, %c3, %c4, %c5, %c6,
+ %d1, %d2, %d3, %d4, %d5, %d6,
+ %e1, %e2, %e3, %e4, %e5, %e6,
+ %f1, %f2, %f3, %f4, %f5, %f6)
+ : (memref<vector<[32]xi8>>, memref<vector<[16]xi8>>, memref<vector<[8]xi8>>, memref<vector<[4]xi8>>, memref<vector<[2]xi8>>, memref<vector<[1]xi8>>,
+ memref<vector<[32]xi16>>, memref<vector<[16]xi16>>, memref<vector<[8]xi16>>, memref<vector<[4]xi16>>, memref<vector<[2]xi16>>, memref<vector<[1]xi16>>,
+ memref<vector<[32]xi32>>, memref<vector<[16]xi32>>, memref<vector<[8]xi32>>, memref<vector<[4]xi32>>, memref<vector<[2]xi32>>, memref<vector<[1]xi32>>,
+ memref<vector<[32]xi64>>, memref<vector<[16]xi64>>, memref<vector<[8]xi64>>, memref<vector<[4]xi64>>, memref<vector<[2]xi64>>, memref<vector<[1]xi64>>,
+ memref<vector<[32]xf32>>, memref<vector<[16]xf32>>, memref<vector<[8]xf32>>, memref<vector<[4]xf32>>, memref<vector<[2]xf32>>, memref<vector<[1]xf32>>,
+ memref<vector<[32]xf64>>, memref<vector<[16]xf64>>, memref<vector<[8]xf64>>, memref<vector<[4]xf64>>, memref<vector<[2]xf64>>, memref<vector<[1]xf64>>) -> ()
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir
new file mode 100644
index 000000000000000..260cd1eabe7dae1
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+/// This tests basic functionality of arrays of scalable vectors, which in MLIR
+/// are vectors with a single trailing scalable dimension. This test requires
+/// the -arm-sve-legalize-vector-storage pass to ensure the loads/stores done
+/// here are be legal for the LLVM backend.
+
+func.func @read_and_print_2d_vector(%memref: memref<3x?xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim = memref.dim %memref, %c1 : memref<3x?xf32>
+ %mask = vector.create_mask %c2, %dim : vector<3x[8]xi1>
+ %vector = vector.transfer_read %memref[%c0,%c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[8]xf32>
+
+ /// TODO: Support vector.print for arrays of scalable vectors.
+ %row0 = vector.extract %vector[0] : vector<[8]xf32> from vector<3x[8]xf32>
+ %row1 = vector.extract %vector[1] : vector<[8]xf32> from vector<3x[8]xf32>
+ %row2 = vector.extract %vector[2] : vector<[8]xf32> from vector<3x[8]xf32>
+
+ /// Print each of the vectors. This only checks the first eight elements (which
+ /// works for all vscale >= 1).
+
+ // CHECK-LABEL: TEST 1
+ vector.print str "TEST 1 (print and read 2D arrays of scalable vectors)"
+ // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8
+ vector.print %row0 : vector<[8]xf32>
+ // CHECK: ( 8, 8, 8, 8, 8, 8, 8, 8
+ vector.print %row1 : vector<[8]xf32>
+ /// This last row is all zero due to our mask.
+ // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0
+ vector.print %row2 : vector<[8]xf32>
+
+ return
+}
+
+func.func @print_1x2xVSCALExf32(%vector: vector<1x2x[4]xf32>) {
+ /// TODO: Support vector.print for arrays of scalable vectors.
+ %slice0 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32>
+ %slice1 = vector.extract %vector[0, 1] : vector<[4]xf32> from vector<1x2x[4]xf32>
+ vector.print %slice0 : vector<[4]xf32>
+ vector.print %slice1 : vector<[4]xf32>
+ return
+}
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 2 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+ %dim_b = memref.dim %b, %c2 : memref<1x2x?xf32>
+ %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+ %mask_b = vector.create_mask %c2, %c3, %dim_b : vector<1x2x[4]xi1>
+
+ vector.print str "TEST 2 (reading and adding two 3D arrays of scalable vectors)"
+
+ /// Print each of the vectors. This only checks the first four elements (which
+ /// works for all vscale >= 1).
+
+ // CHECK-LABEL: Vector A
+ // CHECK-NEXT: ( 5, 5, 5, 5
+ // CHECK-NEXT: ( 5, 5, 5, 5
+ vector.print str "\nVector A"
+ %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+ func.call @print_1x2xVSCALExf32(%vector_a) : (vector<1x2x[4]xf32>) -> ()
+
+ vector.print str "\nVector B"
+ // CHECK-LABEL: Vector B
+ // CHECK-NEXT: ( 4, 4, 4, 4
+ // CHECK-NEXT: ( 4, 4, 4, 4
+ %vector_b = vector.transfer_read %b[%c0, %c0, %c0], %cst, %mask_b {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+ func.call @print_1x2xVSCALExf32(%vector_b) : (vector<1x2x[4]xf32>) -> ()
+
+ %sum = arith.addf %vector_a, %vector_b : vector<1x2x[4]xf32>
+ // CHECK-LABEL: Sum
+ // CHECK-NEXT: ( 9, 9, 9, 9
+ // CHECK-NEXT: ( 9, 9, 9, 9
+ vector.print str "\nSum"
+ func.call @print_1x2xVSCALExf32(%sum) : (vector<1x2x[4]xf32>) -> ()
+
+ return
+}
+
+func.func @entry() {
+ %vscale = vector.vscale
+
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %f32_8 = arith.constant 8.0 : f32
+ %f32_5 = arith.constant 5.0 : f32
+ %f32_4 = arith.constant 4.0 : f32
+
+ vector.print str "\n====================\n"
+
+ %test_1_memref_size = arith.muli %vscale, %c8 : index
+ %test_1_memref = memref.alloca(%test_1_memref_size) : memref<3x?xf32>
+
+ linalg.fill ins(%f32_8 : f32) outs(%test_1_memref :memref<3x?xf32>)
+
+ func.call @read_and_print_2d_vector(%test_1_memref) : (memref<3x?xf32>) -> ()
+
+ vector.print str "\n====================\n"
+
+ %test_2_memref_size = arith.muli %vscale, %c4 : index
+ %test_2_memref_a = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32>
+ %test_2_memref_b = memref.alloca(%test_2_memref_size) : memref<1x2x?xf32>
+
+ linalg.fill ins(%f32_5 : f32) outs(%test_2_memref_a :memref<1x2x?xf32>)
+ linalg.fill ins(%f32_4 : f32) outs(%test_2_memref_b :memref<1x2x?xf32>)
+
+ func.call @add_arrays_of_scalable_vectors(
+ %test_2_memref_a, %test_2_memref_b) : (memref<1x2x?xf32>, memref<1x2x?xf32>) -> ()
+
+ vector.print str "\n====================\n"
+
+ return
+}
>From 74f47fa3d827dfcd4866879aeac0863553b4fa2c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 23 Oct 2023 12:59:03 +0000
Subject: [PATCH 3/5] Always align SVE vectors to 16 bytes and predicates to 2
bytes
---
.../Transforms/LegalizeVectorStorage.cpp | 10 ++++-----
.../ArmSVE/legalize-vector-storage.mlir | 22 +++++++++----------
2 files changed, 15 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index 610eb38089c4c88..e7edcf6ec789235 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -82,15 +82,13 @@ struct RelaxScalableVectorAllocaAlignment
LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
PatternRewriter &rewriter) const override {
- auto elementType = allocaOp.getType().getElementType();
- auto vectorType = llvm::dyn_cast<VectorType>(elementType);
+ auto memrefElementType = allocaOp.getType().getElementType();
+ auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
return failure();
- unsigned elementByteSize =
- vectorType.getElementType().getIntOrFloatBitWidth() / 8;
-
- unsigned aligment = std::max(1u, elementByteSize);
+ // Set alignment based on the defaults for SVE vectors and predicates.
+ unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
allocaOp.setAlignment(aligment);
return success();
diff --git a/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
index fda8e2e0fab9618..61879c48712f4d2 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir
@@ -7,7 +7,7 @@
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>)
func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
- // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
%alloca = memref.alloca() : memref<vector<[1]xi1>>
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -24,7 +24,7 @@ func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vecto
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>)
func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
- // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
%alloca = memref.alloca() : memref<vector<[2]xi1>>
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -41,7 +41,7 @@ func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vecto
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>)
func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
- // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
%alloca = memref.alloca() : memref<vector<[4]xi1>>
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -58,7 +58,7 @@ func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vecto
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>)
func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
- // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
%alloca = memref.alloca() : memref<vector<[8]xi1>>
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -77,7 +77,7 @@ func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vecto
func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<3x[16]xi1>>
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
%alloca = memref.alloca() : memref<vector<3x[8]xi1>>
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
@@ -95,7 +95,7 @@ func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]
// CHECK-LABEL: @set_sve_alloca_alignment
func.func @set_sve_alloca_alignment() {
- // CHECK-COUNT-6: alignment = 1
+ // CHECK-COUNT-6: alignment = 16
%a1 = memref.alloca() : memref<vector<[32]xi8>>
%a2 = memref.alloca() : memref<vector<[16]xi8>>
%a3 = memref.alloca() : memref<vector<[8]xi8>>
@@ -103,7 +103,7 @@ func.func @set_sve_alloca_alignment() {
%a5 = memref.alloca() : memref<vector<[2]xi8>>
%a6 = memref.alloca() : memref<vector<[1]xi8>>
- // CHECK-COUNT-6: alignment = 2
+ // CHECK-COUNT-6: alignment = 16
%b1 = memref.alloca() : memref<vector<[32]xi16>>
%b2 = memref.alloca() : memref<vector<[16]xi16>>
%b3 = memref.alloca() : memref<vector<[8]xi16>>
@@ -111,7 +111,7 @@ func.func @set_sve_alloca_alignment() {
%b5 = memref.alloca() : memref<vector<[2]xi16>>
%b6 = memref.alloca() : memref<vector<[1]xi16>>
- // CHECK-COUNT-6: alignment = 4
+ // CHECK-COUNT-6: alignment = 16
%c1 = memref.alloca() : memref<vector<[32]xi32>>
%c2 = memref.alloca() : memref<vector<[16]xi32>>
%c3 = memref.alloca() : memref<vector<[8]xi32>>
@@ -119,7 +119,7 @@ func.func @set_sve_alloca_alignment() {
%c5 = memref.alloca() : memref<vector<[2]xi32>>
%c6 = memref.alloca() : memref<vector<[1]xi32>>
- // CHECK-COUNT-6: alignment = 8
+ // CHECK-COUNT-6: alignment = 16
%d1 = memref.alloca() : memref<vector<[32]xi64>>
%d2 = memref.alloca() : memref<vector<[16]xi64>>
%d3 = memref.alloca() : memref<vector<[8]xi64>>
@@ -127,7 +127,7 @@ func.func @set_sve_alloca_alignment() {
%d5 = memref.alloca() : memref<vector<[2]xi64>>
%d6 = memref.alloca() : memref<vector<[1]xi64>>
- // CHECK-COUNT-6: alignment = 4
+ // CHECK-COUNT-6: alignment = 16
%e1 = memref.alloca() : memref<vector<[32]xf32>>
%e2 = memref.alloca() : memref<vector<[16]xf32>>
%e3 = memref.alloca() : memref<vector<[8]xf32>>
@@ -135,7 +135,7 @@ func.func @set_sve_alloca_alignment() {
%e5 = memref.alloca() : memref<vector<[2]xf32>>
%e6 = memref.alloca() : memref<vector<[1]xf32>>
- // CHECK-COUNT-6: alignment = 8
+ // CHECK-COUNT-6: alignment = 16
%f1 = memref.alloca() : memref<vector<[32]xf64>>
%f2 = memref.alloca() : memref<vector<[16]xf64>>
%f3 = memref.alloca() : memref<vector<[8]xf64>>
>From 54387a97081535abde5cd4b9d308bcca8c65c1a0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 20 Oct 2023 16:13:49 +0000
Subject: [PATCH 4/5] [mlir][SVE] Add an e2e test for vector.contract
Adds an end-to-end test for `vector.contract` that targets SVE (i.e.
scalable vectors). Note that this requires lifting the restriction on
`vector.outerproduct` (to which `vector.contract` is lowered to) that
would deem the following as invalid by the Op verifier (*):
```
vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32>
```
This is indeed valid as the end-to-end test demonstrates (at least when
compiling for SVE).
Depends on #68794
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +-
...r-contract-to-outerproduct-transforms.mlir | 59 ++++++--
.../Vector/vector-scalable-outerproduct.mlir | 5 +-
.../Vector/CPU/ArmSVE/test-contraction.mlir | 137 ++++++++++++++++++
4 files changed, 191 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8de132daf3c6a5d..d77476c10908395 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
- if (vRHS.isScalable() != vLHS.isScalable())
- return emitOpError("expected either all or none of vector operands #1 "
- "and #2 to be scalable");
+ if (vLHS.isScalable() && !vRHS.isScalable()) {
+ // This restriction reflects what's currently supported in terms of
+ // scalable vectors. However, we could relax this if there's a use case.
+ return emitOpError(
+ "expected either both or only #2 operand dim to be scalable");
+ }
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 1e92fcff64dea57..a5cd9b8f39173b3 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
}
// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
return %0 : vector<3x7xf32>
}
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+// Note that only the J dimension is scalable in this example. In theory, all
+// dimensions could be be scalable, but there is no target yet for which this
+// would make sense.
+func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x[7]xf32>,
+ %arg2: vector<3x[7]xf32>,
+ %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+ return %0 : vector<3x[7]xf32>
+}
+
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 7d9923e036660c9..3b4e24da92aaacc 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
- // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+ // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+
+ return
}
+
// -----
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
new file mode 100644
index 000000000000000..12187dfd7787155
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
@@ -0,0 +1,137 @@
+// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
+// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm
+// DEFINE: %{entry} =
+// DEFINE: %{run} = %mcr_aarch64_cmd -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+
+// REDEFINE: %{entry} = entry_i32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=I32
+
+// REDEFINE: %{entry} = entry_f32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=F32
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+func.func @entry_i32() {
+ %vscale = vector.vscale
+
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %n_rows = arith.muli %vscale, %c2 : index
+
+ %cst = arith.constant 0: i32
+ %i32_123 = arith.constant 123 : i32
+ %i32_314 = arith.constant 314 : i32
+
+ // Allocate and initialize matrix A
+ %A_alloc = memref.alloca() : memref<3x5xi32>
+ linalg.fill ins(%i32_123 : i32) outs(%A_alloc :memref<3x5xi32>)
+ %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+ %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xi32>, vector<3x5xi32>
+
+ // Allocate and initialize matrix B
+ %B_alloc = memref.alloca(%n_rows) : memref<5x?xi32>
+ linalg.fill ins(%i32_123 : i32) outs(%B_alloc :memref<5x?xi32>)
+ %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+ %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xi32>, vector<5x[2]xi32>
+
+ // Allocate and initialize matrix C
+ %C_alloc = memref.alloca(%n_rows) : memref<3x?xi32>
+ linalg.fill ins(%i32_314 : i32) outs(%C_alloc :memref<3x?xi32>)
+ %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+ %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xi32>, vector<3x[2]xi32>
+
+ // Matmul
+ %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+ %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32> } : vector<3x[2]x5xi1> -> vector<3x[2]xi32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32: ( 75959, 75959, 75959, 75959
+ vector.print %slice1 : vector<[2]xi32>
+ %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959, 75959, 75959
+ vector.print %slice2 : vector<[2]xi32>
+ %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959, 75959, 75959
+ vector.print %slice3 : vector<[2]xi32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+func.func @entry_f32() {
+ %vscale = vector.vscale
+
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %n_rows = arith.muli %vscale, %c2 : index
+
+ %cst = arith.constant 0.0: f32
+ %f32_123 = arith.constant 1.23 : f32
+ %f32_314 = arith.constant 3.14 : f32
+
+ // Allocate and initialize matrix A
+ %A_alloc = memref.alloca() : memref<3x5xf32>
+ linalg.fill ins(%f32_123 : f32) outs(%A_alloc :memref<3x5xf32>)
+ %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+ %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xf32>, vector<3x5xf32>
+
+ // Allocate and initialize matrix B
+ %B_alloc = memref.alloca(%n_rows) : memref<5x?xf32>
+ linalg.fill ins(%f32_123 : f32) outs(%B_alloc :memref<5x?xf32>)
+ %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+ %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xf32>, vector<5x[2]xf32>
+
+ // Allocate and initialize matrix C
+ %C_alloc = memref.alloca(%n_rows) : memref<3x?xf32>
+ linalg.fill ins(%f32_314 : f32) outs(%C_alloc :memref<3x?xf32>)
+ %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+ %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[2]xf32>
+
+ // Matmul
+ %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+ %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32> } : vector<3x[2]x5xi1> -> vector<3x[2]xf32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice1 : vector<[2]xf32>
+ %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice2 : vector<[2]xf32>
+ %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice3 : vector<[2]xf32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+}
>From b5f879aa7c8a51f2b167dc3ad157b83e62983c55 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 24 Oct 2023 09:24:34 +0000
Subject: [PATCH 5/5] [mlir][vector] Add scalable vectors to tests for
vector.contract
Update the remaining tests for matrix multiplication in:
* vector-contract-to-outerproduct-transforms.mlir
with cases for scalable vectors. Note that at the moment these tests
only verify scalability in 1 dimension (*). More specifically, only the
2nd parallel dimension ("N") is made scalable.
It's not clear whether making the other dimensions scalable would ever
make sense (i.e. dims "M" and/or "K"), but we can always revisit in the
future. For 2-d scalable vectors (future work) we will be making both
parallel dimensions scalable.
(*) In terms of Arm extensions, this means targeting SVE rather than
SME.
---
.../Vector/Transforms/LowerVectorContract.cpp | 2 +-
...r-contract-to-outerproduct-transforms.mlir | 140 ++++++++++++++++++
2 files changed, 141 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 5463a7bd8f4c840..6dbe36e605e9a78 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
return v;
Type promotedType = dstElementType;
if (vecType)
- promotedType = VectorType::get(vecType.getShape(), promotedType);
+ promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index a5cd9b8f39173b3..c6eadab3212785c 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
+//
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: return %[[c3]] : vector<2x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
// CHECK-LABEL: func @matmul_0
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
// CHECK-LABEL: func @matmul_0_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
+// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_3_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_4 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -313,6 +436,23 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}
+// CHECK-LABEL: func @matmul_4_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// CHECK: return %[[c0]] : vector<3x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
+-> vector<3x[2]xf32>
+{
+ %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+ : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
+ return %0 : vector<3x[2]xf32>
+}
+
#matmat_accesses_5 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
More information about the Mlir-commits
mailing list