Heatmaps with imshow not displaying x axis properly

30 Views Asked by At

I am attempting to create a heatmap using data from an Excel file. The excel file has multiple sheets and in each sheet there is a column with 'Z Score Times', which are the times I want on the X axis and columns with channel data titled CH1, CH2, CH3, and CH4 that contain the data to be displayed in each heatmap.

The Excel file is formatted so that each trial is saved in a different sheet. This means per Excel file there should be four heatmaps outputted one for each channel and the Y axis should be following the 'Z Score Times' column. For some reason the data is always squeezed in the X direction. The code I tried and the outputs are below.

import matplotlib.pyplot as plt

def TTL_correction_plot(trigger, excel_file):
    # Load the Excel file
    xls = pd.ExcelFile(excel_file)

    # Initialize empty lists to store column data and times from each sheet
    ch_data = {f'CH{i+1}': [] for i in range(4)}
    time_data = []
    
    # Iterate over each sheet
    for sheet_name in xls.sheet_names:
        # Read the sheet into a DataFrame
        df = pd.read_excel(excel_file, sheet_name=sheet_name)
        # Append Z Score Times to time_data
        time_data.append(df['Z Score Times'])
        # Append column data to respective lists
        for i in range(4):
            ch_data[f'CH{i+1}'].append(df[f'CH{i+1}'])
    
    # Concatenate data from all sheets
    all_ch_data = {key: pd.concat(value, axis=1) for key, value in ch_data.items()}
    all_times = pd.concat(time_data)[enter image description here](https://i.stack.imgur.com/TzG9A.png)
    
    # Plot each channel data as a heatmap
    plt.figure(figsize=(12, 8))
    
    for i, (ch, data) in enumerate(all_ch_data.items(), start=1):
        plt.subplot(2, 2, i)
        # Plot heat map
        im = plt.imshow(data.T, aspect='auto', cmap='viridis', extent=[all_times.min(), all_times.max(), 0, len(xls.sheet_names)])
        plt.colorbar(im, label='Value')
        plt.axvline(x=0, color='black', linestyle='--', linewidth=2, label='Trigger Onset')
        plt.title(f"{trigger} Channel {i}")  # Adjusted title format
        plt.xlabel('Z Score Times')
        plt.ylabel('Sheet Index')  # Change ylabel to 'Sheet Index'
    
    plt.tight_layout()
    plt.show()

# Call the function for each file and specify the trigger
TTL_correction_plot('Top', 'Input2_TTL.xlsx')
TTL_correction_plot('Bottom', 'Input3_TTL.xlsx')

What I have so far

What I have so far

0

There are 0 best solutions below