Filter dictionary with values of different types

59 Views Asked by At

I am trying to filter dictionary by value (dictionary with different data types), but got the error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I want to get all records corresponding to 'YALE' value in my dictionary.

This is my code:

dataset = {
    'timeseires': array([[
        [ -5.653222,   7.39066 ,  20.651941, 4.07861 ,-11.752331, -34.611312],
        [ -5.653222,   7.39066 ,  20.651941, 4.07861 ,-11.752331, -34.611312]
    ]]),
    'site': array(['YALE', 'KKI'], dtype='<U8')
}

dataset = data.tolist()
 
def filter(pairs):
    key, value = pairs
    filter_key = 'site'
    if key == filter_key and value == 'YALE':
        return True
    else:
        return False
         
final_dic = dict(filter(filter, dataset.items()))
print(final_dic)

The expected output:

> dataset = {
>         'timeseires': array([[
>             [ -5.653222,   7.39066 ,  20.651941, 4.07861 ,-11.752331, -34.611312]
>         ]]),
>         'site': array(['KKI'], dtype='<U8')
>     }
2

There are 2 best solutions below

0
Federicofkt On BEST ANSWER

From the output given, it seems like your task is to extract the 'timeseries' value corresponding to the index of the 'YALE'.

This code should do the trick:

def filter_dictionary_by_value(dataset, value):
    index_value = np.where(dataset['site'] == value)[0]
    filtered_timeseries = dataset['timeseries'][index_value]
    

    filtered_dict = {
        'timeseries': filtered_timeseries,
        'site': np.array([value], dtype='<U8')
    }

If this wasn't what you were expecting please be more specific in the question.

1
e-motta On

Considering your expected output, you could do something like this using dict comprehension:

import numpy as np

filtered_indices = np.where(dataset["site"] == "YALE")[0].tolist()
dataset = {k: np.delete(v, filtered_indices) for k, v in dataset.items()}
{'timeseires': array([  7.39066 ,  20.651941,   4.07861 , -11.752331, -34.611312,
        -5.653222,   7.39066 ,  20.651941,   4.07861 , -11.752331,
       -34.611312]), 'site': array(['KKI'], dtype='<U8')}