Source code for src.gridmind.utils.vis_util

from typing import Dict, Hashable, Optional
from matplotlib import pyplot as plt
from tabulate import tabulate
import numpy as np
import os
import glob
import logging





[docs]def plot_state_values(states, true_values, estimated_values): """ Plots the true values and estimated values for each state. :param states: List of state names (e.g., ['A', 'B', 'C', ...]) :param true_values: List of true values corresponding to the states :param estimated_values: List of lists containing estimated values for each state over iterations """ # Create a figure and axis plt.figure(figsize=(10, 6)) # Plot the true values (these will be constant across all iterations) plt.plot( states, true_values, label="True Values", color="black", linestyle="--", marker="o", ) # Plot estimated values over iterations for i, estimate in enumerate(estimated_values): plt.plot(states, estimate, label=f"Iteration {i+1}", linestyle="-", marker="x") # Adding labels and title plt.xlabel("States") plt.ylabel("Values") plt.title("True vs Estimated Values of States") plt.legend() plt.grid(True) # Show plot plt.tight_layout() plt.savefig("True vs Estimated Values of States.png") plt.show()
[docs]class VideoUtil: """Utility class for video loading and processing operations.""" @staticmethod
[docs] def load_video_as_tensor( video_save_path: str, logger: Optional[logging.Logger] = None ): """ Load video file(s) and convert to tensor format for TensorBoard. Args: video_save_path: Base path for the video files (without extension) logger: Optional logger for logging messages Returns: torch.Tensor: Video tensor in format (N, T, C, H, W) where: N = batch size (1) T = number of frames C = channels (3 for RGB) H = height W = width Returns None if video cannot be loaded. """ if logger is None: logger = logging.getLogger(__name__) try: import torch import torchvision from torchvision.io import read_video import gc except ImportError: logger.warning( "torch or torchvision not installed. Cannot load video for TensorBoard. " "Install with: pip install torch torchvision" ) return None # RecordVideo creates files with pattern: {name_prefix}-episode-{id}.mp4 # Find the most recent video file matching the pattern video_dir = os.path.dirname(video_save_path) video_prefix = os.path.basename(video_save_path) # Search for video files with the prefix video_files = glob.glob(os.path.join(video_dir, f"{video_prefix}*.mp4")) if not video_files: logger.warning( f"No video files found matching pattern: {video_save_path}*.mp4" ) return None # Use the most recently created video file latest_video = max(video_files, key=os.path.getctime) logger.info(f"Loading video from: {latest_video}") video_frames = None try: # read_video returns (video_tensor, audio_tensor, info) # video_tensor shape: (T, H, W, C) video_tensor, audio_tensor, info = read_video(latest_video, pts_unit="sec") # Explicitly delete audio tensor to free memory del audio_tensor # Convert from (T, H, W, C) to (N, T, C, H, W) format expected by TensorBoard # Permute dimensions: T, H, W, C -> T, C, H, W video_frames = video_tensor.permute(0, 3, 1, 2) # Delete original tensor to free memory del video_tensor # Add batch dimension: T, C, H, W -> 1, T, C, H, W video_frames = video_frames.unsqueeze(0) logger.info(f"Loaded video with shape: {video_frames.shape}") # Force garbage collection to free up memory gc.collect() return video_frames except Exception as e: logger.error(f"Error loading video {latest_video}: {str(e)}") # Clean up on error if video_frames is not None: del video_frames gc.collect() return None
if __name__ == "__main__": # Example usage
[docs] feature1 = [0, 0, 1, 1]
feature2 = [0, 1, 0, 1] state_values = [1.0, 0.5, 0.8, 0.2] print_value_table( feature1, feature2, state_values, feature1_name="X-axis", feature2_name="Y-axis" )