Metal.jl 1.4: Improved random numbers

By: Christian Guinard

Re-posted from: https://juliagpu.org/post/2024-10-07-metal-1.4/index.html

Metal.jl 1.4 adds higher-quality random number generators from the Metal Performance Shaders library. Some limitations apply, with a fallback to the current implementation in those situations.

Metal.rand and friends

Using functionality provided by the Metal Performance Shaders (MPS) library, Metal.jl now comes with much improved GPU random number generators. Uniform distributions using Metal.rand (and its in-place variant Metal.rand!) are available for all Metal-supported integer types and Float32. However, due to Metal API limitations, 8-bit and 16-bit integers may fall back to the lower-quality GPUArrays.jl random number generator if their size in bytes is not a multiple of 4. Normally distributed Float32 values can be generated for with Metal.randn and Metal.randn!, while Float16 is not supported by the MPS library and will always fall back to the GPUArrays implementation.

The easiest way to use these is to use the Metal convenience functions Metal.rand[n][!] as you would the usual functions from the Random.jl standard library:

julia> a = Metal.rand(Float32, 2)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 0.95755994
 0.7110207julia> Metal.randn!(a)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 1.7230463
 0.55636907

However, the Random.jl methods can also be used by providing the appropriate RNG either from MPS.default_rng() or MPS.RNG() to the standard Random.rand[n][!] functions:

julia> using Randomjulia> rng = MPS.RNG();julia> Random.rand(rng, 2)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 0.8941469
 0.67628527

Seeding is done by calling Metal.seed! for the global RNG, or Random.seed! when working with an explicit RNG object.

Other improvements since the last blog post

  • Since v0.5: MtlArray storage mode has been parameterized, allowing one to create a shared storage MtlArray by calling MtlArray{eltype, ndims, Metal.SharedStorage}(...).

  • Since v0.3: MPS-accelerated decompositions were added.

  • Various performance improvements

  • Many bug fixes.

Future work

Although Metal.jl is now in v1, there is still work to be done to make it as fast and feature-complete as possible. In particular:

  • Metal.jl is now using native ObjectiveC FFI for wrapping Metal APIs. However, these wrappers have to be written manually for every piece of Objective-C code. We are looking for help with improving Clang.jl and ObjectiveC.jl to enable the automatic generation of these wrappers;

  • The MPS wrappers are incomplete, automatic wrapper generation would greatly help with full MPS support;

  • To implement a full-featured KernelAbstractions.jl back-end, Metal atomic operations need to be hooked up to Atomix;

  • Full support for BFloat16 values, which has been supported since Metal 3.1 (macOS 14), is not yet available in Metal.jl. There is, however, a draft PR in the works. Check it out if you're interested in helping out;

  • Some functionality present in CUDA.jl could be ported to Metal.jl to improve usability.