Does XArray have a function like argmax for quantile statistics?

77 Views Asked by At

Is there a way to keep the coordinates when using xarray.quantile?

I am taking a (90th percentile) quantile of a dataset with coordinates lat, lon, and time. I want to have the time when the data values are in this quantile.

I run the command Data.quantile([.90],dim='time') which removes the 'time' coordinate and replaces it with 'quantile' Is there a way to retain the coordinate information AND perform the quantile operation?

There's the argument "keep_attrs" but I have yet to find anything on retaining coordinates. I want something like xarray.DataArray.argmax

3

There are 3 best solutions below

4
jspaeth On

You can use .where() to filter according to .quantile():

Given an xarray DataArray da:

>>> da
<xarray.DataArray (time: 100)>
array([-1.11006507e+00, -4.41380179e-01,  1.10087254e+00,  2.18218427e-01,
       ...-5.51287030e-01])
Coordinates:
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-04-09

>>> da.where(da > da.quantile(0.9, "time"), drop=True)
<xarray.DataArray (time: 10)>
array([1.84009741, 2.25750906, 1.62780955, 1.55448247, 2.11139034,
       2.17723193, 3.11637597, 1.26926648, 1.49876131, 1.55716718])
Coordinates:
  * time      (time) datetime64[ns] 2000-01-12 2000-01-17 ... 2000-04-03
    quantile  float64 0.9
0
atteggiani On

You cannot "keep" the time coordinate because the quantiles are calculated over that coordinate.

If you want to return the indices of the computed quantiles along an axis (time in your case), there is no xarray built-in function such as argmax.

However, this answer on a similar question suggests using np.argpartition to achieve the task.

The following function I wrote works for xarray.dataarrays.

def argquantile(quantiles,darray,dim=None):
    if not isinstance(quantiles,list):
        quantiles = [quantiles]
    if dim is None:
        dim = darray.dims[0]
    idx = [int(np.round(q * (len(darray[dim]) - 1))) for q in quantiles]
    indquant = xr.concat([np.argpartition(darray, [i], axis=darr.dims.index(dim)).isel({dim:i}).drop(dim).assign_coords({'quantile':q}) for i,q in zip(idx,quantiles)],'quantile')
    return indquant

It takes similar inputs to the xarray.DataArray.quantile built-in function but returns the indices of the quantiles along the selected dimension.

Below there is an example script to test it:

import numpy as np
import xarray as xr

# The argquantile function
def argquantile(quantiles,darray,dim=None):
    if not isinstance(quantiles,list):
        quantiles = [quantiles]
    if dim is None:
        dim = darray.dims[0]
    idx = [int(np.round(q * (len(darray[dim]) - 1))) for q in quantiles]
    indquant = xr.concat([np.argpartition(darray, [i], axis=darr.dims.index(dim)).isel({dim:i}).drop(dim).assign_coords({'quantile':q}) for i,q in zip(idx,quantiles)],'quantile')
    return indquant

# Let's create an example dataarray
time = np.arange(21)
lat = np.linspace(-90,90,30)
lon = np.linspace(0,360,51)[:-1]
quantiles = [0.5,0.8]
data = np.random.rand(len(time),len(lat),len(lon))
dims = ['time','lat','lon']
coords = [time,lat,lon]
darr = xr.DataArray(data=data, dims = dims, coords={d:coord for d,coord in zip(dims,coords)})

# Calculate quantile with xarray 
# We use interpolation='nearest' so we have exact coordinate values and we can retrieve the exact indices.
q = darr.quantile(quantiles,dim='time',interpolation='nearest')

# Calculate argquantile
aq = argquantile(quantiles,darr,dim='time')

# verify that aq effectively contains the quantiles indeces (for our case)
def verify():
    return np.all([darr[aq[iq,ilat,ilon],ilat,ilon].values == q[iq,ilat,ilon].values for iq,_ in enumerate(quantiles) for ilat,_ in enumerate(lat) for ilon,_ in enumerate(lon)])

print(verify())

Hope that helps!

Cheers Davide

1
Maxim Couillard On

Here is my (terribly inefficient) code:

Cape90_by_hour=[]
Cape_by_hour=[]
hours_list=['00','01','02','03','04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','20','21','22','23']
for z in hours_list:
    zhour=CAPE[(CAPE['hour']==z)]
     Cape_by_hour.append(zhour)
     z90=zhour.quantile([.90],dim='hour')
     Cape90_by_hour.append(z90)

cape_above_percentile = []
datetime_of_cape_above_percentile = []

for hr in np.arange(0,24,1):
    percentile_cape = Cape90_by_hour[hr] ['quantile'==1]
    cape90_avg_at_hr=[]
    date_of_cape_above_percentile = []
    for lat_idx in range(len(CAPE1.latitude)):
        for lon_idx in range(len(CAPE1.longitude)):
            percentile_cape90=percentile_cape.isel(latitude=lat_idx,longitude=lon_idx).values
        cape_values = Cape_by_hour[hr].isel(latitude=lat_idx, longitude=lon_idx).values
        time_values = Cape_by_hour[hr]['time']
        cape90_at_each_pt=[]
        dates_at_each_pt=[]
        for w in range(len(cape_values)):
            if cape_values[w] >= percentile_cape90:
                cape90_at_each_pt.append(cape_values[w]) 
                dates_at_each_pt.append(time_values[w])
        cape90_avg=np.sum(cape90_at_each_pt)/len(cape90_at_each_pt)
        cape90_avg_at_hr.append(cape90_avg)
        date_of_cape_above_percentile.append(dates_at_each_pt)
        
cape_above_percentile.append(np.array(cape90_avg_at_hr))
datetime_of_cape_above_percentile.append(np.array(date_of_cape_above_percentile,dtype=object))   

numpts=len(lon)*len(lat)
Cin_at_cape90=[]
for hour in np.arange(0,24,1):
    cin_avg_at_all_pts=[]
    for points in np.arange(0,numpts,1):
        cinValues=[]
        for value in np.arange(0,138,1):
             cinValue=Cin1.isel(latitude=0,longitude=0)[Cin1['time']==datetime_of_cape_above_percentile[hour][points][value]].values                
        cinValues.append(cinValue)
    cin_avg_at_pt=np.nanmean(cinValues)
    cin_avg_at_all_pts.append(cin_avg_at_pt)
cin_avg_at_all_pts2=np.reshape(cin_avg_at_all_pts,(4,4))
Cin_at_cape90.append(cin_avg_at_all_pts2)

Let me know if anyone can find a more efficient way.