import React, { useState } from "react";
import { Box, Grid } from "@mui/material";
import MainButton from "../../buttons/MainButton";
import ModelModal from "../modals/ModelModal";
import LossModal from "../modals/LossModal";
import { CreatModel, ModelData } from "@/types";
import { useCreateModel } from "../../../hooks/configMng/useCreateModel";
import { useUpdateModel } from "../../../hooks/configMng/useUpdateModel";
import { useSelector } from "react-redux";
import { RootState } from "../../../store/store";
import { useFetchProjectModels } from "../../../utils/projectUtils/getAllModels";
import styles from "../css/NewAnalysis.module.css";

interface ModelSectionProps {
  projectId: string;
  saveLoading: boolean;
  setSaveLoading: (loading: boolean) => void;
  modelSaved: boolean;
  setModelSaved: (loading: boolean) => void;
  runningEval: boolean;
  setFramework: (value: string) => void;
}

const initialModelData: Omit<CreatModel, "model_parameters"> = {
  model_parameters_name: "",
  model_definition: null,
  model_definition_name: "",
  loss: null,
  loss_name: "",
  built_in_loss: "",
  project_id: "",
  framework: "",
  ml_task: "",
  num_of_classes: "",
  class_name: "",
  clip_values: "",
  optimizer: "",
  learning_rate: "",
};

const ModelSection: React.FC<ModelSectionProps> = ({
  projectId,
  saveLoading,
  setSaveLoading,
  modelSaved,
  setModelSaved,
  runningEval,
  setFramework
}) => {
  const projectModels = useSelector<RootState, ModelData[] | null>(
    (state) => state.project.projectModels
  );
  const [errorMessage, setErrorMessage] = useState("");
  const [modelOpen, setModelOpen] = useState(false);
  const [lossOpen, setLossOpen] = useState(false);
  const [selectedExistingModel, setSelectedExistingModel] = useState("");
  const [modelData, setModelData] = useState<
    Omit<CreatModel, "model_parameters"> & { model_parameters: File | null }
  >({
    ...initialModelData,
    model_parameters: null,
    project_id: projectId,
  });
  const [savedModelData, setSavedModelData] = useState<
    Omit<CreatModel, "model_parameters"> & { model_parameters: File | null }
  >({
    ...initialModelData,
    model_parameters: null,
    project_id: projectId,
  });

  const projectBody = { project_id: projectId };
  const getModels = useFetchProjectModels();

  const updateModelMutation = useUpdateModel({
    onSuccess(data) {
      setModelSaved(true);
      setSavedModelData(modelData);
      setModelOpen(false);
      setSaveLoading(false);
      setErrorMessage("");
      setFramework(modelData.framework);
      getModels.mutate(projectBody);
    },
    onError(error) {
      setErrorMessage(error.response.data.message);
      setSaveLoading(false);
    },
  });

  const createModelMutation = useCreateModel({
    onSuccess(data) {
      setModelSaved(true);
      setSavedModelData(modelData);
      setModelOpen(false);
      setLossOpen(false);
      setSaveLoading(false);
      setSelectedExistingModel(data.model_id);
      setErrorMessage("");
      setFramework(modelData.framework);
      getModels.mutate(projectBody);
    },
    onError(error) {
      let message = error.response.data.message;
      if (message.includes("422")) {
        message = "invalid body.";
      }
      setErrorMessage(message);
      setSaveLoading(false);
    },
  });

  const handleClearLossData = () => {
    let newdata = modelData;
    newdata.built_in_loss = "";
    newdata.loss = null;
    newdata.loss_name = "";
    setModelData(newdata);
  };

  const handleClearFrameworkData = () => {
    let newData = modelData;
    newData.optimizer = "";
    newData.learning_rate = "";
    newData.model_definition = null;
    newData.model_definition_name = "";
    newData.class_name = "";
    setModelData(newData);
  };

  const handleClearOptimizer = () => {
    let newData = modelData;
    newData.optimizer = "";
    setModelData(newData);
  };

  const handleModelClosed = () => {
    if(savedModelData.framework === ""){
      setSelectedExistingModel("");
    }
    setModelData(savedModelData);
    setModelOpen(false);
    setLossOpen(false);
    setErrorMessage("");
  };

  const handleLossClosed = () => {
    setModelOpen(true);
    setLossOpen(false);
    handleClearLossData();
  };

  const handleSaveModel = () => {
    setSaveLoading(true);
    if (selectedExistingModel !== "") {
      const body = {
        model_id: selectedExistingModel,
        project_id: projectId,
        framework: modelData?.framework || "",
        ml_task: modelData?.ml_task || "",
        num_of_classes: modelData?.num_of_classes || "",
        class_name: modelData?.class_name || "",
        optimizer: modelData?.optimizer || "",
        clip_values: modelData?.clip_values || "",
        learning_rate: modelData?.learning_rate || "",
      };
      updateModelMutation.mutate(body);
    } else {
      createModelMutation.mutate(modelData as CreatModel);
    }
  };

  const handleSaveLoss = () => {
    if (handleCheckSaveModal()) {
      setLossOpen(false);
      setModelOpen(true);
    } else {
      setLossOpen(false);
      setModelOpen(true);
    }
  };

  const handleCheckSaveModal = () => {
    if (
      modelData != null && modelData.num_of_classes !== "" &&
      (modelData.model_parameters !== null ||
        modelData.model_parameters_name) &&
      modelData.framework !== "" &&
      modelData.ml_task !== ""
    ) {
      if (
        ((modelData.optimizer !== "" && modelData.learning_rate !== "") ||
        (modelData.optimizer === "" && modelData.learning_rate === ""))
      ) {
        if((modelData.framework === "TensorFlow" || modelData.framework === "Keras" || modelData.framework === "PyTorch") && modelData.clip_values === ""){
          return false;
        }
        // loss required on pytorch
        if(modelData.framework === "PyTorch" && (modelData.loss_name === "" && modelData.built_in_loss === "")){
          return false;
        }
        // class name and model definition required on pytorch kers or tensorflow
        if((modelData.framework === "TensorFlow" || modelData.framework === "Keras" || modelData.framework === "PyTorch") && (modelData.model_definition_name === "" || modelData.class_name === "")){
          return false;
        }
        return true;
      } else {
        return false;
      }
    } else {
      return false;
    }
  };

  return (
    <>
      <Grid item xs={2} md={2} lg={2}>
        <Box className={styles.buttonBox}>
          <MainButton
            label="model"
            theme={`${modelSaved ? "white" : "black"}`}
            className={styles.button}
            onClick={() => {
              setModelOpen(true);
            }}
            disabled={runningEval}
          />
        </Box>
      </Grid>
      <ModelModal
        data={modelData}
        setData={setModelData}
        modelsData={projectModels || []}
        onClose={handleModelClosed}
        open={modelOpen}
        onSave={handleSaveModel}
        onLossPressed={() => {
          setModelOpen(false);
          setLossOpen(true);
        }}
        clearLossData={handleClearLossData}
        clearFrameworkData={handleClearFrameworkData}
        checkDataForSave={handleCheckSaveModal}
        selectedExistingModel={selectedExistingModel}
        setSelectedExistingModel={setSelectedExistingModel}
        clearOptimizer={handleClearOptimizer}
        loading={saveLoading}
        errorMessage={errorMessage}
      />
      <LossModal
        data={modelData}
        setData={setModelData}
        onClose={handleLossClosed}
        open={lossOpen}
        onSave={handleSaveLoss}
        loading={saveLoading}
      />
    </>
  );
};

export default ModelSection;
