import { ReactNode, useCallback, useMemo } from "react";
import {
  Box,
  IconButton,
  IconButtonProps,
  styled,
  unstable_composeClasses as composeClasses,
} from "@mui/material";
import {
  DataGridProProps,
  getDataGridUtilityClass,
  GridRenderCellParams,
  useGridApiContext,
  useGridRootProps,
} from "@mui/x-data-grid-pro";
import { TreeCollapseIcon } from "assets/icons/treeCollapseIcon";
import { TreeExpandIcon } from "assets/icons/treeExpandIcon";

export interface WrapperTypes {
  depth: number,
  hasChildren?: boolean,
}

const slots = {
  root: ["treeDataGroupingCell"],
  toggle: ["treeDataGroupingCellToggle"],
};

const useUtilityClasses = (ownerState: { classes: DataGridProProps["classes"] }) => {
  const { classes } = ownerState;
  return useMemo(() => composeClasses(slots, getDataGridUtilityClass, classes), [classes]);
};

export type GroupingColDefProps = GridRenderCellParams & {
  renderInnerCell?: (props: GridRenderCellParams) => ReactNode;
  showExpandIcon?: boolean,
};

function isNavigationKey(key: string) {
  return (
    key === "Home" || key === "End" || key.indexOf("Arrow") === 0 || key.indexOf("Page") === 0 || key === " "
  );
}

export function GroupingColDef(props: GroupingColDefProps) {
  const { field, formattedValue, id, renderInnerCell, row, rowNode, showExpandIcon = true } = props;

  const apiRef = useGridApiContext();
  const rootProps = useGridRootProps();
  const classes = useUtilityClasses({ classes: rootProps?.classes });

  const { _formattedValue, Icon } = useMemo(() => ({
    _formattedValue: formattedValue ?? rowNode.groupingKey,
    Icon: rowNode?.childrenExpanded ? TreeCollapseIcon : TreeExpandIcon,
  }), [formattedValue, rowNode?.childrenExpanded, rowNode.groupingKey]);

  const handleKeyDown: IconButtonProps["onKeyDown"] = useCallback(event => {
    if (event.key === " ") {
      event.stopPropagation();
    }
    if (isNavigationKey(event.key) && !event.shiftKey) {
      apiRef.current.publishEvent("cellNavigationKeyDown", props, event);
    }
  }, [apiRef, props]);

  const handleClick: IconButtonProps["onClick"] = useCallback(event => {
    apiRef.current.setRowChildrenExpansion(id, !rowNode.childrenExpanded);
    apiRef.current.setCellFocus(id, field);
    event.stopPropagation();
  }, [apiRef, field, id, rowNode.childrenExpanded]);

  return (
    <Wrapper className={classes.root} depth={rowNode.depth} hasChildren={!!row._descendantCount}>
      <IconWrapper>
        {row._descendantCount > 0 && showExpandIcon && (
          <CustomIconButton
            onClick={handleClick}
            onKeyDown={handleKeyDown}
            aria-label={
              apiRef.current.getLocaleText(
                rowNode.childrenExpanded ? "treeDataCollapse" : "treeDataExpand",
              )
            }
          >
            {
              <Icon />
            }
          </CustomIconButton>
        )}
      </IconWrapper>
      {renderInnerCell && renderInnerCell(props)}
      {!renderInnerCell && (
        <span>
          {_formattedValue}
        </span>
      )}
    </Wrapper>
  );
}

const CustomIconButton = styled(IconButton)({
  padding: "0",
});

const IconWrapper = styled(Box)({
  display: "flex",
  flex: "0 0 1rem",
  marginRight: "0.313rem",
});

const WRAPPER_REMOVE_PROPS: string[] = ["hasChildren"];
const Wrapper = styled(Box, {
  shouldForwardProp: prop => !WRAPPER_REMOVE_PROPS.includes(prop as string),
})<WrapperTypes>(({ depth, hasChildren = false }) => ({
  marginLeft: `${hasChildren ? depth * 1.125 : depth * 1.45}rem`,
}));
