Xarray most efficient way to select variable and calculate its mean

1k Views Asked by At

I have a datacube of 3Gb opened with xarray that has 3 variables I'm interested in (v, vx, vy). The description is below with the code.

I am interested only in one specific time window spanning between 2009 and 2013, while the entire dataset spans from 1984 to 2018.

What I want to do is:

  • Grab the v, vx, vy values between 2009 and 2013
  • Calculate their mean along the time axis and save them as three 334x333 arrays

The issue is that it takes so much time that after 1 hour, the few lines of code I wrote were still running. What I don't understand is that if I save my "v" values as an array, load them as such and calculate their mean, it takes way less time than doing what I wrote below (see code). I don't know if there is a memory leak, or if it is just a terrible way of doing it. My pc has 16Gb of RAM, of which 60% is available before loading the datacube. So theoritically it should have enough RAM to compute everything.

What would be an efficient way to truncate my datacube to the desired time-window, then calculate the temporal mean (over axis 0) of the 3 variables "v", "vx", "vy" ?

I tried doing it like that:

datacube = xr.open_dataset('datacube.nc')  # Load the datacube
datacube = datacube.reindex(mid_date = sorted(datacube.mid_date.values))  # Sort the datacube by ascending time, where "mid_date" is the time dimension
    
sdate = '2009-01'   # Start date
edate = '2013-12'   # End date
    
ds = datacube.sel(mid_date = slice(sdate, edate))   # Create a new datacube gathering only the values between the start and end dates
    
vvtot = np.nanmean(ds.v.values, axis=0)   # Calculate the mean of the values of the "v" variable of the new datacube
vxtot = np.nanmean(ds.vx.values, axis=0)
vytot = np.nanmean(ds.vy.values, axis=0)






Dimensions:                    (mid_date: 18206, y: 334, x: 333)
Coordinates:
  * mid_date                   (mid_date) datetime64[ns] 1984-06-10T00:00:00....
  * x                          (x) float64 4.868e+05 4.871e+05 ... 5.665e+05
  * y                          (y) float64 6.696e+06 6.696e+06 ... 6.616e+06
Data variables: (12/43)
    UTM_Projection             object ...
    acquisition_img1           (mid_date) datetime64[ns] ...
    acquisition_img2           (mid_date) datetime64[ns] ...
    autoRIFT_software_version  (mid_date) float64 ...
    chip_size_height           (mid_date, y, x) float32 ...
    chip_size_width            (mid_date, y, x) float32 ...
                        ...
    vy                         (mid_date, y, x) float32 ...
    vy_error                   (mid_date) float32 ...
    vy_stable_shift            (mid_date) float64 ...
    vyp                        (mid_date, y, x) float64 ...
    vyp_error                  (mid_date) float64 ...
    vyp_stable_shift           (mid_date) float64 ...
Attributes:
    GDAL_AREA_OR_POINT:         Area
    datacube_software_version:  1.0
    date_created:               30-01-2021 20:49:16
    date_updated:               30-01-2021 20:49:16
    projection:                 32607
1

There are 1 best solutions below

0
Bert Coerver On BEST ANSWER

Try to avoid calling ".values" in between, because when you do that you are switching to np.array instead of xr.DataArray!

import xarray as xr
from dask.diagnostics import ProgressBar

# Open the dataset using chunks.
ds = xr.open_dataset(r"/path/to/you/data/test.nc", chunks = "auto")

# Select the period you want to have the mean for. 
ds = ds.sel(time = slice(sdate, edate))

# Calculate the mean for all the variables in your ds.
ds = ds.mean(dim = "time")

# The above code takes less than a second, because no actual
# calculations have been done yet (and no data has been loaded into your RAM).
# Once you use ".values", ".compute()", or
# ".to_netcdf()" they will be done. We can see progress like this:
with ProgressBar():
    ds = ds.compute()