#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Blender Studio Tools Authors
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""
Blender script to extract render statistics from video strips.
This script is meant to be executed within Blender using --python flag,
or called from render_stats.py.
"""

import bpy
import csv
import json
import sys
from pathlib import Path
from statistics import median
import OpenImageIO as oiio


CHANNEL_TO_PROCESS = 22  # The "Comp 2" channel
# TODO: Ask Blender devs to support channel lookup by name


def time_str_to_time_seconds_int(time_str):
    return int(time_str.split(':')[0]) * 60 + float(time_str.split(':')[1])


def get_frame_metadata(frame_path):
    """Extract metadata from a frame using OpenImageIO."""
    frame_path_str = str(frame_path)

    if not Path(frame_path_str).exists():
        return None

    try:
        img_input = oiio.ImageInput.open(frame_path_str)
        if img_input is None:
            return None

        spec = img_input.spec()
        metadata = {}

        # Extract the required metadata
        for attr in spec.extra_attribs:
            if attr.name == "cycles.ViewLayer.total_time":
                metadata["render_time"] = attr.value
            elif attr.name == "cycles.ViewLayer.samples":
                metadata["samples"] = attr.value
            elif attr.name == "Scene":
                metadata["scene"] = attr.value

        img_input.close()
        return metadata
    except Exception as e:
        print(f"Error reading metadata from {frame_path_str}: {e}", file=sys.stderr)
        return None


def get_strip_frame_paths(strip):
    """Get all frame paths for a video strip."""
    frame_paths = []

    if not hasattr(strip, 'directory') or not hasattr(strip, 'elements'):
        return frame_paths

    directory = Path(bpy.path.abspath(strip.directory))

    for element in strip.elements:
        frame_path = directory / element.filename
        if frame_path.exists():
            frame_paths.append(frame_path)

    return frame_paths


def write_stat_files(results):
    # Get the current blend file path
    blend_file_path = Path(bpy.data.filepath)

    # Create output filenames based on blend file name
    blend_filename = blend_file_path.stem
    output_json = blend_file_path.parent / f"{blend_filename}-render_stats.json"
    output_csv = blend_file_path.parent / f"{blend_filename}-render_stats.csv"

    with open(output_json, 'w') as f:
            json.dump(results, f, indent=2)

    # Save results to CSV, withs some arbitrary chances to facilitate analysis
    # For example, we estimate all shots to be rendered at 100 samples
    with open(output_csv, 'w', newline='') as f:
        writer = csv.writer(f)
        # Write header
        writer.writerow(['shot_name', 'frames_count', 'samples', 'frame_median_time_sec', 'total_time_sec', 'total_time_hours'])
        # Write data rows
        for result in results:
            # If samples is 10, multipliy by 10 to estimate 100 samples
            if result['samples'] and result['samples'] == 10:
                result['samples'] = 100
                result['frame_median_time_sec'] *= 10
                result['total_time_sec'] *= 10

            writer.writerow([
                result['shot_name'],
                result['frames_count'],
                result['samples'],
                result['frame_median_time_sec'],
                result['total_time_sec'],
                result['total_time_sec'] / 3600 if result['total_time_sec'] else None
            ])

    print(f"Results saved to: {output_json}")
    print(f"Results saved to: {output_csv}")
    print(f"Processed {len(results)} strips")


def main():
    # Get the sequence editor
    if not bpy.context.scene.sequence_editor:
        print("No sequence editor found in the blend file", file=sys.stderr)
        sys.exit(1)

    all_strips = bpy.context.scene.sequence_editor.strips_all

    # Find all image sequence strips in the desired channel
    filtered_strips = [s for s in all_strips if s.channel == CHANNEL_TO_PROCESS and s.type  == 'IMAGE']
    print(f"Processing {len(filtered_strips)} strips in Comp 2 channel")

    results = []

    for strip in filtered_strips:
        print(f"Processing strip: {strip.name}")

        # Swap the strip directory from -comp to -lighting, as the render metadata is in the lighting frames
        strip.directory = strip.directory.replace('-comp/', '-lighting/')

        # Get all frame paths
        frame_paths = get_strip_frame_paths(strip)

        if not frame_paths:
            print(f"No frames found for strip: {strip.name}", file=sys.stderr)
            continue

        frames_data = []
        samples_value = None
        shot_name = None

        # Process each frame
        for frame_path in frame_paths:
            metadata = get_frame_metadata(frame_path)

            if metadata:
                frame_info = {
                    "frame_path": str(frame_path),
                    "render_time": metadata.get("render_time"),
                    "samples": metadata.get("samples")
                }
                frames_data.append(frame_info)

                # Get the shot name from the first frame's Scene metadata
                if shot_name is None and "scene" in metadata:
                    shot_name = metadata["scene"]

                # Get the first samples value
                if samples_value is None and "samples" in metadata:
                    samples_value = metadata["samples"]

        if not frames_data:
            print(f"No metadata found for strip: {strip.name}", file=sys.stderr)
            continue

        # If no shot_name was found in metadata, fall back to strip name
        if shot_name is None:
            shot_name = strip.name
            print(f"Warning: No Scene metadata found for strip {strip.name}, using strip name", file=sys.stderr)

        # Calculate median render time
        render_times = [f["render_time"] for f in frames_data if f["render_time"] is not None]
        render_times_int = [time_str_to_time_seconds_int(t) for t in render_times]
        median_render_time = int(median(render_times_int)) if render_times_int else None

        # Create the result object
        result = {
            "shot_name": shot_name,
            "frames": frames_data,
            "frames_count": len(frames_data),
            "samples": None if not samples_value else int(samples_value),
            "frame_median_time_sec": median_render_time,
            "total_time_sec": None if not median_render_time else median_render_time * len(frames_data)
        }

        results.append(result)

    # Save results to JSON
    write_stat_files(results)


if __name__ == "__main__":
    main()
