import type {
  DocumentNode,
  GraphQLSchema,
  GraphQLType,
  ValueNode,
  VariableDefinitionNode,
} from 'graphql';

import {
  visit,
  visitWithTypeInfo,
  ValidationContext,
  TypeInfo,
  isNonNullType,
  isTypeSubTypeOf,
  Kind,
  typeFromAST,
  parseType,
} from 'graphql';

/**
 * Returns true if the variable is allowed in the location it was found,
 * which includes considering if default values exist for either the variable
 * or the location at which it is located.
 */
function allowedVariableUsage(
  schema: GraphQLSchema,
  varType: GraphQLType,
  varDefaultValue: ValueNode | undefined,
  locationType: GraphQLType,
  locationDefaultValue?: any,
): boolean {
  if (isNonNullType(locationType) && !isNonNullType(varType)) {
    const hasNonNullVariableDefaultValue =
      varDefaultValue != null && varDefaultValue.kind !== Kind.NULL;
    const hasLocationDefaultValue = locationDefaultValue !== undefined;
    if (!hasNonNullVariableDefaultValue && !hasLocationDefaultValue) {
      return false;
    }
    const nullableLocationType = locationType.ofType;
    return isTypeSubTypeOf(schema, varType, nullableLocationType);
  }
  return isTypeSubTypeOf(schema, varType, locationType);
}

export const replaceAutoTypes = (
  schema: GraphQLSchema,
  documentAST: DocumentNode,
): DocumentNode => {
  const typeInfo = new TypeInfo(schema);

  const abortObj = Object.freeze({});
  const context = new ValidationContext(schema, documentAST, typeInfo, () => {
    throw abortObj;
  });

  let varDefMap: Record<string, VariableDefinitionNode> = Object.create(null);
  return visit(
    documentAST,
    visitWithTypeInfo(typeInfo, {
      OperationDefinition: {
        enter() {
          varDefMap = Object.create(null);
        },
        leave(operation) {
          const usages = context.getRecursiveVariableUsages(operation);

          for (const { node, type, defaultValue } of usages) {
            const varName = node.name.value;
            const varDef = varDefMap[varName];
            if (varDef && type) {
              const schema = context.getSchema();

              if (varDef.type.kind === Kind.NAMED_TYPE && varDef.type.name.value === 'auto') {
                const varType = typeFromAST(schema, varDef.type);
                if (
                  varType &&
                  !allowedVariableUsage(schema, varType, varDef.defaultValue, type, defaultValue)
                ) {
                  const typeStr = type.toString();
                  (varDef as any).type = parseType(typeStr);
                }
              }
            }
          }
        },
      },
      VariableDefinition(node) {
        varDefMap[node.variable.name.value] = node;
      },
    }),
  );
};
