import os
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, DefaultDict
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.segmentation import watershed, find_boundaries
from tqdm import tqdm # pyright: ignore[reportMissingModuleSource]
from skimage.morphology import erosion, disk, dilation
import psutil
from collections import defaultdict, deque
import matplotlib.patches as mpatches
import multiprocessing as mp
from multiprocessing import shared_memory
from scipy.spatial import cKDTree
from pdb import set_trace as stop
from moaap.utils.object_props import clean_up_objects, ConnectLon_on_timestep
[docs]
class UnionFind:
"""
A Union-Find (Disjoint Set) data structure.
Assumes each 'label' (int) is a unique object ID.
"""
def __init__(self):
self.parent: Dict[int, int] = {}
[docs]
def add(self, item: int):
if item not in self.parent:
self.parent[item] = item
[docs]
def find(self, item: int) -> int:
self.add(item)
if self.parent[item] == item:
return item
self.parent[item] = self.find(self.parent[item])
return self.parent[item]
[docs]
def union(self, item1: int, item2: int):
root1 = self.find(item1)
root2 = self.find(item2)
if root1 != root2:
self.parent[root1] = root2
[docs]
def connect_3d_objects(labels, min_tsteps, dT):
"""
Links 2D labeled slices into 3D objects based on maximum spatial overlap
between consecutive time steps.
Parameters
----------
labels : np.ndarray
3D array where each [t, :, :] slice contains independent 2D labels.
min_tsteps : int
Minimum duration to keep an object.
dT : int
Time step.
Returns
-------
objects_watershed : np.ndarray
3D array with consistent object IDs tracked over time.
"""
T, H, W = labels.shape
objects_watershed = np.zeros_like(labels, dtype=int)
objects_watershed[0] = labels[0]
next_id = labels.max() + 1
for t in tqdm(range(1, T)):
prev = objects_watershed[t-1]
curr = labels[t]
# build overlap counts
M = curr.max() + 1
mask = prev > 0
p = prev[mask].ravel()
c = curr[mask].ravel()
pair_idx = p * M + c
counts = np.bincount(pair_idx, minlength=(prev.max()+1)*M)
nz = np.nonzero(counts)[0]
p_lbls = nz // M
c_lbls = nz % M
overlaps = counts[nz]
# greedy best‐overlap assignment
order = np.argsort(-overlaps)
p_lbls = p_lbls[order]
c_lbls = c_lbls[order]
mapping = {}
used_curr = set()
for p_lbl, c_lbl in zip(p_lbls, c_lbls):
if p_lbl == 0 or c_lbl == 0:
continue
if p_lbl not in mapping and c_lbl not in used_curr:
mapping[p_lbl] = c_lbl
used_curr.add(c_lbl)
# build the new t‐slice
new_slice = np.zeros((H, W), dtype=int)
# 1) continuing objects
for p_lbl, c_lbl in mapping.items():
new_slice[curr == c_lbl] = p_lbl
# 2) brand‐new objects
all_curr = np.unique(curr)
for c_lbl in all_curr:
if c_lbl == 0 or c_lbl in used_curr:
continue
new_slice[curr == c_lbl] = next_id
next_id += 1
objects_watershed[t] = new_slice
# finally do your cleanup
objects_watershed, _ = clean_up_objects(objects_watershed,
min_tsteps=min_tsteps,
dT=dT)
return objects_watershed
def _get_all_centers_by_time(
labeled_data: np.ndarray
) -> Tuple[DefaultDict[int, Dict[int, Tuple[float, float]]],
DefaultDict[int, List[int]],
Set[int]]:
"""
Calculates the 2D center for every label at every time slice it appears.
Parameters
----------
labeled_data : np.ndarray
3D array of labeled data, shape (T, H, W).
Returns
-------
Tuple :
- centers_by_label : DefaultDict[int, Dict[int, Tuple[float, float]]]
Mapping of label -> time -> (y_center, x_center).
- labels_by_time : DefaultDict[int, List[int]]
Mapping of time -> list of labels present at that time.
"""
print("Pre-calculating all label 2D centers at each time slice...")
centers_by_label: DefaultDict[int, Dict[int, Tuple[float, float]]] = defaultdict(dict)
labels_by_time: DefaultDict[int, List[int]] = defaultdict(list)
num_times = labeled_data.shape[0]
# Iterate over time slices to compute centers via center of mass
for t in range(num_times):
label_slice = labeled_data[t, :, :]
labels_in_slice = np.unique(label_slice)
labels_in_slice = labels_in_slice[labels_in_slice != 0]
if labels_in_slice.size == 0:
continue
centers = center_of_mass(label_slice, labels=label_slice, index=labels_in_slice)
if labels_in_slice.size == 1:
centers = [centers] # Handle single label case
for label, center in zip(labels_in_slice, centers):
centers_by_label[label][t] = center
labels_by_time[t].append(label)
return centers_by_label, labels_by_time
def _find_nearest_neighbor(
center: np.ndarray,
time: int,
labels_by_time: List[int],
centers_by_label: DefaultDict[int, Dict[int, Tuple[float, float]]]
) -> Tuple[int, float]:
"""
Find the nearest neighbor label at a given time slice to the provided center.
Parameters
----------
center : np.ndarray
2D center point (y, x).
time : int
Time slice to search for neighbors.
labels_by_time : List[int]
List of labels present at the given time slice.
centers_by_label : DefaultDict[int, Dict[int, Tuple[float, float]]]
Precomputed centers for each label at each time slice.
Returns
-------
Tuple : [int, float]
A tuple of (nearest_label, distance). If no labels exist at that time, returns (None, inf).
"""
nearest_label = -1
min_distance = float('inf')
if not labels_by_time:
return None, min_distance
# Calculate distances to all labels at the given time and find the nearest
for label in labels_by_time:
actual_center = np.array(centers_by_label[label][time])
dist = np.linalg.norm(center - actual_center)
if dist < min_distance:
min_distance = dist
nearest_label = label
return nearest_label, min_distance
from collections import defaultdict
import os
import pickle
import json
import numpy as np
[docs]
def build_object_history_dict(labels, centers, events, uf, histories, object_type=None, out_dir="outputs", save=True):
"""
Build per-object history records (lifetime + interactions) from watershed tracking outputs.
Parameters
----------
labels : array-like of int
All nonzero object labels.
centers : dict
centers[label][t] -> (y, x) or similar. Used only to infer lifetimes.
events : list of dict
Each event dict must have keys:
- 'type' in {'merge','split'}
- 'time' (int)
- 'from_label' (int)
- 'to_label' (int)
- 'distance' (float)
uf : UnionFind
Must implement uf.find(label) and have uf.parent mapping.
histories : dict[int, set[int]]
Root -> set of labels connected by merges/splits.
Returns
-------
object_data : dict[int, dict]
object_data[label] is the record for that object label.
"""
labels = [int(l) for l in labels]
# lifetimes from centers
label_times = {}
for lab in labels:
ts = sorted(centers.get(lab, {}).keys())
label_times[lab] = (int(ts[0]), int(ts[-1])) if ts else (None, None)
# interactions per label
interactions = defaultdict(list)
for e in events:
t = int(e["time"])
etype = e["type"]
a = int(e["from_label"])
b = int(e["to_label"])
dist = float(e["distance"])
interactions[a].append({"time": t, "type": etype, "role": "from", "other_label": b, "distance": dist})
interactions[b].append({"time": t, "type": etype, "role": "to", "other_label": a, "distance": dist})
# union-find group info
label_root = {lab: int(uf.find(lab)) for lab in labels}
root_members = {int(root): sorted(int(x) for x in members) for root, members in histories.items()}
# assemble object_data
object_data = {}
for lab in labels:
t0, t1 = label_times[lab]
evs = sorted(interactions.get(lab, []), key=lambda d: d["time"])
duration = None if (t0 is None or t1 is None) else (t1 - t0 + 1)
partners = sorted({e["other_label"] for e in evs})
n_partners = len(partners)
root = label_root[lab]
object_data[lab] = {
"label": lab,
"lifetime": {"t_start": t0, "t_end": t1, "duration": duration},
"root": root,
"group_labels": root_members.get(root, [lab]),
"n_interactions": len(evs),
"unique_partners": partners,
"n_unique_partners": n_partners,
"interactions": evs,
}
if save:
os.makedirs(out_dir, exist_ok=True)
# lossless, recommended
pkl_name = f"object_history_{object_type}.pkl" if object_type else "object_history.pkl"
with open(os.path.join(out_dir, pkl_name), "wb") as f:
pickle.dump(object_data, f, protocol=pickle.HIGHEST_PROTOCOL)
# optional human-readable JSON (keys become strings)
def _jsonify(x):
if isinstance(x, (np.integer,)): return int(x)
if isinstance(x, (np.floating,)): return float(x)
if isinstance(x, np.ndarray): return x.tolist()
return x
json_name = f"object_history_{object_type}.json" if object_type else "object_history.json"
with open(os.path.join(out_dir, json_name), "w") as f:
json.dump({str(k): {kk: _jsonify(vv) for kk, vv in v.items()} for k, v in object_data.items()},
f, indent=2)
return object_data
[docs]
def analyze_watershed_history(watershed_results, min_dist, object_type: str, histplot: bool = False):
"""
Analyze the history of watershed objects over time.
The output is a union of all objects which merged or split over time,
along with a list of events (merges and splits) that occurred and the history array
(dict of sets), where two labels are in one set if they are connected via merges/splits.
This is done via Euler-timestepping and comparing the overlap of objects.
This function also creates a plot of the history of all objects showing merges and splits of
at most 40 objects (for better readability).
Parameters
----------
watershed_results : np.ndarray
3D array of watershed labels over time, shape (T, H, W).
min_dist : float
Minimum distance threshold to consider two objects as related (for merges/splits).
object_type : str
Type of object being analyzed (e.g., "mcs", "cloud").
histplot : bolean
Switch to turn on plotting of object history
Returns
-------
union_array : Dict[int, int]
Mapping of each label to its root label in the union-find structure.
events : List[Dict[str, Any]]
List of merge and split events with details.
histories : Dict[int, Set[int]]
Dictionary mapping root labels to sets of all connected labels.
"""
# Create Union-Find structure
T = watershed_results.shape[0]
labels = np.unique(watershed_results)
labels = labels[labels != 0]
centers, labels_t = _get_all_centers_by_time(watershed_results)
uf = UnionFind()
for label in labels:
uf.add(label)
events: List[Dict[str, Any]] = []
for label in labels:
times_present = sorted(centers[label].keys())
if not times_present:
continue
t_start = times_present[0]
t_end = times_present[-1]
if t_end - t_start < 1:
print("Skipping label", label, "with insufficient time span")
continue
# check for split genesis
center_start = np.array(centers[label][t_start])
if t_start > 0:
try:
center_next = np.array(centers[label][t_start + 1])
except:
center_next = np.array(centers[label][t_start])
# previous center prediction, c_-1 = c_0 - v * dt, v = (c_1 - c_0) / dt
# hence, c_-1 = 2 * c_0 - c_1
pred_center = 2 * center_start - center_next
nearest_label, dist = _find_nearest_neighbor(
pred_center,
t_start - 1,
labels_t[t_start - 1],
centers
)
# If a nearby label is found within min_dist, consider it a split
if nearest_label is not None and dist < min_dist:
uf.union(label, nearest_label)
events.append({
'type': 'split',
'time': t_start,
'from_label': nearest_label,
'to_label': label,
'distance': dist
})
if t_end < T - 1:
try:
center_prev = np.array(centers[label][t_end - 1])
except:
center_prev = np.array(centers[label][t_end])
center_end = np.array(centers[label][t_end])
# next center prediction, c_+1 = c_0 + v * dt, v = (c_0 - c_-1) / dt
# hence, c_+1 = 2 * c_0 - c_-1
pred_center = 2 * center_end - center_prev
nearest_label, dist = _find_nearest_neighbor(
pred_center,
t_end + 1,
labels_t[t_end + 1],
centers
)
# If a nearby label is found within min_dist, consider it a merge
if nearest_label is not None and dist < min_dist:
uf.union(label, nearest_label)
events.append({
'type': 'merge',
'time': t_end,
'from_label': label,
'to_label': nearest_label,
'distance': dist
})
# Build histories
histories: Dict[int, Set[int]] = defaultdict(set)
for label in labels:
root = uf.find(label)
histories[root].add(label)
union_array = uf.parent
# get object histories into a directory
history_data = build_object_history_dict(labels, centers, events, uf, histories, object_type=object_type, save=True)
if histplot is True:
# Plot the history
# Collect all unique labels and their lifetimes
all_labels = set()
for root, labels in histories.items():
all_labels.update(labels)
label_times = {}
for label in all_labels:
if label in centers:
times = sorted(centers[label].keys())
label_times[label] = (min(times), max(times))
# Filter to only labels involved in events (merges or splits)
event_labels = set()
for event in events:
event_labels.add(event['from_label'])
event_labels.add(event['to_label'])
filtered_labels = [label for label in all_labels if label in event_labels]
filtered_label_times = {label: label_times[label] for label in filtered_labels if label in label_times}
# Group filtered_labels by their history root and sort within groups and between groups
label_to_root = {}
for root, labels in histories.items():
for label in labels:
if label in filtered_labels:
label_to_root[label] = root
# Group labels by root
root_groups = {}
for label in filtered_labels:
root = label_to_root[label]
if root not in root_groups:
root_groups[root] = []
root_groups[root].append(label)
# Count events per label
event_count = defaultdict(int)
for event in events:
event_count[event['from_label']] += 1
event_count[event['to_label']] += 1
# Sort groups by the minimum label in the group
sorted_roots = sorted(root_groups.keys(), key=lambda r: min(root_groups[r]))
# For each root, arrange labels by event count, with most eventful in the middle
ordered_labels = []
for root in sorted_roots:
labels = root_groups[root]
# Sort by event count descending
sorted_labels = sorted(labels, key=lambda l: event_count[l], reverse=True)
# Arrange to place highest event count in middle
left = deque()
right = deque()
for label in sorted_labels:
if len(right) <= len(left):
right.append(label)
else:
left.appendleft(label)
ordered_group = list(left) + list(right)
ordered_labels.extend(ordered_group)
# Plot setup (only for filtered labels, ordered)
fig, ax = plt.subplots(figsize=(12, 8))
# Limit to first 50 entries to keep plot readable
if len(ordered_labels) > 40:
ordered_labels = ordered_labels[:40]
y_positions = {label: i for i, label in enumerate(ordered_labels)}
ax.set_yticks(list(y_positions.values()))
ax.set_yticklabels(list(y_positions.keys()), fontsize=12)
ax.set_xlabel('Time Step', fontsize=14)
ax.set_title('Watershed Object History: Merges and Splits (Filtered to Event-Involved Labels)', fontsize=16)
ax.tick_params(axis='x', labelsize=14) # increase x-axis tick fontsize
# Plot label lifetimes as horizontal lines (only for filtered labels)
for label, (t_start, t_end) in filtered_label_times.items():
if label in y_positions:
y = y_positions[label]
ax.plot([t_start, t_end], [y, y], 'b-', linewidth=2)
# Plot events (only for filtered labels)
for event in events:
t = event['time']
from_label = event['from_label']
to_label = event['to_label']
dist = event['distance']
event_type = event['type']
# Only plot if both labels are in the filtered set
if from_label in y_positions and to_label in y_positions:
y_from = y_positions[from_label]
y_to = y_positions[to_label]
color = 'red' if event_type == 'merge' else 'green'
ax.plot([t, t], [y_from, y_to], color=color, linestyle='--', linewidth=1)
ax.scatter([t, t], [y_from, y_to], color=color, s=50)
# Create proper legend with correct colors
lifetime_patch = mpatches.Patch(color='blue', label='Lifetime')
merge_patch = mpatches.Patch(color='red', label='Merge')
split_patch = mpatches.Patch(color='green', label='Split')
ax.legend(handles=[lifetime_patch, merge_patch, split_patch], loc='upper left', fontsize=14)
# save the plot in a pdf
plt.tight_layout()
os.makedirs('outputs', exist_ok=True)
plt.savefig('outputs/watershed_history_' + object_type + '.pdf')
return union_array, events, histories, history_data
#from memory_profiler import profile
# @profile_
[docs]
def watershed_2d_overlap(data, # 3D matrix with data for watershedding [np.array]
object_threshold, # float to create binary object mast [float]
max_treshold, # value for identifying max. points for spreading [float]
min_dist, # minimum distance (in grid cells) between maximum points [int]
dT, # time interval in hours [int]
mintime = 24, # minimum time an object has to exist in dT [int]
connectLon = 0, # do we have to track features over the date line?
extend_size_ratio = 0.25, # if connectLon = 1 this key is setting the ratio of the zonal domain added to the watershedding. This has to be big for large objects (e.g., ARs) and can be smaller for e.g., MCSs
erosion_disk = 3.5):
"""
This function performs watershedding on 2D anomaly fields over time and connects
the resulting 2D features into 3D objects based on maximum overlap.
This function uses spatially reduced watersheds from the previous time step as seed for the
current time step, which improves temporal consistency of features.
Parameters
----------
data : np.ndarray
3D array of data for watershedding [time, lat, lon].
object_threshold : float
Threshold to create binary object mask.
max_treshold : float
Threshold for identifying maximum points for spreading.
min_dist : int
Minimum distance (in grid cells) between maximum points.
dT : int
Time interval in hours.
mintime : int, optional
Minimum time an object has to exist in dT. Default is 24.
connectLon : int, optional
Whether to track features over the date line (1 for yes, 0 for no). Default is 0.
extend_size_ratio : float, optional
If connectLon = 1, this sets the ratio of the zonal domain added to the watershedding.
This has to be big for large objects (e.g., ARs) and can be smaller for e.g., MCSs. Default is 0.25.
erosion_disk : float, optional
Disk size for erosion of previous timestep mask to improve temporal connection of features. Default is 3.5.
"""
if connectLon == 1:
axis = 1
extension_size = int(data.shape[1] * extend_size_ratio)
data = np.concatenate(
[data[:, :, -extension_size:], data, data[:, :, :extension_size]], axis=2
)
data_2d_watershed = np.copy(data); data_2d_watershed[:] = np.nan
for tt in tqdm(range(data.shape[0])):
image = data[tt,:] >= object_threshold
data_t0 = data[tt,:,:]
# get maximum precipitation over three time steps to make fields more coherant
coords = peak_local_max(data_t0,
min_distance = min_dist,
threshold_abs = max_treshold,
labels = image
)
mask = np.zeros(data_t0.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
if tt != 0:
# allow markers to change a bit from time to time and
# introduce new markers if they have strong enough max/min and
# are far enough away from existing objects
boundaries = find_boundaries(data_2d_watershed[tt-1,:,:].astype("int"), mode='outer')
# Set boundaries to zero in the markers
separated_markers = np.copy(data_2d_watershed[tt-1,:,:].astype("int"))
separated_markers[boundaries] = 0
separated_markers = erosion(separated_markers, disk(erosion_disk)) #3.5
separated_markers[data_2d_watershed[tt,:,:] == 0] = 0
# add unique new markers if they are not too close to old objects
dilated_matrix = dilation(data_2d_watershed[tt-1,:,:].astype("int"), disk(2.5))
markers_updated = (markers + np.max(separated_markers)).astype("int")
markers_updated[markers_updated == np.max(separated_markers)] = 0
markers_add = (markers_updated != 0) & (dilated_matrix == 0)
separated_markers[markers_add] = markers_updated[markers_add]
markers = separated_markers
# break up elements that are no longer connected
markers, _ = ndi.label(markers)
# make sure that spatially separate objects have unique labels
# markers, _ = ndi.label(mask)
data_2d_watershed[tt,:,:] = watershed(image = np.array(data[tt,:])*-1, # watershedding field with maxima transformed to minima
markers = markers, # maximum points in 3D matrix
connectivity = np.ones((3, 3)), # connectivity
offset = (np.ones((2)) * 1).astype('int'), #4000/dx_m[dx]).astype('int'),
mask = image, # binary mask for areas to watershed on
compactness = 0) # high values --> more regular shaped watersheds
if connectLon == 1:
# Crop to the original size
# start = extension_size
# end = start + image.shape[axis]
if extension_size != 0:
data_2d_watershed = np.array(data_2d_watershed[:, :, extension_size:-extension_size])
data_2d_watershed = ConnectLon_on_timestep(data_2d_watershed.astype("int"))
### CONNECT OBJECTS IN 3D BASED ON MAX OVERLAP
labels = np.array(data_2d_watershed).astype('int')
objects = connect_3d_objects(labels,
int(mintime/dT),
dT)
return objects
# from memory_profiler import profile
# # @profile__sections
# @profile_
[docs]
def watershed_3d_overlap(
data: np.ndarray,
object_threshold: float,
max_treshold: float,
min_dist: int,
dT: int,
mintime: int = 24,
connectLon: int = 0,
extend_size_ratio: float = 0.25
) -> np.ndarray:
"""
Perform 3D watershedding on the input data with temporal consistency.
Parameters
----------
data : np.ndarray
3D matrix with data for watershedding
object_threshold : float
Float to create binary object mast
max_treshold : float
Value for identifying max. points for spreading
min_dist : int
Minimum distance (in grid cells) between maximum points
dT : int
Time interval in hours
mintime : int, optional
Minimum time an object has to exist in dT, by default 24
connectLon : int, optional
Do we have to track features over the date line?, by default 0
extend_size_ratio : float, optional
If connectLon = 1 this key is setting the ratio of the zonal domain added to the watershedding.
This has to be big for large objects (e.g., ARs) and can be smaller for e.g., MCSs, by default 0.25
Returns
-------
np.ndarray
3D matrix with watershed labels
"""
if connectLon == 1:
axis = 2
extension_size = int(data.shape[2] * extend_size_ratio)
data = np.concatenate(
[data[:, :, -extension_size:], data, data[:, :, :extension_size]], axis=axis
)
if np.ndim(object_threshold) >= 2:
object_threshold = np.concatenate(
[object_threshold[:, :, -extension_size:], object_threshold, object_threshold[:, :, :extension_size]], axis=axis
)
# Create a binary mask for watershedding, all data that needs to be segmented is True
image = data >= object_threshold
coords_list = []
# find peaks in each time slice and add time as an additional coordinate
for t in range(data.shape[0]):
coords_t = peak_local_max(data[t],
min_distance = min_dist,
threshold_abs = max_treshold,
labels = image[t],
exclude_border=True
)
coords_with_time = np.column_stack((np.full(coords_t.shape[0], t), coords_t))
coords_list.append(coords_with_time)
# Combine all coordinates into a single array
if len(coords_list) > 0:
coords = np.vstack(coords_list)
else:
coords = np.empty((0, 3), dtype=int)
mask = np.zeros(data.shape, dtype=bool)
mask[tuple(coords.T)] = True
# label peaks over time to ensure temporal consistency
labels = label_peaks_over_time_3d(coords, max_dist=min_dist)
markers = np.zeros(data.shape, dtype=int)
markers[tuple(coords.T)] = labels
# define connectivity for 3D watershedding and perform watershedding
conection = np.ones((3, 3, 3))
watershed_results = watershed(image = np.array(data)*-1, # watershedding field with maxima transformed to minima
markers = markers, # maximum points in 3D matrix
connectivity = conection, # connectivity
offset = (np.ones((3)) * 1).astype('int'), #4000/dx_m[dx]).astype('int'),
mask = image, # binary mask for areas to watershed on
compactness = 0) # high values --> more regular shaped watersheds
# correct objects on date line if needed
if connectLon == 1:
if extension_size != 0:
watershed_results = np.array(watershed_results[:, :, extension_size:-extension_size])
watershed_results = ConnectLon_on_timestep(watershed_results.astype("int"))
return watershed_results
[docs]
def watershed_3d_overlap_parallel(
data,
object_threshold,
max_treshold,
min_dist,
dT,
mintime=24,
connectLon=0,
extend_size_ratio=0.25,
n_chunks_lat=1,#None,
n_chunks_lon=1,#None,
overlap_cells=None,
mp_method='auto'
):
"""
Parallel version of watershed_3d_overlap using domain decomposition.
Parameters
----------
data : np.ndarray
3D matrix with data for watershedding
object_threshold : float
Float to create binary object mask
max_treshold : float
Value for identifying max. points for spreading
min_dist : int
Minimum distance (in grid cells) between maximum points
dT : int
Time interval in hours
mintime : int, optional
Minimum time an object has to exist in dT, by default 24
connectLon : int, optional
Do we have to track features over the date line?, by default 0
extend_size_ratio : float, optional
If connectLon = 1 this key is setting the ratio of the zonal domain added to the watershedding.
This has to be big for large objects (e.g., ARs) and can be smaller for e.g., MCSs, by default 0.25
n_chunks_lat : int, default=None
Number of chunks to split latitude dimension, if None, auto-detects based on CPU count
n_chunks_lon : int, default=None
Number of chunks to split longitude dimension
overlap_cells : int, optional
Number of overlapping cells between chunks. If None, uses min_dist * 2
mp_method : str, optional
Multiprocessing method: 'fork', 'spawn', or 'auto' (default). 'auto' chooses based on data size and system memory.
Returns
-------
np.ndarray
3D matrix with watershed labels
"""
data = np.asarray(data, dtype=np.float32)
if n_chunks_lat == 1 and n_chunks_lon == 1:
print("Only one chunk specified, running serial version.")
return watershed_3d_overlap(
data,
object_threshold,
max_treshold,
min_dist,
dT,
mintime,
connectLon,
extend_size_ratio
)
if n_chunks_lat == None and n_chunks_lon == None:
num_proc = mp.cpu_count() - 1 # get one less for system processes
print(f"Auto-detecting number of processes: {num_proc}")
num_proc = min(12, num_proc) # limit to 16 processes max to avoid oversubscription
lat = data.shape[1]
lon = data.shape[2]
print(f"Shape of the data to watershed: {data.shape}")
r = lon/lat
n_chunks_lon = int(np.floor(np.sqrt(num_proc * r)))
n_chunks_lat = int(np.floor(num_proc / n_chunks_lon))
# print(n_chunks_lat, n_chunks_lon)
while n_chunks_lat * n_chunks_lon > num_proc:
if n_chunks_lon > n_chunks_lat * r and n_chunks_lon > 1 or n_chunks_lat == 1:
n_chunks_lon -= 1
else:
n_chunks_lat -= 1
print(f"Auto-configured to {n_chunks_lat} latitude chunks and {n_chunks_lon} longitude chunks for parallel processing.")
# Set default overlap
if overlap_cells is None:
overlap_cells = min_dist * 4
# Handle dateline extension
if connectLon == 1:
extension_size = int(data.shape[2] * extend_size_ratio)
data = np.concatenate(
[data[:, :, -extension_size:], data, data[:, :, :extension_size]], axis=2
)
else:
extension_size = 0
nt, nlat, nlon = data.shape
# --- SETUP SHARED MEMORY FOR INPUT & MAIN OUTPUT ---
shm_input = shared_memory.SharedMemory(create=True, size=data.nbytes)
shared_input_arr = np.ndarray(data.shape, dtype=data.dtype, buffer=shm_input.buf)
shared_input_arr[:] = data[:]
out_dtype = np.int32
out_size = int(np.prod(data.shape) * np.dtype(out_dtype).itemsize)
shm_output = shared_memory.SharedMemory(create=True, size=out_size)
shared_output_arr = np.ndarray(data.shape, dtype=out_dtype, buffer=shm_output.buf)
shared_output_arr.fill(0)
# --- PRE-CALCULATE HALO BUFFER SIZE ---
# We need to store the "Upper" halos for Lat and Lon for every chunk.
# To do this efficiently, we pre-calculate the boundaries and required size.
lat_chunks = _calculate_chunk_boundaries(nlat, n_chunks_lat, overlap_cells)
lon_chunks = _calculate_chunk_boundaries(nlon, n_chunks_lon, overlap_cells)
halo_metadata = [] # Stores size and offset info for each chunk
total_halo_elements = 0
for i, (lat_s, lat_e, lat_cs, lat_ce) in enumerate(lat_chunks):
for j, (lon_s, lon_e, lon_cs, lon_ce) in enumerate(lon_chunks):
# Calculate dimensions of the halos this chunk will produce
# Note: Halos are the regions OUTSIDE the core but INSIDE the chunk
# Lat Halo Upper (South side of chunk): Shape (T, overlap_lat, width_lon)
# We strictly clip the halo width to the CORE width to match the neighbor's core
h_lat_h = lat_e - lat_ce
h_lat_w = lon_ce - lon_cs # Core width only
size_lat = nt * h_lat_h * h_lat_w
# Lon Halo Upper (East side of chunk): Shape (T, width_lat, overlap_lon)
h_lon_h = lat_ce - lat_cs
h_lon_w = lon_e - lon_ce
size_lon = nt * h_lon_h * h_lon_w
meta = {
'chunk_i': i, 'chunk_j': j,
'lat_bounds': (lat_s, lat_e, lat_cs, lat_ce),
'lon_bounds': (lon_s, lon_e, lon_cs, lon_ce),
'lat_halo_shape': (nt, h_lat_h, h_lat_w),
'lon_halo_shape': (nt, h_lon_h, h_lon_w),
'lat_halo_offset': total_halo_elements,
'lon_halo_offset': total_halo_elements + size_lat
}
halo_metadata.append(meta)
total_halo_elements += (size_lat + size_lon)
# --- SETUP SHARED MEMORY FOR HALOS ---
halo_bytes = total_halo_elements * np.dtype(out_dtype).itemsize
shm_halos = shared_memory.SharedMemory(create=True, size=halo_bytes)
# We don't create a single NDArray here because it's a flat buffer containing many arrays
# --- DECISION LOGIC (The "Smart Switch") ---
# This block decides the strategy if 'auto' is selected.
if mp_method == 'auto':
total_cells = data.size
# Example thresholds (to be calibrated):
FORK_LIMIT = 600000000 # Below 600M cells -> Parallel (Fork)
# Above -> Parallel (Spawn)
if total_cells < FORK_LIMIT:
mp_method = 'fork'
else:
mp_method = 'spawn'
try:
print(f" Processing {len(halo_metadata)} chunks with {halo_bytes / 1e9:.2f} GB halo buffer...")
chunk_args = []
for meta in halo_metadata:
chunk_args.append((
meta,
shm_input.name,
shm_output.name,
shm_halos.name,
data.shape,
data.dtype,
out_dtype,
object_threshold,
max_treshold,
min_dist
))
# --- RUN PARALLEL ---
# Modified to use the selected method
if mp_method == 'spawn':
ctx = mp.get_context('spawn')
PoolClass = ctx.Pool
else:
# Default to 'fork' (standard mp.Pool)
# WARNING: 'fork' can deadlock with C-libs, but is faster for medium data
ctx = mp.get_context('fork')
PoolClass = ctx.Pool
with PoolClass() as pool:
worker_results = pool.starmap(_process_watershed_chunk_no_return, chunk_args)
print(" Merging chunk results...")
# Combine the worker results (metadata) with the pre-calculated halo metadata
# We need both to find the data in shared memory
full_results = []
for w_res, h_meta in zip(worker_results, halo_metadata):
combined = {**w_res, **h_meta}
full_results.append(combined)
_merge_watershed_chunks(
full_results,
shared_output_arr,
shm_halos,
lat_chunks,
lon_chunks
)
final_result = _relabel_consecutive(shared_output_arr.copy())
finally:
# CLEANUP
shm_input.close(); shm_input.unlink()
shm_output.close(); shm_output.unlink()
shm_halos.close(); shm_halos.unlink()
if connectLon == 1:
if extension_size != 0:
final_result = final_result[:, :, extension_size:-extension_size]
final_result = ConnectLon_on_timestep(final_result.astype("int"))
return final_result
def _relabel_consecutive(labeled_array):
"""
Relabel array to have consecutive integer labels starting from 1.
Parameters
----------
labeled_array : np.ndarray
3D array of labeled data with not necessarily consecutive integers.
Returns
-------
np.ndarray
Relabeled array with consecutive integers.
"""
# Get unique non-zero labels
unique_labels = np.unique(labeled_array[labeled_array > 0])
if len(unique_labels) == 0:
return labeled_array
# Create a lookup array: old_label -> new_label
# The maximum old label determines the size we need
max_label = unique_labels[-1] # unique_labels is sorted
lookup = np.zeros(max_label + 1, dtype=labeled_array.dtype)
lookup[unique_labels] = np.arange(1, len(unique_labels) + 1, dtype=labeled_array.dtype)
# Apply the mapping using fancy indexing
# This is MUCH faster than looping
result = lookup[labeled_array]
return result
def _process_watershed_chunk_no_return(
meta,
shm_input_name,
shm_output_name,
shm_halos_name,
shape,
dtype_in,
dtype_out,
object_threshold,
max_treshold,
min_dist
):
"""
Process a single watershed chunk in shared memory without returning large arrays.
Parameters
----------
meta : Dict
Metadata for the chunk (boundaries, halo shapes, offsets).
shm_input_name : str
Name of the shared memory for input data.
shm_output_name : str
Name of the shared memory for output data.
shm_halos_name : str
Name of the shared memory for halo data.
shape : Tuple[int]
Shape of the full data array.
dtype_in : np.dtype
Data type of the input data.
dtype_out : np.dtype
Data type of the output data.
object_threshold : float
Threshold to create binary object mask.
max_treshold : float
Threshold for identifying maximum points for spreading.
min_dist : int
Minimum distance (in grid cells) between maximum points.
Returns
-------
Dict
max_label : int
Maximum label found in this chunk.
"""
# Attach to shared memories
shm_in = shared_memory.SharedMemory(name=shm_input_name)
shm_out = shared_memory.SharedMemory(name=shm_output_name)
shm_halos = shared_memory.SharedMemory(name=shm_halos_name)
full_data_in = np.ndarray(shape, dtype=dtype_in, buffer=shm_in.buf)
full_data_out = np.ndarray(shape, dtype=dtype_out, buffer=shm_out.buf)
# Create flat wrapper for halo buffer
# We will reconstruct the specific halo arrays using slicing
flat_halos = np.ndarray((shm_halos.size // np.dtype(dtype_out).itemsize,),
dtype=dtype_out, buffer=shm_halos.buf)
lat_s, lat_e, lat_cs, lat_ce = meta['lat_bounds']
lon_s, lon_e, lon_cs, lon_ce = meta['lon_bounds']
chunk_data = full_data_in[:, lat_s:lat_e, lon_s:lon_e]
try:
# --- Perform Watershed (Same as before) ---
image = chunk_data >= object_threshold
coords_list = []
for t in range(chunk_data.shape[0]):
coords_t = peak_local_max(
chunk_data[t],
min_distance=min_dist,
threshold_abs=max_treshold,
labels=image[t],
exclude_border=True
)
if coords_t.size > 0:
coords_with_time = np.column_stack((np.full(coords_t.shape[0], t), coords_t))
coords_list.append(coords_with_time)
if len(coords_list) > 0:
coords = np.vstack(coords_list)
else:
coords = np.empty((0, 3), dtype=int)
mask = np.zeros(chunk_data.shape, dtype=bool)
if coords.size > 0: mask[tuple(coords.T)] = True
labels = label_peaks_over_time_3d(coords, max_dist=min_dist)
markers = np.zeros(chunk_data.shape, dtype=int)
if coords.size > 0: markers[tuple(coords.T)] = labels
watershed_result = watershed(
image=chunk_data * -1,
markers=markers,
connectivity=np.ones((3, 3, 3)),
offset=np.ones(3, dtype=int),
mask=image,
compactness=0
)
# -------------------------------------------
rel_lat_cs = lat_cs - lat_s
rel_lat_ce = lat_ce - lat_s
rel_lon_cs = lon_cs - lon_s
rel_lon_ce = lon_ce - lon_s
core_result = watershed_result[:, rel_lat_cs:rel_lat_ce, rel_lon_cs:rel_lon_ce]
full_data_out[:, lat_cs:lat_ce, lon_cs:lon_ce] = core_result.astype(dtype_out)
# Extract Halo slices from local result
# Lat Halo (Upper)
if meta['lat_halo_shape'][1] > 0:
# We crop the halo to the CORE width (rel_lon_cs to rel_lon_ce)
# to align spatially with the neighbor's core
h_lat = watershed_result[:, rel_lat_ce:, rel_lon_cs:rel_lon_ce]
# Flatten and write to buffer
start = meta['lat_halo_offset']
end = start + h_lat.size
flat_halos[start:end] = h_lat.ravel().astype(dtype_out)
# Lon Halo (Upper)
if meta['lon_halo_shape'][2] > 0:
# Crop to CORE height (rel_lat_cs to rel_lat_ce)
h_lon = watershed_result[:, rel_lat_cs:rel_lat_ce, rel_lon_ce:]
start = meta['lon_halo_offset']
end = start + h_lon.size
flat_halos[start:end] = h_lon.ravel().astype(dtype_out)
# Return only tiny metadata
return {
'max_label': watershed_result.max() if watershed_result.size > 0 else 0
}
finally:
shm_in.close(); shm_out.close(); shm_halos.close()
def _merge_watershed_chunks(chunk_results, merged_array, shm_halos, lat_chunks, lon_chunks):
# Reconstruct the flat halo array
dtype_out = merged_array.dtype
flat_halos = np.ndarray((shm_halos.size // np.dtype(dtype_out).itemsize,),
dtype=dtype_out, buffer=shm_halos.buf)
chunk_results.sort(key=lambda x: (x['chunk_i'], x['chunk_j']))
# Calculate Offsets
chunk_offsets = {}
current_offset = 0
for result in chunk_results:
idx = (result['chunk_i'], result['chunk_j'])
chunk_offsets[idx] = current_offset
current_offset += result['max_label']
total_max_label = current_offset
# Build Merge Map
global_map = _build_merge_map_shm(
merged_array,
flat_halos, # Pass flat buffer
chunk_results,
chunk_offsets,
total_max_label
)
# Apply Map In-Place
_apply_map_inplace(merged_array, chunk_results, chunk_offsets, global_map)
return merged_array
def _build_merge_map_shm(merged_array, flat_halos, chunk_results, chunk_offsets, total_max_label, overlap_match_threshold=0.5):
"""
Build a merge map for watershed labels across chunk boundaries using shared memory halos. This is done using union-find on the lablels
of the halo and its neighboring core region.
Parameters
----------
merged_array : np.ndarray
The full merged watershed array from all chunks.
flat_halos : np.ndarray
Flat array containing all halo data from chunks.
chunk_results : list of dict
Metadata for each chunk including halo offsets and shapes.
chunk_offsets : dict
Offsets for each chunk's labels in the global label space.
total_max_label : int
Total number of unique labels across all chunks.
overlap_match_threshold : float, optional
Threshold for considering a halo-core overlap as a match, by default 0.5.
Returns
-------
list
A list mapping each label to its root label after merging.
"""
parent = list(range(total_max_label + 1))
def find(i):
if parent[i] == i: return i
path = [i]
while parent[path[-1]] != path[-1]:
path.append(parent[path[-1]])
root = path[-1]
for node in path: parent[node] = root
return root
def union(i, j):
root_i = find(i); root_j = find(j)
if root_i != root_j:
if root_i < root_j: parent[root_j] = root_i
else: parent[root_i] = root_j
grid_map = {(r['chunk_i'], r['chunk_j']): r for r in chunk_results}
def check_overlap(halo_flat_slice, halo_shape, core_slice_raw, offset_halo, offset_core):
# Reshape the flat halo slice back to 3D
halo_data = halo_flat_slice.reshape(halo_shape)
# Determine the common shape
d0 = min(halo_data.shape[0], core_slice_raw.shape[0])
d1 = min(halo_data.shape[1], core_slice_raw.shape[1])
d2 = min(halo_data.shape[2], core_slice_raw.shape[2])
if d0 == 0 or d1 == 0 or d2 == 0: return
# Slice both arrays to this common shape
h_cut = halo_data[:d0, :d1, :d2]
c_cut = core_slice_raw[:d0, :d1, :d2]
if d0 < halo_data.shape[0] or d1 < halo_data.shape[1] or d2 < halo_data.shape[2]:
print(f"Warning: Clipping Halo overlap from {halo_data.shape} to {(d0, d1, d2)}")
# Determine the full area of halo objects within this specific window
mask_halo_only = h_cut > 0
if not np.any(mask_halo_only): return
# Get local IDs + Offset for the Halo objects in this window
halo_ids_all = h_cut[mask_halo_only] + offset_halo
# Count total pixels for each halo object in this window
# Use simple bincount. IDs are shifted by offset, so we need a large enough bin.
if halo_ids_all.size > 0:
max_id = halo_ids_all.max()
halo_total_counts = np.bincount(halo_ids_all, minlength=max_id + 1)
else:
return
# Determine Intersections
mask_intersect = mask_halo_only & (c_cut > 0)
if not np.any(mask_intersect): return
halo_ids_int = h_cut[mask_intersect] + offset_halo
core_ids_int = c_cut[mask_intersect] + offset_core
pairs = np.column_stack((halo_ids_int, core_ids_int))
unique_pairs, counts = np.unique(pairs, axis=0, return_counts=True)
for (h_id, c_id), count in zip(unique_pairs, counts):
# Criterion:
# Does the halo object map significantly to the core object?
# Ratio = (Intersection Area) / (Halo Object Area in Overlap Window)
# Now we use the correct total count from the window analysis
if h_id < len(halo_total_counts):
total_halo_pixels = halo_total_counts[h_id]
if total_halo_pixels > 0:
ratio = count / total_halo_pixels
if ratio > overlap_match_threshold:
union(int(h_id), int(c_id))
# --- Process Boundaries ---
for res in chunk_results:
i, j = res['chunk_i'], res['chunk_j']
# Check North Neighbor (i+1)
if (i + 1, j) in grid_map:
# Reconstruct Halo from buffer
shape = res['lat_halo_shape']
if shape[1] > 0: # If height > 0
start = res['lat_halo_offset']
end = start + np.prod(shape)
halo_view = flat_halos[start:end]
neighbor_res = grid_map[(i+1, j)]
# We expect the halo to overlap with the neighbor's lat_core region
# specifically the *start* of the neighbor's core.
neighbor_lat_start = neighbor_res['lat_bounds'][2] # lat_core_start
neighbor_lat_end = neighbor_res['lat_bounds'][3] # lat_core_end
# The theoretical overlap height is shape[1]
# But we can't go beyond the neighbor's core size
max_h = min(shape[1], neighbor_lat_end - neighbor_lat_start)
core_slice = merged_array[
:,
neighbor_lat_start : neighbor_lat_start + max_h,
res['lon_bounds'][2] : res['lon_bounds'][3] # Match my core width
]
check_overlap(halo_view, shape, core_slice, chunk_offsets[(i, j)], chunk_offsets[(i+1, j)])
# Check East Neighbor (j+1)
if (i, j + 1) in grid_map:
shape = res['lon_halo_shape']
if shape[2] > 0:
start = res['lon_halo_offset']
end = start + np.prod(shape)
halo_view = flat_halos[start:end]
neighbor_res = grid_map[(i, j+1)]
neighbor_lon_start = neighbor_res['lon_bounds'][2] # lon_core_start
neighbor_lon_end = neighbor_res['lon_bounds'][3] # lon_core_end
max_w = min(shape[2], neighbor_lon_end - neighbor_lon_start)
core_slice = merged_array[
:,
res['lat_bounds'][2] : res['lat_bounds'][3], # Match my core height
neighbor_lon_start : neighbor_lon_start + max_w
]
check_overlap(halo_view, shape, core_slice, chunk_offsets[(i, j)], chunk_offsets[(i, j+1)])
# --- Build Final Consecutive Map (Same as before) ---
final_mapping = np.zeros(total_max_label + 1, dtype=np.int32)
for k in range(len(parent)): final_mapping[k] = find(k)
final_mapping[0] = 0
unique_roots = np.unique(final_mapping)
if unique_roots[0] == 0: unique_roots = unique_roots[1:]
compress_lut = np.zeros(final_mapping.max() + 1, dtype=np.int32)
compress_lut[unique_roots] = np.arange(1, len(unique_roots) + 1)
return compress_lut[final_mapping]
def _calculate_chunk_boundaries(total_size, n_chunks, overlap):
"""
Calculate chunk boundaries with overlap.
Parameters
----------
total_size : int
Total size of the dimension to be chunked.
n_chunks : int
Number of chunks to create.
overlap : int
Number of overlapping cells between chunks.
Returns
-------
list of tuples
Each tuple contains (start_with_overlap, end_with_overlap, core_start, core_end)
"""
chunk_size = total_size // n_chunks
boundaries = []
for i in range(n_chunks):
# Core region (without overlap)
core_start = i * chunk_size
core_end = (i + 1) * chunk_size if i < n_chunks - 1 else total_size
# Extended region (with overlap)
start = max(0, core_start - overlap)
end = min(total_size, core_end + overlap)
boundaries.append((start, end, core_start, core_end))
return boundaries
def _apply_map_inplace(merged_array, chunk_results, chunk_offsets, global_map):
"""
Applies the global mapping to the shared array block-by-block.
Parameters
----------
merged_array : np.ndarray
The shared array containing the merged watershed results.
chunk_results : list of dict
Metadata for each chunk processed.
chunk_offsets : dict
Offsets for each chunk's local labels in the global map.
global_map : np.ndarray
The global mapping array from local to global labels.
Returns
-------
None
"""
print(" Applying labels in-place...")
for res in chunk_results:
idx = (res['chunk_i'], res['chunk_j'])
offset = chunk_offsets[idx]
max_local_label = res['max_label']
if max_local_label == 0:
continue
# Create Local Lookup Table
# Size = max local label + 1 (to include 0)
local_lut = np.zeros(max_local_label + 1, dtype=np.int32)
# Explicitly keep background 0 -> 0
local_lut[0] = 0
if max_local_label > 0:
start_idx = offset + 1
end_idx = offset + max_local_label + 1
local_lut[1:] = global_map[start_idx : end_idx]
# Unpack the bounds from the metadata tuples
# lat_bounds = (start, end, core_start, core_end)
lat_core_start = res['lat_bounds'][2]
lat_core_end = res['lat_bounds'][3]
lon_core_start = res['lon_bounds'][2]
lon_core_end = res['lon_bounds'][3]
# Apply in-place
sl = (
slice(None),
slice(lat_core_start, lat_core_end),
slice(lon_core_start, lon_core_end)
)
chunk_data = merged_array[sl]
# Advanced indexing: reads chunk_data, looks up values in local_lut, writes back
chunk_data[:] = local_lut[chunk_data]
# @profile_
[docs]
def label_peaks_over_time_3d(coords, max_dist=5):
"""
Labels peaks in 3D coordinates over time based on spatial proximity.
Parameters
----------
coords :
np.ndarray of shape (N_peaks, 3), each row is [t, y, x]
max_dist :
maximum allowed distance to consider peaks as the same object (in grid units)
Returns
-------
labels :
np.ndarray of shape (N_peaks,), integer labels for each peak over time
"""
# Split coords by timestep
timesteps = np.unique(coords[:, 0])
labels = np.zeros(coords.shape[0], dtype=np.int32)
next_label = 1
prev_coords = None
prev_labels = None
for t in timesteps:
idx_t = np.where(coords[:, 0] == t)[0]
coords_t = coords[idx_t][:, 1:3] # [y, x] only
labels_t = np.zeros(coords_t.shape[0], dtype=np.int32)
if prev_coords is None or prev_coords.shape[0] == 0:
# First timestep: assign new labels
labels_t[:] = np.arange(next_label, next_label + coords_t.shape[0])
next_label += coords_t.shape[0]
else:
# Build KDTree for previous peaks
tree = cKDTree(prev_coords)
for i, peak in enumerate(coords_t):
dist, idx = tree.query(peak, distance_upper_bound=max_dist)
if dist < max_dist and idx < prev_coords.shape[0]:
labels_t[i] = prev_labels[idx]
else:
labels_t[i] = next_label
next_label += 1
labels[idx_t] = labels_t
prev_coords = coords_t
prev_labels = labels_t
return labels