Source code for src.gridmind.utils.logtools.async_tensorboard_logger

from torch.utils.tensorboard import SummaryWriter
from queue import Queue
import threading


[docs]class AsyncTensorboardLogger: def __init__(self, log_dir=None, flush_secs=10):
[docs] self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)
[docs] self.log_queue = Queue()
[docs] self._stop_signal = object()
[docs] self.thread = threading.Thread(target=self._log_worker, daemon=True)
self.thread.start()
[docs] def _log_worker(self): while True: item = self.log_queue.get() if item is self._stop_signal: break tag, value, step = item self.writer.add_scalar(tag, value, step) self.log_queue.task_done()
[docs] def add_scalar(self, tag, value, global_step, walltime=None): self.log_queue.put((tag, value, global_step))
[docs] def close(self): self.log_queue.put(self._stop_signal) self.thread.join() self.writer.close()