I'm developing a ML training server and client that uses the ultralytics package, YOLOv8 models and gRPC. The server is written in python and the client in C#. I got everything going and I can start the training process with no problem from the client side. Now I want to report the training process back to the client, to visually update the user on how the training is going. For this I created this small proto file:
syntax = "proto3";
package training_client;
service TrainingService {
rpc StartTraining(StartTrainingRequest) returns (stream InTrainingResponse) {}
}
message StartTrainingRequest {}
message Metric
{
string name = 1;
float value = 2;
}
message InTrainingResponse
{
int32 epoch = 1;
repeated Metric metrics = 2;
}
My basic server implementation looks like this:
from multiprocessing import Process, Queue
import grpc
from ultralytics.engine.trainer import BaseTrainer
import training_pb2_grpc
from ultralytics import YOLO
from concurrent.futures import ThreadPoolExecutor
from training_pb2 import InTrainingResponse, Metric
def on_train_epoch_end(trainer: BaseTrainer):
print("Putting values into queue")
TrainingServicer.progress_queue.put((trainer.epoch, trainer.metrics))
def on_train_end(trainer: BaseTrainer):
print("Training Finished")
TrainingServicer.progress_queue.put(None) # finished training, break infinite loop
class TrainingServicer(training_pb2_grpc.TrainingServiceServicer):
progress_queue = Queue()
def __init__(self):
super().__init__()
self.model = YOLO("yolov8m.pt")
self.model.add_callback("on_train_epoch_end", on_train_epoch_end)
self.model.add_callback("on_train_end", on_train_end)
def run_training(self):
self.model.train(data="dataset/data.yaml", epochs=15, imgsz=512, batch=2, device=0)
def StartTraining(self, request, context):
training_thread = Process(target=self.run_training)
training_thread.start()
while True:
try:
item = TrainingServicer.progress_queue.get(timeout=1)
if item is None:
break
epoch, metrics = item
resp = InTrainingResponse(epoch=epoch)
for k, v in metrics.items():
resp.metrics.append(Metric(name=k, value=v))
print("Yielding training update")
yield resp
except Exception as e:
print("Queue is empty or no new data available")
continue
def serve():
server = grpc.server(ThreadPoolExecutor(max_workers=10))
training_pb2_grpc.add_TrainingServiceServicer_to_server(TrainingServicer(), server)
server.add_secure_port('[::]:30008', grpc.local_server_credentials())
print("Start Server")
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()
And my basic client like this:
using Grpc.Net.Client;
using Grpc.Core;
using TrainingClient;
using var channel = GrpcChannel.ForAddress("http://localhost:30008");
var client = new TrainingService.TrainingServiceClient(channel);
using var call = client.StartTraining(new StartTrainingRequest());
await foreach (var epoch in call.ResponseStream.ReadAllAsync())
{
Console.WriteLine($"Received Epoch {epoch.Epoch} with {epoch.Metrics}");
}
My problem now is that the callbacks are successfully putting items into the queue, but the consumer part of the queue never receives them, indicating for me that both queues are different instances, which I quickly confirmed by checking their memory addresses. The call to model.train() is blocking, which is why I tried to run it in a different process, to be able to yield back each result to the RPC. My understanding of multiprocessing.Queue is that this implementation uses a shared memory spaces to allow different processes to share data, but I don't seem to be able to use it correctly.
Explained above: Tried to use multiprocessing to use a producer/consumer sort of pattern to report training metrics mid-training back to the client.