"use client";

import React, { forwardRef, useImperativeHandle, useRef, useState, useEffect } from 'react';

import {
  Box,
  Typography,
  Button,
  Slider,
  IconButton,
  Tooltip,
  useTheme,
  Select,
  MenuItem,
  FormControl,
} from '@mui/material';
import {
  PlayArrow,
  Stop,
  Pause,
  ZoomIn,
  ZoomOut,
  Save,
  CloudDownload,
  VerticalAlignBottom,
  VerticalAlignTop,
  CenterFocusStrong,
} from '@mui/icons-material';
import * as Tone from 'tone';
import { Midi } from '@tonejs/midi';
import { getTransport } from 'tone';

const NOTES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'];
const MAJOR_SCALE = [0, 2, 4, 5, 7, 9, 11];
const MINOR_SCALE = [0, 2, 3, 5, 7, 8, 10];

const MidiPlayer = forwardRef(({ downloadUrl, initialBpm }, ref) => {
  const theme = useTheme();

  const [midiNotes, setMidiNotes] = useState([]);
  const [isPlaying, setIsPlaying] = useState(false);
  const [isPaused, setIsPaused] = useState(false);
  const [selectedNote, setSelectedNote] = useState(null);
  const [tempo, setTempo] = useState(initialBpm || 120);
  const [lowestNote, setLowestNote] = useState(21); // A0
  const [highestNote, setHighestNote] = useState(108); // C8
  const [playbackPosition, setPlaybackPosition] = useState(0);
  const [timeSignature, setTimeSignature] = useState({ numerator: 4, denominator: 4 });

  const synth = useRef(null);
  const canvasRef = useRef(null);
  const containerRef = useRef(null);
  const scheduledPart = useRef(null);
  const pausedTime = useRef(0);
  const animationFrameId = useRef(null);

  useImperativeHandle(ref, () => ({
    stop: stopMidi,
  }));

  useEffect(() => {
    synth.current = new Tone.PolySynth(Tone.Synth).toDestination();
    synth.current.volume.value = -14;
    return () => {
      if (synth.current) {
        synth.current.dispose();
      }
      if (scheduledPart.current) {
        scheduledPart.current.dispose();
      }
    };
  }, []);

  const parseMidi = async (url) => {
    try {
      const response = await fetch(url);
      const arrayBuffer = await response.arrayBuffer();
      const midi = new Midi(arrayBuffer);

      // Extract the time signature
      const timeSignatureRaw = midi.header.timeSignatures[0];
      let timeSignatureValidated = {
        numerator: 4,
        denominator: 4,
      };
      if (timeSignatureRaw) {
        timeSignatureValidated = {
          numerator: timeSignatureRaw.timeSignature[0],
          denominator: timeSignatureRaw.timeSignature[1],
        };
      }
      setTimeSignature(timeSignatureValidated);

      const rawNotes = midi.tracks.flatMap((track) =>
        track.notes.map((note) => ({
          midi: note.midi,
          time: note.time,
          duration: note.duration,
          disabled: false,
        }))
      );

      const midiNumbers = rawNotes.map(note => note.midi);
      const newLowestNote = Math.max(21, Math.min(...midiNumbers));
      const newHighestNote = Math.min(108, Math.max(...midiNumbers));

      setLowestNote(newLowestNote);
      setHighestNote(newHighestNote);
      setMidiNotes(rawNotes);
    } catch (error) {
      console.error("Failed to parse MIDI:", error);
    }
  };

  useEffect(() => {
    if (downloadUrl) {
      parseMidi(downloadUrl);
    }
  }, [downloadUrl, initialBpm]);

  useEffect(() => {
    drawMidiNotes();
  }, [midiNotes, tempo, lowestNote, highestNote, playbackPosition]);

  useEffect(() => {
    if (isPlaying) {
      const updatePlaybackPosition = () => {
      const transport = getTransport();
      const endTime = Math.max(...midiNotes.map(note => note.time + note.duration));
      if (transport.seconds >= endTime) {
        stopMidi();
      } else {
        setPlaybackPosition(transport.seconds);
        animationFrameId.current = requestAnimationFrame(updatePlaybackPosition);
      }
    };
      animationFrameId.current = requestAnimationFrame(updatePlaybackPosition);
    } else {
      if (isPaused) {
        setPlaybackPosition(pausedTime.current);
      } else {
        setPlaybackPosition(0);
      }
      cancelAnimationFrame(animationFrameId.current);
    }

    return () => cancelAnimationFrame(animationFrameId.current);
  }, [isPlaying, isPaused]);

  const drawMidiNotes = () => {
    const canvas = canvasRef.current;
    if (!canvas) return;

    const ctx = canvas.getContext('2d');
    if (!ctx) return;

    const containerWidth = containerRef.current.clientWidth;
    const beatsPerBar = 4;
    const beatDuration = 60 / tempo;
    const barDuration = beatsPerBar * beatDuration;
    const totalBars = Math.ceil(Math.max(...midiNotes.map(note => note.time + note.duration)) / barDuration);
    const canvasWidth = containerWidth;
    const noteRange = highestNote - lowestNote + 1;

    canvas.width = canvasWidth;
    canvas.height = Math.min(containerRef.current.clientHeight, noteRange * 10);

    ctx.clearRect(0, 0, canvas.width, canvas.height);

    for (let i = 0; i <= totalBars; i++) {
      const x = (i * barDuration * canvasWidth) / (totalBars * barDuration);
      ctx.strokeStyle = '#444';
      ctx.lineWidth = 2;
      ctx.beginPath();
      ctx.moveTo(x, 0);
      ctx.lineTo(x, canvas.height);
      ctx.stroke();
    }

    for (let i = 0; i <= totalBars * 8; i++) {
      const x = (i * (barDuration / 8) * canvasWidth) / (totalBars * barDuration);
      ctx.strokeStyle = 'rgba(102, 102, 102, 0.2)';
      ctx.lineWidth = 1;
      ctx.beginPath();
      ctx.moveTo(x, 0);
      ctx.lineTo(x, canvas.height);
      ctx.stroke();
    }

    for (let i = 0; i < noteRange; i++) {
      const y = (i * canvas.height) / noteRange;
      const midiNumber = highestNote - i;
      const isBlackKey = ['C#', 'D#', 'F#', 'G#', 'A#'].includes(Tone.Frequency(midiNumber, 'midi').toNote().replace(/[0-9]/g, ''));
      ctx.fillStyle = isBlackKey ? 'rgba(51, 51, 51, 0.1)' : 'rgba(255, 255, 255, 0.05)';
      ctx.fillRect(0, y, canvas.width, canvas.height / noteRange);

      ctx.strokeStyle = 'rgba(102, 102, 102, 0.2)';
      ctx.lineWidth = 0.5;
      ctx.beginPath();
      ctx.moveTo(0, y);
      ctx.lineTo(canvas.width, y);
      ctx.stroke();
    }

    midiNotes.forEach((note, index) => {
      const x = (note.time * canvasWidth) / (totalBars * barDuration);
      const y = ((highestNote - note.midi) / noteRange) * canvas.height;
      const width = (note.duration * canvasWidth) / (totalBars * barDuration);
      const height = canvas.height / noteRange;

      ctx.fillStyle = note.disabled ? 'rgba(128, 128, 128, 0.5)' : 'rgba(63, 81, 181, 0.8)';
      ctx.fillRect(x, y, width, height);

      ctx.strokeStyle = 'black';
      ctx.lineWidth = 1;
      ctx.strokeRect(x, y, width, height);
    });

    const playbackX = (playbackPosition * canvasWidth) / (totalBars * barDuration);
    ctx.strokeStyle = 'red';
    ctx.lineWidth = 2;
    ctx.beginPath();
    ctx.moveTo(playbackX, 0);
    ctx.lineTo(playbackX, canvas.height);
    ctx.stroke();
  };

  const handleCanvasClick = (event) => {
    const canvas = canvasRef.current;
    if (!canvas) return;

    const rect = canvas.getBoundingClientRect();
    const containerWidth = containerRef.current.clientWidth;
    const beatsPerBar = 4;
    const beatDuration = 60 / tempo;
    const barDuration = beatsPerBar * beatDuration;
    const totalBars = Math.ceil(Math.max(...midiNotes.map(note => note.time + note.duration)) / barDuration);
    const x = event.clientX - rect.left;
    const y = event.clientY - rect.top;

    const noteRange = highestNote - lowestNote + 1;
    const clickedNoteIndex = midiNotes.findIndex((note) => {
      const noteX = (note.time * containerWidth) / (totalBars * barDuration);
      const noteY = ((highestNote - note.midi) / noteRange) * canvas.height;
      const noteWidth = (note.duration * containerWidth) / (totalBars * barDuration);
      const noteHeight = canvas.height / noteRange;

      return (
        x >= noteX &&
        x <= noteX + noteWidth &&
        y >= noteY &&
        y <= noteY + noteHeight
      );
    });

    if (clickedNoteIndex !== -1) {
      setSelectedNote(clickedNoteIndex);
    } else {
      setSelectedNote(null);
    }
  };

  const playMidi = async () => {
    if (isPlaying) return;
    setIsPlaying(true);
    setIsPaused(false);
    const transport = getTransport();
    transport.bpm.value = tempo;
    transport.seconds = pausedTime.current;  // Ensure transport starts at correct position


    await Tone.start();

    
    transport.seconds = pausedTime.current;
    scheduledPart.current = new Tone.Part((time, note) => {
      const safeDuration = Math.max(0.01, note.duration);
      synth.current.triggerAttackRelease(
        Tone.Frequency(note.midi, 'midi').toFrequency(),
        safeDuration,
        time
      );
    }, midiNotes.filter(note => !note.disabled).map(note => ({
      time: note.time,
      midi: note.midi,
      duration: note.duration,
    })));

    scheduledPart.current.start(0);
    transport.bpm.value = tempo;
    transport.start();
  };

  const pauseMidi = () => {
    setIsPlaying(false);
    setIsPaused(true);
    pausedTime.current = getTransport().seconds;

    if (scheduledPart.current) {
      scheduledPart.current.stop();
      //console.log("Scheduled playback paused.");
    }

    getTransport().pause();
    //console.log(`Playback paused at ${pausedTime.current.toFixed(2)}s`);
  };

  const stopMidi = () => {
    setIsPlaying(false);
    setIsPaused(false);
    setPlaybackPosition(0);
    pausedTime.current = 0;

    if (synth.current) {
      synth.current.releaseAll(); // Release all notes immediately
    }

    if (scheduledPart.current) {
      scheduledPart.current.stop();
      scheduledPart.current.dispose();
      scheduledPart.current = null;
      //console.log("Scheduled playback stopped and disposed.");
    }

    const transport = getTransport();
    transport.stop();
    transport.seconds = 0;
    //console.log("Playback stopped and transport reset.");
  };

  const saveMidi = () => {
    try {
      const midi = new Midi();
      midi.header.setTempo(tempo);
      midi.header.timeSignatures.push({
        ticks: 0,
        timeSignature: [timeSignature.numerator, timeSignature.denominator],
        measures: 0,
      });

      const track = midi.addTrack();
      midiNotes.forEach(note => {
        if (!note.disabled) {
          track.addNote({
            midi: note.midi,
            time: note.time,
            duration: note.duration,
          });
        }
      });

      const midiArray = midi.toArray();
      const blob = new Blob([midiArray], { type: 'audio/midi' });
      const url = URL.createObjectURL(blob);
      const a = document.createElement('a');
      a.href = url;
      a.download = 'exported_midi.mid';
      a.click();
      URL.revokeObjectURL(url);
    } catch (error) {
      console.error("Failed to save MIDI:", error);
    }
  };

  

  return (
    <Box sx={{ mt: 0, bgcolor: 'background.paper', color: 'text.primary', width: '100%' }}>
      
      <Box
        ref={containerRef}
        sx={{
          width: '100%',
          height: '250px',
          position: 'relative',
          backgroundColor: '#1e1e1e',
        }}
      >
        <canvas
          ref={canvasRef}
          onClick={handleCanvasClick}
          style={{ display: 'block' }}
        />
      </Box>

      <Box sx={{ display: 'flex', alignItems: 'center', gap: 0, mb: 1 ,mt: 1}}>
        <IconButton onClick={playMidi} color="primary" disabled={isPlaying && !isPaused}>
          <PlayArrow />
        </IconButton>
        <IconButton onClick={pauseMidi} color="primary" disabled={!isPlaying || isPaused}>
          <Pause />
        </IconButton>
        <IconButton onClick={stopMidi} sx={{ color: 'darkgrey' }}>
          <Stop />
        </IconButton>
        <Tooltip title="Save MIDI">
          <IconButton onClick={saveMidi} color="primary">
            <Save />
          </IconButton>
        </Tooltip>
        <Typography variant="body1" sx={{ ml: 2 }}>
          Playback Position: {playbackPosition.toFixed(2)}s
        </Typography>
      </Box>
    </Box>
  );
});

export default MidiPlayer;