'''
Helper functions for testing and writing parallel code with the romtools.linalg module.
'''
import numpy as np
[docs]
def assert_axis_is_none_or_within_rank(a, axis):
'''Simple helper function to assert that the given axis is valid for the given array.'''
assert isinstance(axis, int) or axis is None, "axis must be an int or None"
if axis is not None:
assert axis < a.ndim, "axis must be < rank of the array"
[docs]
def distribute_array_impl(global_array, comm, dist_axis=0):
'''
Splts an np.array and distributes to all available MPI processes as evenly as possible.
For example, distributing an array with 6 rows over 3 processes will result in 2 rows
per process.
If the array has 7 rows, 2 processors will hold 2 rows, and the third will hold 3 rows.
Inputs:
global_array: The global np.array to be distributed.
comm: The MPI communicator
dist_axis: The axis along which to split the input array. By default, splits along the first axis (rows).
Returns:
local_array: The subset of global_array sent to the current MPI process.
'''
# Get comm info
n_procs = comm.Get_size()
rank = comm.Get_rank()
# Handle null case
if global_array.size == 0:
return np.empty(0)
# Split the global_array and send to corresponding MPI rank
if rank == 0:
splits = np.array_split(global_array, n_procs, axis=dist_axis)
for proc in range(n_procs):
if proc == 0:
local_array = splits[proc]
else:
comm.send(splits[proc], dest=proc)
else:
local_array = comm.recv(source=0)
return local_array
[docs]
def generate_random_local_and_global_arrays_impl(shape, comm, use_int=False):
'''
Randomly generates a global array of the specified shape and distributes to all available
MPI processes.
Returns both the local and global arrays.
'''
# Get comm info
rank = comm.Get_rank()
# Create global_array (using optional dim<x> arguments)
if shape == tuple():
global_arr = np.empty(0)
elif len(shape) <=3:
if use_int:
global_arr = np.random.randint(0, size=shape) if rank == 0 else np.empty(shape)
else:
global_arr = np.random.rand(*shape) if rank == 0 else np.empty(shape)
else:
raise ValueError(f"This function only supports arrays up to rank 3 (received rank {len(shape)})")
# Broadcast global_array and create local_array
comm.Bcast(global_arr, root=0)
dist_axis = 0 if len(shape) < 3 else 1
local_arr = distribute_array_impl(global_arr.copy(), comm, dist_axis)
return local_arr, global_arr
[docs]
def generate_local_and_global_arrays_from_example_impl(rank, slices, example: int):
'''
Generates and returns the local and global arrays built from the example tensors in the documentation.
'''
# Create arrays
if example == 1:
global_arr = np.array([2.2, 3.3, 40., 51., -24., 45., -4.])
local_arr = global_arr[slices[rank][0]:slices[rank][1]]
elif example == 2:
global_arr = np.array([[2.2, 1.3, 4.],
[3.3, 5.0, 33.],
[40., -2., -4.],
[51., 4., 6.],
[-24., 8., 9.],
[45., -3., -4.],
[-4., 8., 9.]])
local_arr = global_arr[slices[rank][0]:slices[rank][1],:]
elif example == 3:
global_arr = np.array([[[2.,3.],[3.,4.],[4.,2.],[5.,8.],[-2.,2.],[4.,1.],[-4.,2.]],
[[1.,6.],[5.,-1.],[-2.,-2.],[4.,-1.],[8.,0.],[-3.,-6.],[8.,0.]],
[[4.,-7.],[3.,5.],[-4.,5.],[6.,0.],[9.,3.],[-4.,1.],[9.,3.]]])
local_arr = global_arr[:,slices[rank][0]:slices[rank][1],:]
else:
return None, None
return local_arr, global_arr