import React, { useEffect, useState } from "react";
import {
  Box,
  Typography,
  Dialog,
  DialogTitle,
  DialogContent,
  DialogActions,
  IconButton,
} from "@mui/material";
import SelectButton from "../../buttons/SelectButton";
import UploadButton from "../../buttons/UploadButton";
import TooltipWrapper from "../../assets/TooltipWrapper";
import styles from "./css/ModelModal.module.css";
import MainButton from "../../buttons/MainButton";
import { isValidNumberArray } from "../../../utils/arrayNumbersCheck";
import { ModelData , CreatModel} from "@/types";
import { convertToCreatModel } from "../../../utils/convertConfigData";
import { checkErrors } from "../../../utils/checkErrors";
import { useGenerateURL } from "../../../hooks/configMng/useGenerateURL";
import {
  isValidPositiveInteger,
  isValidPositiveFloat,
} from "../../../utils/numberCheck";
import { TbDownload } from "react-icons/tb";

interface ModelModalProps {
  data: any;
  setData: React.Dispatch<React.SetStateAction<any>>;
  modelsData: ModelData[];
  onSave: () => void;
  onClose: () => void;
  onLossPressed: () => void;
  clearLossData: () => void;
  clearFrameworkData: () => void;
  checkDataForSave: () => boolean;
  clearOptimizer: () => void;
  selectedExistingModel: string;
  setSelectedExistingModel: (id: string) => void;
  loading: boolean;
  open: boolean;
  errorMessage: string;
}

const ModelModal: React.FC<ModelModalProps> = ({
  data,
  setData,
  modelsData,
  onSave,
  onClose,
  open,
  onLossPressed,
  clearLossData,
  clearFrameworkData,
  checkDataForSave,
  selectedExistingModel,
  setSelectedExistingModel,
  clearOptimizer,
  loading,
  errorMessage,
}) => {
  const frameworks = [
    "TensorFlow",
    "Keras",
    "PyTorch",
    "Sklearn",
    "CatBoost",
    "XGBoost",
  ];
  const optimizerTensorKeras = [
    "Adadelta",
    "Adafactor",
    "Adagrad",
    "Adam",
    "AdamW",
    "Adamax",
    " Ftrl",
    "Lion",
    "NAdam",
    "Rmsprop",
    "SGD",
  ];
  const optimizerPyTorch = [
    "Adam",
    "Adadelta",
    "Adagrad",
    "AdamW",
    "Adamax",
    "ASGD",
    "NAdam",
    "RAdam",
    "LBFGS",
    "SparseAdam",
    "RMSprop",
    "Rprop",
    "SGD",
  ];
  const mlTasks = ["Classification"];
  const [disableLoss, setDisableLoss] = useState(true);
  const [
    disableOptLearningClassDefinition,
    setDisableOptLearningClassDefinition,
  ] = useState(true);
  const [requiredLoss, setRequiredLoss] = useState(false);
  const [requiredClipValues, setRequiredClipValues] = useState(false);
  const [requiredLearningOpt, setRequiredLearningOpt] = useState(false);
  const [requiredClassDefinition, setRequiredClassDefeinition] =
    useState(false);
  const [errors, setErrors] = useState<any>({});
  const [disableSave, setDisableSave] = useState(true);

  type Framework =
    | "PyTorch"
    | "TensorFlow"
    | "Keras"
    | "Sklearn"
    | "CatBoost"
    | "XGBoost";
  const fileExtensions: Record<Framework, string> = {
    PyTorch: "pth",
    TensorFlow: "keras",
    Keras: "keras",
    Sklearn: "pickle",
    CatBoost: "cbm",
    XGBoost: "json",
  };

  const templatesPath = "model_def/";
  const files = {
    PyTorch: "pytorch_model_def.py",
    TensorFlow: "tensorflow_model_def.py",
    Keras: "tensorflow_model_def.py",
  };

  const getErrorMessage = (framework: Framework, extension: string) => {
    return `The model parameters file for ${framework} should be a .${extension} file.`;
  };

  const handleDownload = (url: string) => {
    const a = document.createElement("a");
    a.href = url;
    // a.download = filename;
    document.body.appendChild(a); // Required for this to work in FireFox
    a.click();
    document.body.removeChild(a);
  };

  const generateURL = useGenerateURL({
    onSuccess(data) {
      // setUrlToDownload(data.url);
      handleDownload(data.url);
    },
    onError(error) {
      console.log(`Error fetching URL: ${error}`);
    },
  });

  useEffect(() => {
    // Check if the required fields are filled to enable the Save button
    if (checkDataForSave() && checkErrors(errors)) {
      setDisableSave(false);
    } else {
      setDisableSave(true);
    }
  }, [data, checkDataForSave, errors]);

  const handleFrameworkChange = (value: string, fromExisting?: boolean) => {
    if (value === "PyTorch") {
      setRequiredClassDefeinition(true);
      setDisableLoss(false);
      setDisableOptLearningClassDefinition(false);
      setRequiredClipValues(true);
      setRequiredLoss(true);
      if (!fromExisting) {
        clearOptimizer();
      }
    } else if (value === "TensorFlow" || value === "Keras") {
      setRequiredClassDefeinition(true);
      setDisableOptLearningClassDefinition(false);
      setDisableLoss(false);
      setRequiredClipValues(true);
      setRequiredLoss(false);
      if (!fromExisting) {
        clearLossData();
        clearOptimizer();
      }
    } else {
      setRequiredClipValues(false);
      setRequiredClassDefeinition(false);
      setDisableOptLearningClassDefinition(true);
      setDisableLoss(true);
      setRequiredLoss(false);
      if (!fromExisting) {
        clearLossData();
        clearFrameworkData();
        clearOptimizer();
      }
    }
    if (!fromExisting) {
      const expectedExtension = fileExtensions[value as Framework];
      const fileExtension = data?.model_parameters?.name.split(".").pop();
      if (
        value !== "" &&
        fileExtension !== undefined &&
        fileExtension !== expectedExtension
      ) {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          model_parameters: getErrorMessage(
            value as Framework,
            expectedExtension
          ),
        }));
      } else {
        if (value === "") {
          if (data.model_parameters_name !== "") {
            setErrors((prevErrors: any) => ({
              ...prevErrors,
              model_parameters: "Please choose the framework type first.",
            }));
          }
        } else {
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            model_parameters: "",
          }));
        }
      }
      setData({ ...data, framework: value });
    }
  };

  const handleSelectChange = (field: string, value: string) => {
    if (field === "framework") {
      handleFrameworkChange(value);
    } else {
      if (field === "optimizer" && !disableOptLearningClassDefinition) {
        if (value !== "") {
          setRequiredLearningOpt(true);
        } else {
          if (data.learning_rate === "") {
            setRequiredLearningOpt(false);
          }
        }
      }
      setData({ ...data, [field]: value });
    }
  };

  const handleTextFieldChange = (field: string, value: string) => {
    setData({ ...data, [field]: value });

    if (field === "clip_values") {
      if (isValidNumberArray(value) || value === "") {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]: "",
        }));
      } else {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]:
            "Invalid array format. example of valid array, for example: [1,2]",
        }));
      }
    }
    if (field === "learning_rate") {
      if (value !== "") {
        setRequiredLearningOpt(true);
        if (!isValidPositiveFloat(value)) {
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "Invalid number, please insert only positive float.",
          }));
        } else {
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "",
          }));
        }
      } else {
        if (data.optimizer === "") {
          setRequiredLearningOpt(false);
        }
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]: "",
        }));
      }
    }
    if (field === "num_of_classes") {
      if (isValidPositiveInteger(value) || value === "") {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]: "",
        }));
      } else {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]: "Invalid number, please insert only positive int.",
        }));
      }
    }
  };

  const handleUpload = (field: string, file: File) => {
    const fileExtension = file.name.split(".").pop();
    if (field === "model_parameters") {
      if (data.framework === "") {
        setErrors((prevErrors: any) => ({
          ...prevErrors,
          [field]: "Please choose the framework type first.",
        }));
      } else {
        const expectedExtension = fileExtensions[data.framework as Framework];
        if (fileExtension !== expectedExtension) {
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: getErrorMessage(data.framework, expectedExtension),
          }));
          setData((prevData: CreatModel) => ({
            ...prevData,
            [field]: null,
            [`${field}_name`]: ""
          }))
        } else if (selectedExistingModel !== "") {
          setData({
            ...data,
            model_parameters: file,
            model_parameters_name: file.name,
            model_definition: null,
            model_definition_name: "",
            loss: null,
            loss_name: "",
            built_in_loss: "",
            project_id: data.project_id,
            framework: "",
            ml_task: "",
            num_of_classes: "",
            class_name: "",
            clip_values: "",
            optimizer: "",
            learning_rate: "",
          });
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "",
          }));
          setSelectedExistingModel("");
        } else {
          setData({
            ...data,
            model_parameters: file,
            model_parameters_name: file.name,
          });
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "",
          }));
        }
      }
    } else {
      if (field === "model_definition") {
        if (fileExtension !== "py") {
          setData((prevData: CreatModel) => ({
            ...prevData,
            [field]: null,
            [`${field}_name`]: ""
          }))
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "The model definition file should be a .py file.",
          }));
        } else {
          setData({
            ...data,
            model_definition: file,
            model_definition_name: file.name,
          });
          setErrors((prevErrors: any) => ({
            ...prevErrors,
            [field]: "",
          }));
        }
      }
    }
  };

  const getFilteredFrameworks = () => {
    if (data.framework === "TensorFlow" || data.framework === "Keras") {
      return optimizerTensorKeras;
    } else if (data.framework === "PyTorch") {
      return optimizerPyTorch;
    }
    return [];
  };

  const handleSelectionExistingModel = (id: string, name: string) => {
    if (id === "" && name === "") {
      if (selectedExistingModel === "") {
        setData({
          ...data,
          model_parameters: null,
          model_parameters_name: "",
        });
      } else {
        setSelectedExistingModel("");
        setData({
          ...data,
          model_parameters: null,
          model_parameters_name: "",
          model_definition: null,
          model_definition_name: "",
          loss: null,
          loss_name: "",
          built_in_loss: "",
          project_id: data.project_id,
          framework: "",
          ml_task: "",
          num_of_classes: "",
          class_name: "",
          clip_values: "",
          optimizer: "",
          learning_rate: "",
        });
        handleFrameworkChange("empty", true);
      }
    } else {
      setSelectedExistingModel(id);
      const existingModel = modelsData.find((model) => model.id === id);
      setData(convertToCreatModel(existingModel as ModelData));
      handleFrameworkChange(existingModel?.framework as string, true);
    }
    setErrors((prevErrors: any) => ({ ...prevErrors, model_parameters: "" }));
  };

  const handleDownloadTemplate = () => {
    const framework = data.framework;
    let templatePath = templatesPath + files[framework as keyof typeof files];
    generateURL.mutate({ file_url: templatePath });
  };

  const handleClose = () => {
    setDisableLoss(true);
    setDisableOptLearningClassDefinition(true);
    setErrors({});
    onClose();
  };

  return (
    <Dialog
      open={open}
      onClose={()=>{}}
      maxWidth={false}
      fullWidth
      PaperProps={{
        sx: {
          width: "72rem",
          height: "45rem",
          maxWidth: "none",
          margin: "auto",
        },
      }}
    >
      <DialogTitle>
        Model Configuration
        <IconButton onClick={handleClose} className={styles.closeButton}>
          X
        </IconButton>
      </DialogTitle>
      <DialogContent>
        <Box className={styles.container}>
          <Box className={styles.column}>
            <TooltipWrapper tooltipText="Select the framework your model uses">
              <Typography>
                <span className={styles.required}>*</span>Framework
              </Typography>
            </TooltipWrapper>
            <SelectButton
              data={frameworks}
              selectedValue={data.framework}
              updateFunction={(value) => handleSelectChange("framework", value)}
              isUpload={false}
              disabledState={selectedExistingModel !== ""}
            />
            <div className={styles.placeholderMin}></div>
            <TooltipWrapper tooltipText="Only classificaiton is supportes at this time">
              <Typography>
                <span className={styles.required}>*</span>ML Task
              </Typography>
            </TooltipWrapper>
            <SelectButton
              className={styles.ml_task}
              data={mlTasks}
              selectedValue={data.ml_task}
              updateFunction={(value) => handleSelectChange("ml_task", value)}
              isUpload={false}
            />
          </Box>
          <div className={styles.columnBorder}></div>
          <Box className={styles.column}>
            <TooltipWrapper tooltipText="Upload the saved model (e.g., weights)">
              <Typography>
                <span className={styles.required}>*</span>Model Parameters
              </Typography>
            </TooltipWrapper>
            <UploadButton
              data={modelsData.map((model) => ({
                name: model.model_parameters_name as string,
                id: model.id,
              }))}
              updateFunction={handleSelectionExistingModel}
              selectedValue={data.model_parameters_name || ""}
              onUpload={(file) => handleUpload("model_parameters", file)}
              selection={true}
            />
            <div className={styles.errorMessage}>{errors.model_parameters}</div>

            <Typography>Template Definition File</Typography>
            <MainButton
              theme="black"
              label="Download Template"
              onClick={handleDownloadTemplate}
              IconComponent={TbDownload}
              className={styles.template}
              disabled={data.framework !== "PyTorch" && data.framework !== "TensorFlow" && data.framework !== "Keras"}
            />

            <TooltipWrapper tooltipText="Upload the python file with the model definition">
              <Typography>
                <span
                  className={`${
                    !requiredClassDefinition
                      ? styles.requiredHidden
                      : styles.required
                  }`}
                >
                  *
                </span>
                Model Definition
              </Typography>
            </TooltipWrapper>
            <UploadButton
              updateFunction={(value) =>
                handleSelectChange("model_definition_name", value)
              }
              selectedValue={data.model_definition_name || ""}
              onUpload={(file) => handleUpload("model_definition", file)}
              selection={false}
              disabledState={
                disableOptLearningClassDefinition ||
                selectedExistingModel !== ""
              }
            />
            <div className={styles.errorMessageDefinition}>
              {errors.model_definition}
            </div>

            <Typography>
              <span
                className={`${
                  !requiredLearningOpt ? styles.requiredHidden : styles.required
                }`}
              >
                *
              </span>
              Optimizer
            </Typography>
            <SelectButton
              className={styles.optimizer}
              listClassName={styles.optimizerList}
              data={getFilteredFrameworks()}
              selectedValue={data.optimizer}
              updateFunction={(value) => handleSelectChange("optimizer", value)}
              isUpload={false}
              disabledState={disableOptLearningClassDefinition}
            />
            <div className={styles.lossPlaceholder}></div>
            <MainButton
              label="Loss"
              theme="blue"
              onClick={() => {
                onLossPressed();
              }}
              required={requiredLoss}
              className={styles.lossButton}
              disabled={disableLoss || selectedExistingModel !== ""}
            />
          </Box>

          <Box className={styles.column}>
            <div className={styles.placeholder}></div>
            <TooltipWrapper tooltipText="The name of the model class in the defenition file">
              <Typography>
                <span
                  className={`${
                    !requiredClassDefinition
                      ? styles.requiredHidden
                      : styles.required
                  }`}
                >
                  *
                </span>
                Class Name
              </Typography>
            </TooltipWrapper>
            <input
              type="text"
              value={data.class_name}
              onChange={(e) =>
                handleTextFieldChange("class_name", e.target.value)
              }
              className={`${styles.textFiled} ${
                disableOptLearningClassDefinition ? styles.disableInput : ""
              } ${errors.class_name ? styles.errorBorder : ""}`}
              disabled={disableOptLearningClassDefinition}
            />

            <Typography>
              <span
                className={`${
                  !requiredLearningOpt ? styles.requiredHidden : styles.required
                }`}
              >
                *
              </span>
              Learning Rate
            </Typography>
            <input
              type="text"
              value={data.learning_rate}
              onChange={(e) =>
                handleTextFieldChange("learning_rate", e.target.value)
              }
              className={`${styles.textFiled} ${
                disableOptLearningClassDefinition ? styles.disableInput : ""
              } ${errors.learning_rate ? styles.errorBorder : ""}`}
              disabled={disableOptLearningClassDefinition}
            />
            <div className={styles.errorMessageLearning}>
              {errors.learning_rate}
            </div>
          </Box>
          <div className={styles.columnBorder}></div>
          <Box className={styles.column}>
            <TooltipWrapper tooltipText="The number of classes your model predicts">
              <Typography>
                <span className={`${styles.required}`}>*</span>Number of Classes
              </Typography>
            </TooltipWrapper>
            <input
              type="text"
              value={data.num_of_classes}
              onChange={(e) =>
                handleTextFieldChange("num_of_classes", e.target.value)
              }
              className={`${styles.textFiled} ${
                errors.num_of_classes ? styles.errorBorder : ""
              }`}
            />
            <div className={styles.errorMessage}>{errors.num_of_classes}</div>
            <TooltipWrapper tooltipText="The min and max values of the model's inputs (used to enforce realistic attacks)">
              <Typography>
                <span
                  className={`${
                    !requiredClipValues
                      ? styles.requiredHidden
                      : styles.required
                  }`}
                >
                  *
                </span>
                Clip Values
              </Typography>
            </TooltipWrapper>
            <input
              type="text"
              value={data.clip_values}
              onChange={(e) =>
                handleTextFieldChange("clip_values", e.target.value)
              }
              className={`${styles.textFiled} ${
                errors.clip_values ? styles.errorBorder : ""
              }`}
            />
            {errors.clip_values && (
              <div className={styles.errorMessage}>{errors.clip_values}</div>
            )}
          </Box>
        </Box>
      </DialogContent>
      <Box className={styles.dialogActionsContainer}>
        <DialogActions>
          <MainButton
            onClick={onSave}
            theme="blue"
            label="Save"
            className={styles.saveButton}
            isLoading={loading}
            disabled={disableSave || loading}
          ></MainButton>
        </DialogActions>
        <div className={styles.errorMessageMain}>{errorMessage}</div>
      </Box>
    </Dialog>
  );
};

export default ModelModal;
