import * as React from 'react';
import {
  Button,
  CircularProgress,
  FormControl,
  Grid,
  GridProps,
  OutlinedInput,
} from '@mui/material';
import MODEL_TYPE, { isSyntheticCategoryModelType } from 'constants/modelType';
import { RecordNew, Tag, Transform } from 'icons/figma';
import { useCreateRecordHandlerMutation } from 'src/api';
import { Model, RunnerMode } from 'src/api/types/Model';
import DataModal from '../DataModal';

type ActionItemProps = GridProps & {
  model: Model;
};

const ActionItem = ({ model, ...gridProps }: ActionItemProps) => {
  const [createRecordHandler, createRecordHandlerStatus] =
    useCreateRecordHandlerMutation();

  // gpt_x records are more expensive than those of other model types, so we use a
  // significantly lower default (100 vs 5000).
  const [numRecords, setNumRecords] = React.useState(
    model.model_type === MODEL_TYPE.GPT_X ? 100 : 5000
  );
  const [showClassifyModal, setShowClassifyModal] = React.useState(false);
  const [showTransformModal, setShowTransformModal] = React.useState(false);

  const handleGenerateSyntheticData = React.useCallback(
    (event: React.FormEvent<HTMLFormElement>) => {
      event.preventDefault();
      createRecordHandler({
        guid: model?.project_guid,
        modelId: model?.uid,
        runner_mode: model?.runner_mode || RunnerMode.CLOUD,
        params: {
          num_records: numRecords,
        },
      });
    },
    [model, numRecords, createRecordHandler]
  );

  const handleClassify = React.useCallback(
    (dataSource?: string) => {
      setShowClassifyModal(false);
      if (dataSource) {
        createRecordHandler({
          guid: model?.project_guid,
          modelId: model?.uid,
          runner_mode: model?.runner_mode || RunnerMode.CLOUD,
          data_source: dataSource,
        });
      }
    },
    [model, createRecordHandler]
  );

  const handleTransform = React.useCallback(
    (dataSource?: string) => {
      setShowTransformModal(false);
      if (dataSource) {
        createRecordHandler({
          guid: model?.project_guid,
          modelId: model?.uid,
          runner_mode: model?.runner_mode || RunnerMode.CLOUD,
          data_source: dataSource,
        });
      }
    },
    [model, createRecordHandler]
  );

  if (isSyntheticCategoryModelType(model.model_type)) {
    return (
      <Grid container {...gridProps}>
        <form onSubmit={handleGenerateSyntheticData}>
          <FormControl disabled={createRecordHandlerStatus.isLoading}>
            <OutlinedInput
              aria-label="Number of records"
              sx={theme => ({
                mr: 2,
                '& input': {
                  height: theme.spacing(6),
                },
              })}
              value={numRecords}
              onChange={({ target: { value } }) => {
                // Prevent NaN from being assigned
                if (isNaN(parseInt(value))) {
                  return setNumRecords(0);
                }
                return setNumRecords(parseInt(value));
              }}
            />
          </FormControl>
          <Button
            id="ModelHeader_Generate"
            color="success"
            type="submit"
            endIcon={
              createRecordHandlerStatus.isLoading ? (
                <CircularProgress size={16} />
              ) : (
                <RecordNew />
              )
            }
            disabled={createRecordHandlerStatus.isLoading}
          >
            Generate
          </Button>
        </form>
      </Grid>
    );
  } else if (model.model_type === MODEL_TYPE.CLASSIFY) {
    return (
      <Grid container {...gridProps}>
        <Button
          color="success"
          onClick={() => setShowClassifyModal(true)}
          endIcon={
            createRecordHandlerStatus.isLoading ? (
              <CircularProgress size={16} />
            ) : (
              <Tag />
            )
          }
          disabled={createRecordHandlerStatus.isLoading}
        >
          Run Classify
        </Button>
        <DataModal
          title="Classify data"
          description="Automatically discover and label sensitive data types, such as Personally Identifiable Information (PII)."
          open={showClassifyModal}
          onClose={handleClassify}
        />
      </Grid>
    );
  } else if (
    [MODEL_TYPE.TRANSFORM, MODEL_TYPE.TRANSFORM_V2].includes(model.model_type)
  ) {
    return (
      <Grid container {...gridProps}>
        <Button
          color="success"
          onClick={() => setShowTransformModal(true)}
          endIcon={
            createRecordHandlerStatus.isLoading ? (
              <CircularProgress size={16} />
            ) : (
              <Transform />
            )
          }
          disabled={createRecordHandlerStatus.isLoading}
        >
          Run Transform
        </Button>
        <DataModal
          title="Transform data"
          description="Automatically label data and perform privacy-preserving transformations on a dataset."
          open={showTransformModal}
          onClose={handleTransform}
        />
      </Grid>
    );
  } else {
    return null;
  }
};

export default ActionItem;
