Source code for moaap.trackers.waves

import numpy as np
from scipy import fftpack, ndimage, signal
from moaap.utils.constants import g, a, beta, pi, NA
from moaap.utils.data_proc import tukey_latitude_mask, temporal_tukey_window, interpolate_temporal
from moaap.utils.segmentation import watershed_3d_overlap_parallel, analyze_watershed_history
from moaap.utils.object_props import clean_up_objects, BreakupObjects, ConnectLon_on_timestep
import gc
import sys
from pdb import set_trace as stop

[docs] def track_tropwaves_tb(tb, Lat, connectLon, dT, Gridspacing, er_th = 0.05, mrg_th = 0.05, igw_th = 0.2, kel_th = 0.1, eig0_th = 0.1, breakup = 'watershed', analyze_twave_history = False ): """ Identifies and tracks tropical waves using wavenumber-frequency filtering (Wheeler & Kiladis method) applied to precipitation data. Parameters ---------- pr : np.ndarray Precipitation data. Lat : np.ndarray Latitude grid. dT : int Time step (hours). er_th, mrg_th, igw_th, kel_th, eig0_th : float Amplitude thresholds for identifying Equatorial Rossby, Mixed Rossby Gravity, Inertia Gravity, Kelvin, and Eastward Inertia Gravity waves respectively. breakup : str Method to handle merging objects ('breakup' or 'watershed'). analyze_twave_history : bool, optional If True, computes watershed merge/split history. Returns ------- mrg_objects, igw_objects, kelvin_objects, eig0_objects, er_objects : np.ndarray Labeled object arrays for each wave type. """ tb = np.asarray(tb, dtype=np.float32) ew_mintime = 32 # use Turkey (hals cos) function to tamper region lat_mask = tukey_latitude_mask(Lat, lat_start=17.0, lat_stop=25.0) tb_eq = tb.copy() tb_eq[tb_eq > 350] = np.nan tb_eq[tb_eq < 150] = np.nan # compute anomalies tb_eq = tb_eq - np.nanmean(tb_eq, axis=(1,2), keepdims=True) tb_eq = tb_eq * lat_mask[None,:] tb_eq[np.isnan(tb_eq)] = 0 # pad the Tb to avoid boundary effects # temporal turkey tapping: nt = tb_eq.shape[0] win = temporal_tukey_window(nt, alpha=0.2) tb_eq = tb_eq * win[:, None, None] pad_size = int(tb_eq.shape[0] * 0.2) tb_eq = np.pad(tb_eq, ((pad_size,pad_size),(0,0),(0,0)), mode='reflect') tb_eq = interpolate_temporal(tb_eq) tropical_waves = KFfilter(tb_eq, int(24/dT)) wave_names = ['ER','MRG','IGW','Kelvin','Eig0'] print(' track tropical waves') rgiObj_Struct=np.zeros((3,3,3)); rgiObj_Struct[:,:,:]=1 for wa in range(5): print(' work on ' + wave_names[wa]) if wa == 0: amplitude = KFfilter.erfilter(tropical_waves, fmin=None, fmax=None, kmin=-10, kmax=-1, hmin=0, hmax=90, n=1) # had to set hmin from 8 to 0 wave = amplitude[pad_size:-pad_size] < er_th threshold = er_th if wa == 1: amplitude = KFfilter.mrgfilter(tropical_waves) wave = amplitude[pad_size:-pad_size] < mrg_th threshold = mrg_th elif wa == 2: amplitude = KFfilter.igfilter(tropical_waves) wave = amplitude[pad_size:-pad_size] < igw_th threshold = igw_th elif wa == 3: amplitude = KFfilter.kelvinfilter(tropical_waves) wave = amplitude[pad_size:-pad_size] < kel_th threshold = kel_th elif wa == 4: amplitude = KFfilter.eig0filter(tropical_waves) wave = amplitude[pad_size:-pad_size] < eig0_th threshold = eig0_th amplitude = amplitude[pad_size:-pad_size] if breakup == 'breakup': print(' break up long tropical waves that have many elements') wave_objects, object_split = BreakupObjects(wave_objects, int(ew_mintime/dT), dT) elif breakup == 'watershed': min_dist=int((1000 * 10**3)/Gridspacing) wave_amp = amplitude wave_objects = watershed_3d_overlap_parallel( wave_amp *-1, np.abs(threshold), np.abs(threshold), min_dist, dT, mintime = ew_mintime, ) if connectLon == 1: print(' connect waves objects over date line') wave_objects = ConnectLon_on_timestep(wave_objects) wave_objects, _ = clean_up_objects(wave_objects, dT, min_tsteps=int(ew_mintime/dT)) if wa == 0: er_objects = wave_objects.copy() if wa == 1: mrg_objects = wave_objects.copy() if wa == 2: igw_objects = wave_objects.copy() if wa == 3: kelvin_objects = wave_objects.copy() if wa == 4: eig0_objects = wave_objects.copy() if analyze_twave_history: min_dist=int((1000 * 10**3)/Gridspacing) print(f" Minimum distance between {wave_names[wa]} maxima for watershed analysis: {min_dist} grid cells") union_array, events, histories, history_obj = analyze_watershed_history( wave_objects, min_dist, wave_names[wa].lower() ) if wa == 0: er_history = history_obj.copy() if wa == 1: mrg_history = history_obj.copy() if wa == 2: igw_history = history_obj.copy() if wa == 3: kelvin_history = history_obj.copy() if wa == 4: eig0_history = history_obj.copy() """ union_array_clean = {int(k): int(v) for k, v in union_array.items()} events_clean = [ { 'type': e['type'], 'time': int(e['time']), 'from_label': int(e['from_label']), 'to_label': int(e['to_label']), 'distance': float(e['distance']) } for e in events ] histories_clean = {int(root): [int(label) for label in labels] for root, labels in histories.items()} print(f" Printing union array: {dict(list(union_array_clean.items()))}") print(f" Printing events: {events_clean}") print(f" Printing histories: {dict(list(histories_clean.items()))}") """ else: er_history = None mrg_history = None igw_history = None kelvin_history = None eig0_history = None del wave del wave_objects gc.collect() return mrg_objects, igw_objects, kelvin_objects, eig0_objects, er_objects, \ mrg_history, igw_history, kelvin_history, eig0_history, er_history
[docs] class KFfilter: """class for wavenumber-frequency filtering for WK99 and WKH00""" def __init__(self, datain, spd, tim_taper=0.1): """Arguments: 'datain' -- the data to be filtered. dimension must be (time, lat, lon) 'spd' -- samples per day 'tim_taper' -- tapering ratio by cos. applay tapering first and last tim_taper% samples. default is cos20 tapering """ ntim, nlat, nlon = datain.shape #remove dominal trend data = signal.detrend(datain, axis=0) #tapering if tim_taper == 'hann': window = signal.hann(ntim) data = data * window[:,NA,NA] elif tim_taper > 0: #taper by cos tapering same dtype as input array tp = int(ntim*tim_taper) window = np.ones(ntim, dtype=datain.dtype) x = np.arange(tp) window[:tp] = 0.5*(1.0-np.cos(x*pi/tp)) window[-tp:] = 0.5*(1.0-np.cos(x[::-1]*pi/tp)) data = data * window[:,NA,NA] #FFT self.fftdata = fftpack.fft2(data, axes=(0,2)) #Note # fft is defined by exp(-ikx), so to adjust exp(ikx) multipried minus wavenumber = -fftpack.fftfreq(nlon)*nlon frequency = fftpack.fftfreq(ntim, d=1./float(spd)) knum, freq = np.meshgrid(wavenumber, frequency) #make f<0 domain same as f>0 domain #CAUTION: wave definition is exp(i(k*x-omega*t)) but FFT definition exp(-ikx) #so cahnge sign knum[freq<0] = -knum[freq<0] freq = np.abs(freq) self.knum = knum self.freq = freq self.wavenumber = wavenumber self.frequency = frequency
[docs] def decompose_antisymm(self): """ decompose attribute data to sym and antisym component. Parameters ---------- None """ fftdata = self.fftdata nf, nlat, nk = fftdata.shape symm = 0.5*(fftdata[:,:nlat/2+1,:] + fftdata[:,nlat:nlat/2-1:-1,:]) anti = 0.5*(fftdata[:,:nlat/2,:] - fftdata[:,nlat:nlat/2:-1,:]) self.fftdata = np.concatenate([anti, symm],axis=1)
[docs] def kfmask(self, fmin=None, fmax=None, kmin=None, kmax=None): """return wavenumber-frequency mask for wavefilter method Arguments: 'fmin/fmax' -- 'kmin/kmax' -- Returns: 'mask' -- 2D boolean array (wavenumber, frequency).domain to be filterd """ nf, nlat, nk = self.fftdata.shape knum = self.knum freq = self.freq #wavenumber cut-off mask = np.zeros((nf,nk), dtype=bool) if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) return mask
[docs] def wavefilter(self, mask): """apply wavenumber-frequency filtering by original mask. Arguments: 'mask' -- 2D boolean array (wavenumber, frequency).domain to be filterd is False (True member to be zero) Returns: 'filterd' -- filtered data in the original data space """ wavenumber = self.wavenumber frequency = self.frequency fftdata = self.fftdata.copy() nf, nlat, nk = fftdata.shape if (nf, nk) != mask.shape: print( "mask array size is incorrect.") sys.exit() mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 #inverse FFT filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
#filter
[docs] def kelvinfilter(self, fmin=0.05, fmax=0.4, kmin=None, kmax=14, hmin=8, hmax=90): """kelvin wave filter Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number 'hmin/hmax' --equivalent depth Returns: 'filterd' -- filtered data in the original data space """ fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape # filtering ############################################################ mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter if hmin != None: c = np.sqrt(g*hmin) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega - k <0) if hmax != None: c = np.sqrt(g*hmax) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega - k >0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
[docs] def erfilter(self, fmin=None, fmax=None, kmin=-10, kmax=-1, hmin=8, hmax=90, n=1): """equatorial wave filter Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number 'hmin/hmax' -- equivalent depth 'n' -- meridional mode number Returns: 'filterd' -- filtered data in the original data space """ if n <=0 or n%1 !=0: print("n must be n>=1 integer") sys.exit() fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape # filtering ############################################################ mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter if hmin != None: c = np.sqrt(g*hmin) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega*(k**2 + (2*n+1)) + k < 0) if hmax != None: c = np.sqrt(g*hmax) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega*(k**2 + (2*n+1)) + k > 0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
[docs] def igfilter(self, fmin=None, fmax=None, kmin=-15, kmax=-1, hmin=12, hmax=90, n=1): """n>=1 inertio gravirt wave filter. default is n=1 WIG. Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number. negative is westward, positive is eastward 'hmin/hmax' -- equivalent depth 'n' -- meridional mode number Returns: 'filterd' -- filtered data in the original data space """ if n <=0 or n%1 !=0: print("n must be n>=1 integer. for n=0 EIG you must use eig0filter method.") sys.exit() fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape # filtering ############################################################ mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter if hmin != None: c = np.sqrt(g*hmin) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k**2 - (2*n+1) < 0) if hmax != None: c = np.sqrt(g*hmax) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k**2 - (2*n+1) > 0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
[docs] def eig0filter(self, fmin=None, fmax=0.55, kmin=0, kmax=15, hmin=12, hmax=50): """n>=0 eastward inertio gravirt wave filter. Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number. negative is westward, positive is eastward 'hmin/hmax' -- equivalent depth Returns: 'filterd' -- filtered data in the original data space """ if kmin < 0: print("kmin must be positive. if k < 0, this mode is MRG") sys.exit() fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape # filtering ############################################################ mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter if hmin != None: c = np.sqrt(g*hmin) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k*omega - 1 < 0) if hmax != None: c = np.sqrt(g*hmax) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k*omega - 1 > 0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
[docs] def mrgfilter(self, fmin=None, fmax=None, kmin=-10, kmax=-1, hmin=8, hmax=90): """mixed Rossby gravity wave Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number. negative is westward, positive is eastward 'hmin/hmax' -- equivalent depth Returns: 'filterd' -- filtered data in the original data space """ if kmax > 0: print("kmax must be negative. if k > 0, this mode is the same as n=0 EIG") sys.exit() fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape # filtering ############################################################ mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter if hmin != None: c = np.sqrt(g*hmin) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k*omega - 1 < 0) if hmax != None: c = np.sqrt(g*hmax) omega = 2.*pi*freq/24./3600. / np.sqrt(beta*c) #adusting day^-1 to s^-1 k = knum/a * np.sqrt(c/beta) #adusting ^2pia to ^m mask = mask | (omega**2 - k*omega - 1 > 0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real
[docs] def tdfilter(self, fmin=None, fmax=None, kmin=-20, kmax=-6): """KTH05 TD-type filter. Arguments: 'fmin/fmax' -- unit is cycle per day 'kmin/kmax' -- zonal wave number. negative is westward, positive is eastward Returns: 'filterd' -- filtered data in the original data space """ fftdata = self.fftdata.copy() knum = self.knum freq = self.freq nf, nlat, nk = fftdata.shape mask = np.zeros((nf,nk), dtype=bool) #wavenumber cut-off if kmin != None: mask = mask | (knum < kmin) if kmax != None: mask = mask | (kmax < knum) #frequency cutoff if fmin != None: mask = mask | (freq < fmin) if fmax != None: mask = mask | (fmax < freq) #dispersion filter mask = mask | (84*freq+knum-22 > 0) | (210*freq+2.5*knum-13 < 0) mask = np.repeat(mask[:,NA,:], nlat, axis=1) fftdata[mask] = 0.0 filterd = fftpack.ifft2(fftdata, axes=(0,2)) return filterd.real