import * as posedetection from '@tensorflow-models/pose-detection';
import { Keypoint, Pose, PoseDetector } from '@tensorflow-models/pose-detection';
import { COLOR_PALETTE } from '../utilities';
import { MutableRefObject } from "react";
import Webcam from "react-webcam";
import { merge, Observable, Subject } from 'rxjs';
import { ExerciseResult } from '../models';
import { MAX_VIDEO_WIDTH } from './Constant'

interface PoseInfo {
  name?: string;
  faceRegistered?: boolean;
  torsoRegistered?: boolean;
  legsRegistered?: boolean;
  fullBodyRegistered?: boolean;
}

export interface PoseData {
  leftEar: Keypoint;
  leftEye: Keypoint;
  leftShoulder: Keypoint;
  leftElbow: Keypoint;
  leftHip: Keypoint;
  leftWrist: Keypoint;
  leftKnee: Keypoint;
  leftAnkle: Keypoint;
  nosePoint: Keypoint;
  rightEar: Keypoint;
  rightEye: Keypoint;
  rightShoulder: Keypoint;
  rightElbow: Keypoint;
  rightHip: Keypoint;
  rightWrist: Keypoint;
  rightKnee: Keypoint;
  rightAnkle: Keypoint;
  midFemurAvg: number;
  femurLengthL: number;
  femurLengthR: number;
  midFemurL: number;
  midFemurR: number;
  midShinL: number;
  midShinR: number;
  midShinAvg: number;
  midTorsoL: number;
  midTorsoR: number;
  midTorsoAvg: number;
}

// Create class to predict pose snapshots from webcam feed
export class PosePredictor {

  // props
  webcam: MutableRefObject<Webcam>;
  canvas: MutableRefObject<HTMLCanvasElement>;
  reciprocal = 0;
  recorder!: MediaRecorder;
  recording!: Blob;
  chunks: Blob[];
  registerMedia: Subject<boolean>;
  registerResult: Subject<boolean>;
  ratioReady: Subject<boolean>;
  sendResult: Observable<any>;
  result!: any;
  ctx: CanvasRenderingContext2D;
  detector: posedetection.PoseDetector;
  displacement: posedetection.Pose | null = null;
  pose: PoseInfo = {};
  config: {
    model: any;
    maxPoses: number;
    type: 'lightning' | 'thunder';
    scoreThreshold: number;
    customModel: any;
    enableTracking: boolean;
    lineWidth: number;
    keypointRadius: number;
  }

  static cStatic: any;

  static armKeypoints: any = {
    "MoveNet" : ['left_elbow', 'right_elbow', 'left_wrist', 'right_wrist'],
    "PoseNet" : ['leftElbow', 'rightElbow', 'leftWrist', 'rightWrist'],
  }

  static facialKeypoints: any = {
    "MoveNet" : ['nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear'],
    "PoseNet" : ['nose', 'leftEye', 'rightEye', 'leftEar', 'rightEar'],
  }

  static torsoKeypoints: any = {
    "MoveNet" : ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip'],
    "PoseNet" : ['leftShoulder', 'rightShoulder', 'leftHip', 'rightHip'],
  }

  static legKeypoints: any = {
    "MoveNet" : ['left_knee', 'right_knee', 'left_ankle', 'right_ankle'],
    "PoseNet" : ['leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle'],
  }
  
  constructor(
    webcam: MutableRefObject<Webcam | null>,
    canvas: MutableRefObject<HTMLCanvasElement | null>,
    detector: PoseDetector,
    config: any
  ) {
      this.registerMedia = new Subject<boolean>();
      this.registerResult = new Subject<boolean>();
      this.ratioReady = new Subject<boolean>();
      this.sendResult = merge(this.registerMedia, this.registerResult);
      this.webcam = webcam as MutableRefObject<Webcam>;
      this.canvas = canvas as MutableRefObject<HTMLCanvasElement>;
      this.ctx = (canvas.current as HTMLCanvasElement).getContext('2d') as CanvasRenderingContext2D;
      this.detector = detector;
      this.config = config;
      this.chunks = [];
      PosePredictor.cStatic = config;
      this.getAspectReciprocal();
    }

  static data(pose: Pose): PoseData | null {
    if (!(pose&&pose.keypoints)) return null;

    const leftElbow = pose.keypoints.find(kp => kp.name === this.armKeypoints[this.cStatic.model][0]) as Keypoint;
    const leftWrist = pose.keypoints.find(kp => kp.name === this.armKeypoints[this.cStatic.model][2]) as Keypoint;
    const leftEar = pose.keypoints.find(kp => kp.name === this.facialKeypoints[this.cStatic.model][3]) as Keypoint;
    const leftEye = pose.keypoints.find(kp => kp.name === this.facialKeypoints[this.cStatic.model][1]) as Keypoint;
    const leftKnee = pose.keypoints.find(kp => kp.name === this.legKeypoints[this.cStatic.model][0]) as Keypoint;
    const leftAnkle = pose.keypoints.find(kp => kp.name === this.legKeypoints[this.cStatic.model][2]) as Keypoint;
    const leftShoulder = pose.keypoints.find(kp => kp.name === this.torsoKeypoints[this.cStatic.model][0]) as Keypoint;
    const leftHip = pose.keypoints.find(kp => kp.name === this.torsoKeypoints[this.cStatic.model][2]) as Keypoint;
    const nosePoint = pose.keypoints.find(kp => kp.name === this.facialKeypoints[this.cStatic.model][0]) as Keypoint;
    const rightElbow = pose.keypoints.find(kp => kp.name === this.armKeypoints[this.cStatic.model][1]) as Keypoint;
    const rightWrist = pose.keypoints.find(kp => kp.name === this.armKeypoints[this.cStatic.model][3]) as Keypoint;
    const rightEar = pose.keypoints.find(kp => kp.name === this.facialKeypoints[this.cStatic.model][4]) as Keypoint;
    const rightEye = pose.keypoints.find(kp => kp.name === this.facialKeypoints[this.cStatic.model][2]) as Keypoint;
    const rightKnee = pose.keypoints.find(kp => kp.name === this.legKeypoints[this.cStatic.model][1]) as Keypoint;
    const rightAnkle = pose.keypoints.find(kp => kp.name === this.legKeypoints[this.cStatic.model][3]) as Keypoint;
    const rightShoulder = pose.keypoints.find(kp => kp.name === this.torsoKeypoints[this.cStatic.model][1]) as Keypoint;
    const rightHip = pose.keypoints.find(kp => kp.name === this.torsoKeypoints[this.cStatic.model][3]) as Keypoint;

    const midTorsoL = (leftHip.y + leftShoulder.y) / 2;
    const midFemurL = (leftKnee.y + leftHip.y) / 2;
    const femurLengthL = leftKnee.y - leftHip.y
    const femurLengthR = rightKnee.y - rightHip.y
    const midShinL = (leftAnkle.y + leftKnee.y) / 2;
    const midTorsoR = (rightHip.y + rightShoulder.y) / 2;
    const midFemurR = (rightKnee.y + rightHip.y) / 2;
    const midShinR = (rightAnkle.y + rightKnee.y) / 2;

    const midTorsoAvg = (midTorsoL + midTorsoR) / 2;
    const midFemurAvg = (midFemurL + midFemurR) / 2;
    const midShinAvg = (midShinL + midShinR) / 2;

    return { nosePoint,
      leftEar, leftEye, leftShoulder, leftElbow, leftHip, leftWrist, leftKnee, leftAnkle,
      rightEar, rightEye, rightShoulder, rightElbow, rightHip, rightWrist, rightKnee, rightAnkle,
      midFemurAvg, midFemurL, midFemurR, femurLengthL, femurLengthR,
      midShinL, midShinR, midShinAvg,
      midTorsoL, midTorsoR, midTorsoAvg
    }
  }

  captureChunks({ data }: BlobEvent) {
    data.size && (this.chunks = this.chunks.concat(data));
  }

  getAspectReciprocal() {
    const { videoWidth, videoHeight } = this.webcam.current.video as HTMLVideoElement;
    this.reciprocal = videoHeight / videoWidth;
    if (this.reciprocal) this.ratioReady.next(true);
    else requestAnimationFrame(() => this.getAspectReciprocal());
  }

  startRecording() {

    console.log("this.webcam", this.webcam.current);

    this.recorder = new MediaRecorder(
      this.webcam.current.stream!, { mimeType: "video/webm" });
    this.recorder.addEventListener(
      "dataavailable", e => {
        let chunk = this.captureChunks(e);
        console.log("dataavailable", e, chunk );
        return chunk;
      
    });
    this.recorder.onstop = () => this.sendRecording();

    console.log("recorder started");
    this.recorder.start();
  };

  stopRecording() {
    this.recorder.stop();
  };

  sendRecording() {
    this.recording = new Blob(this.chunks, { type: "video/mp4" });
    console.log("got new video recording: ", this.chunks, this.recording);
    this.chunks = [];
    this.registerMedia.next(true);
  }

  /**
   * Get pose keypoints from detector
   */
  async produceSnapshot(): Promise<Pose[] | null> {
    if (((this.webcam.current as Webcam)?.video as HTMLVideoElement).readyState === 4) {
      const video = this.webcam.current.video as HTMLVideoElement;
      const poses = await this.detector.estimatePoses(video, {
        maxPoses: this.config.maxPoses,
        flipHorizontal: true
      });
      return poses;
    }
    return null;
  }

  /**
   * Set a displacement snapshot for offsetting predictor keypoints
   */
  async displace(): Promise<void> {
    const snapshot = await this.produceSnapshot();
    const position = await this.predictPose(snapshot);
    this.displacement = (snapshot as Pose[])[0];
    this.pose = (position as PoseInfo[])[0];
  }

  /**
   * Clear the displacement snapshot
   */
  clearPose() {
    this.pose = {};
    this.displacement = null;
  }

  drawCtx() {
    const v = this.webcam.current.video as HTMLVideoElement;
    this.ctx.drawImage(v, 0, 0, v.videoWidth, v.videoHeight);
  }

  clearCtx() {
    const v = this.webcam.current.video as HTMLVideoElement;
    this.ctx.clearRect(0, 0, v.videoWidth, v.videoHeight);
  }

  /**
   * Draw a guideline on video canvas
   * @param x1 X position of the first coordinate
   * @param x2 X position of the second coordinate
   * @param y Y position
   */
  drawGuideline(x1: number, x2: number, y: number) {
    this.ctx.fillStyle = 'White';
    // this.ctx.strokeStyle = '#603489';
    this.ctx.strokeStyle = 'Red';
    this.ctx.lineWidth = 3;

    this.ctx.beginPath();
    this.ctx.moveTo(x1, y);
    this.ctx.lineTo(x2, y);
    this.ctx.stroke();

    let circle = new Path2D();
    circle.arc(x1, y, 7, 0, 2 * Math.PI);
    this.ctx.fill(circle);
    this.ctx.stroke(circle);
    
    circle = new Path2D();
    circle.arc(x2, y, 7, 0, 2 * Math.PI);
    this.ctx.fill(circle);
    this.ctx.stroke(circle);
  }

  /**
   * Draw pose midpoint guideline on video canvas
   * @param midpoint Midpoint guideline to render
   */
   drawMidpoint(midpoint: 'midFemurAvg' | 'midFemurL' | 'midFemurR' |
    'midShinL' | 'midShinR' | 'midShinAvg' |
    'midTorsoL' | 'midTorsoR' | 'midTorsoAvg') {
    if (!this.displacement) return;
    const data = this.getPoseData(this.displacement);
    const { leftShoulder, rightShoulder } = data;
    const y = data[midpoint];
    this.drawGuideline(leftShoulder.x + 20, rightShoulder.x - 20, y)
   }

  /**
   * Draw all skeletons and keypoints on video canvas
   * @param snapshot Array of poses to render
   */
  async drawResults(snapshot?: Pose[]) {
    if (snapshot === null) return;
    if (!snapshot) snapshot = await this.produceSnapshot() as Pose[];
    const { videoWidth, videoHeight } = this.webcam.current.video as HTMLVideoElement;

    // Set canvas width based on webcam video
    this.canvas.current.width = videoWidth;
    this.canvas.current.height = videoHeight;

    // this.drawCtx();
    for (const pose of snapshot) this.drawResult(pose);
    if (this.displacement)
      this.drawMidpoint('midFemurAvg');
      // this.drawResult(this.displacement);
  }

  /**
   * Draw a skeleton and keypoints on video canvas
   * @param pose A pose with keypoints to render
   */
  drawResult(pose: Pose) {
    if (pose.keypoints) {
      this.drawKeypoints(pose.keypoints);
      this.drawSkeleton(pose.keypoints, (pose.id as number));
    }
  }

  /**
   * Draw keypoints on video canvas
   * @param keypoints Array of keypoints
   */
  drawKeypoints(keypoints: Keypoint[]) {
    const keypointInd = posedetection.util.getKeypointIndexBySide(this.config.model);
    this.ctx.fillStyle = 'LightGray';
    this.ctx.strokeStyle = 'White';
    this.ctx.lineWidth = this.config.lineWidth || 1;

    for (const i of keypointInd.middle) {
      this.drawKeypoint(keypoints[i]);
    }

    this.ctx.fillStyle = 'Green';
    for (const i of keypointInd.left) {
      this.drawKeypoint(keypoints[i]);
    }

    this.ctx.fillStyle = 'Orange';
    for (const i of keypointInd.right) {
      this.drawKeypoint(keypoints[i]);
    }
  }

  /**
   * Draw keypoint on video canvas
   * @param keypoint Keypoint to render
   */
  drawKeypoint(keypoint: Keypoint) {
    // If score is null, just show the keypoint
    const score = keypoint.score != null ? keypoint.score : 1;
    const scoreThreshold = this.config.scoreThreshold || 0;

    if (score >= scoreThreshold) {
      const circle = new Path2D();
      circle.arc(keypoint.x, keypoint.y, this.config.keypointRadius || 4, 0, 2 * Math.PI);
      this.ctx.fill(circle);
      this.ctx.stroke(circle);
    }
  }

  /**
   * Draw a skeleton on video canvas
   * @param keypoints A list of keypoints
   */
  drawSkeleton(keypoints: Keypoint[], poseId: number) {
    // Each poseId is mapped to a color in the color palette
    const color = this.config.enableTracking && poseId != null ?
        COLOR_PALETTE[poseId % 20] :
        'White';
    this.ctx.fillStyle = color;
    this.ctx.strokeStyle = color;
    this.ctx.lineWidth = this.config.lineWidth;

    posedetection.util.getAdjacentPairs(this.config.model).forEach(([i, j]) => {
      const kp1 = keypoints[i];
      const kp2 = keypoints[j];

      // If score is null, just show the keypoint
      const score1 = kp1.score != null ? kp1.score : 1;
      const score2 = kp2.score != null ? kp2.score : 1;
      const scoreThreshold = this.config.scoreThreshold || 0;

      if (score1 >= scoreThreshold && score2 >= scoreThreshold) {
        this.ctx.beginPath();
        this.ctx.moveTo(kp1.x, kp1.y);
        this.ctx.lineTo(kp2.x, kp2.y);
        this.ctx.stroke();
      }
    });
  }

  getPoseData (pose: Pose): PoseData {
    return PosePredictor.data(pose) as PoseData;
  }

  /**
   * Get an array of "poses" (assumed positions) based on snapshot
   * @param snapshot Array of detector poses for which to predict positions
   */
  async predictPose(snapshot?: Pose[] | null) {
    if (snapshot === null) return;
    let poses: PoseInfo[] = [], P = PosePredictor;
    if (!snapshot) snapshot = await this.produceSnapshot();

    for (const pose of (snapshot as Pose[])) {
      let name = '', faceRegistered = true, torsoRegistered = true, legsRegistered = true, fullBodyRegistered;
      for (const keypoint of P.facialKeypoints[this.config.model]) {
        const { score } = pose.keypoints.find(kp => kp.name === keypoint) as Keypoint;
        if (score && score < this.config.scoreThreshold) {
          faceRegistered = false; break;
        }
      }
      for (const keypoint of P.torsoKeypoints[this.config.model]) {
        const { score } = pose.keypoints.find(kp => kp.name === keypoint) as Keypoint;
        if (score && score < this.config.scoreThreshold) {
          torsoRegistered = false; break;
        }
      }
      for (const keypoint of P.legKeypoints[this.config.model]) {
        const { score } = pose.keypoints.find(kp => kp.name === keypoint) as Keypoint;
        if (score && score < this.config.scoreThreshold) {
          legsRegistered = false; break;
        }
      }

      fullBodyRegistered = faceRegistered && torsoRegistered && legsRegistered;

      if (fullBodyRegistered) { // all keypoints detected

      }

      if (torsoRegistered && legsRegistered) { // assume that body is in frame
        const {
          leftShoulder, leftHip, leftKnee, leftAnkle, rightShoulder, rightHip, rightKnee, rightAnkle
        } = this.getPoseData(pose) as PoseData;

        const leftTorsoY = leftHip.y - leftShoulder.y;
        const leftFemurY = leftKnee.y - leftHip.y;
        const leftShinY = leftAnkle.y - leftKnee.y;
        const rightTorsoY = rightHip.y - rightShoulder.y;
        const rightFemurY = rightKnee.y - rightHip.y;
        const rightShinY = rightAnkle.y - rightKnee.y;

        const leftAvgY = (leftTorsoY + leftFemurY + leftShinY) / 3;
        const rightAvgY = (rightTorsoY + rightFemurY + rightShinY) / 3;

        const leftTorsoD = leftAvgY / leftTorsoY;
        const leftFemurD = leftAvgY / leftFemurY;
        const leftShinD = leftAvgY / leftShinY;
        const rightTorsoD = rightAvgY / rightTorsoY;
        const rightFemurD = rightAvgY / rightFemurY;
        const rightShinD = rightAvgY / rightShinY;

        if (
          leftTorsoD >= 0.7 && leftTorsoD <= 1.3 &&
          leftFemurD >= 0.7 && leftFemurD <= 1.3 &&
          leftShinD >= 0.7 && leftShinD <= 1.3 &&
          rightTorsoD >= 0.7 && rightTorsoD <= 1.3 &&
          rightFemurD >= 0.7 && rightFemurD <= 1.3 &&
          rightShinD >= 0.7 && rightShinD <= 1.3
        ) name = 'standing';
        else if (
          leftTorsoD > 1.3 || rightTorsoD > 1.3
        ) name = 'bending';
        else if (
          leftFemurD > 1.3 || rightFemurD > 1.3
        ) name = 'squatting';

      } else if (torsoRegistered) { // assume upper body is in frame only
        name = 'chest forward';

      } else if (legsRegistered && !faceRegistered) { // assume lower body is in frame only
        name = 'legs forward';

      } else if (faceRegistered) { // assume face is in frame only
        const { leftEar, leftEye, nosePoint, rightEar, rightEye } = P.data(pose) as PoseData;
        name = 'head forward';
        if (rightEye.x > nosePoint.x || rightEar.x > rightEye.x) name = 'head right'
        if (leftEye.x < nosePoint.x || leftEar.x < leftEye.x) name = 'head left'
        if (leftEar.y > nosePoint.y || rightEar.y > nosePoint.y) {
          if (leftEar.y > nosePoint.y && rightEar.y < nosePoint.y) name = 'head tilt left';
          else if (rightEar.y > nosePoint.y && leftEar.y < nosePoint.y) name = 'head tilt right';
          else name = 'head up';
          if (rightEye.x > nosePoint.x || rightEar.x > rightEye.x) name = 'head up right'
          if (leftEye.x < nosePoint.x || leftEar.x < leftEye.x) name = 'head up left';
        }
        if (leftEar.y < leftEye.y && rightEar.y < rightEye.y) {
          name = 'head down';
          if (rightEye.x > nosePoint.x || rightEar.x > rightEye.x) name = 'head down right'
          if (leftEye.x < nosePoint.x || leftEar.x < leftEye.x) name = 'head down left';
        }
      }

      const position: PoseInfo = { name, faceRegistered, torsoRegistered, legsRegistered, fullBodyRegistered };
      poses.push(position);
    }

    return poses;
  }

}