How to load data and pass it to clients in flower federated learning?

374 Views Asked by At

I am trying to use Flower for federated learning with data that has both inputs and outputs. However, I am having trouble figuring out how to load this data into Flower so that it can be used in the training process. I am little confused with splitting this data to train ,test samples and the validataion set. I am reading the data as a csv file using pandas.So the dataframe has both inputs and output classes.

I tried creating a custom dataset class to load as below,

class CustomDataset():
    def __init__(self, data):
        self.features = data.iloc[:, :-5].values
        self.targets = data.iloc[:, -5:].values

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        x = torch.tensor(self.features[index], dtype=torch.float32)
        y = torch.tensor(self.targets[index], dtype=torch.float32)
        xtrain, ytrain, xtest, ytest = train_test_split(x, y, test_size=0.2)
        if index % 2 == 0:
            return xtrain, ytrain
        else:
            return xtest, ytest


train_indices = list(range(0, len(dataset), 2))
test_indices = list(range(1, len(dataset), 2))

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler)

NUM_CLIENTS = 10


def load_datasets(num_clients: int, train_loader, test_loader):

    
    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(train_loader) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(train_loader, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(test_loader, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS ,train_loader,test_loader)
0

There are 0 best solutions below