import type { Queue } from "bullmq";
import { JSONCodec } from "nats.ws";
import { z } from "zod";

import {
  AmdGpuInfoSchema,
  GpuTypeSchema,
  MacSmiSchema,
  NvidiaSmiSchema,
} from "../lib/gpu.models";
import {
  ChatCompletionChunkSchema,
  CreateChatCompletionRequestSchema,
} from "../lib/openai.models";
import { WorkerSchema } from "../schema/WorkerSchema";
import {
  authedServiceRequest,
  serviceRequest,
  serviceResponse,
} from "./BaseService";
import { BenchmarkSchema } from "./benchmark.models";
import {
  GenerationLanes,
  generationResult,
  InferenceStateSchema,
} from "./generation.models";
import { WorkerConfigSchema } from "./worker-config.models";

export const InstanceStatusSchema = z.enum([
  "Initializing",
  "Running",
  "Paused",
  "Expired",
]);

export const InferenceMessageTypeEnum = z.enum([
  "LOCK",
  "CHUNK",
  "SUCCESS",
  "ERROR",
  "FINAL-RESULT",
  "PING",
]);

export const InstanceLocationSchema = z.object({
  country: z.string().nullish(),
  city: z.string().nullish(),
  continent: z.string().nullish(),
  latitude: z.number().nullish(),
  longitude: z.number().nullish(),
  postalCode: z.string().nullish(),
  metroCode: z.string().nullish(),
  region: z.string().nullish(),
  regionCode: z.string().nullish(),
  timezone: z.string().nullish(),
});

export type InstanceLocation = z.infer<typeof InstanceLocationSchema>;

export const InstanceInfoSchema = z.object({
  ipAddress: z.string(),
  location: InstanceLocationSchema.nullish(),
  arch: z.string().nullable(),
  platform: z.string().nullable(),
  type: z.string().nullable(),
  totalMemoryBytes: z.number().nullable(),
  totalSwapBytes: z.number().nullable(),
  systemName: z.string(),
  kernelVersion: z.string().nullable(),
  osVersion: z.string().nullable(),
  hostName: z.string().nullable(),
  cpus: z.number().nullable(),
  version: z.string().nullish(),
  nvidiasmi: NvidiaSmiSchema.nullish(),
  amdsmi: AmdGpuInfoSchema.nullish(),
  macsmi: MacSmiSchema.nullish(),
  isSystemd: z.boolean().nullable(),
});

export const InstanceLocationInfoSchema = z.object({
  instanceid: z.string(),
  location: InstanceInfoSchema.shape.location.optional(),
  device: z.string().nullable(),
});

export type InstanceLocationInfo = z.infer<typeof InstanceLocationInfoSchema>;

export enum ThroughputRange {
  TIER_1 = "0-10",
  TIER_2 = "10-20",
  TIER_3 = "20-30",
  TIER_4 = "30-40",
  TIER_5 = "40-50",
  TIER_6 = "50-60",
  TIER_7 = "60-70",
  TIER_8 = "70-80",
  TIER_9 = "80-90",
  TIER_10 = "90+",
}

export const InstancePoolAssignmentSchema = z.object({
  workerId: z.string().nullable(),
  instanceId: z.string().nullable(),
  model: z.string().nullable(),
  throughput: z.nativeEnum(ThroughputRange).nullable(),
  lane: z.nativeEnum(GenerationLanes).nullable(),
});

export const InstanceSchema = z.object({
  _id: z.string(),
  userId: z.string(),
  teamId: z.string(),
  workerId: z.string(),
  name: z.string(),
  status: InstanceStatusSchema,
  info: InstanceInfoSchema,
  poolAssignments: z.array(InstancePoolAssignmentSchema),
  createdAt: z.date(),
  updatedAt: z.date(),
  expiredAt: z.date().nullable(),
  lastHeartbeatAt: z.date().nullable(),
});

export const SubscriptionTopicSchema = z.object({
  subject: z.string(),
  streamName: z.string(),

  // Max wait time in milliseconds for the consumer to wait for a message before proceeding to the next topic
  maxWaitTimeInMs: z.number(),

  // How long to wait for the consumer to ack a message before sending another
  maxAckPendingInMs: z.number(),

  consumerName: z.string(),
  isDurable: z.boolean().optional(),
  ackRequired: z.boolean().optional(),

  options: z.object({
    queue: z.string().optional(),
  }),
});

export const MAX_INSTANCES = 50;

export type Instance = z.infer<typeof InstanceSchema>;
export type InstanceInfo = z.infer<typeof InstanceInfoSchema>;
export type InstanceStatus = z.infer<typeof InstanceStatusSchema>;
export type InstancePoolAssignment = z.infer<
  typeof InstancePoolAssignmentSchema
>;
export type SubscriptionTopic = z.infer<typeof SubscriptionTopicSchema>;

/** ******************************************************************************
 *  Instance Service
 ******************************************************************************* */
export type InstanceService = {
  expirationQueue: Queue;
  get(request: GetInstanceRequest): Promise<GetInstanceResponse>;
  list(request: ListInstancesRequest): Promise<ListInstancesResponse>;
  count(request: CountInstancesRequest): Promise<CountInstancesResponse>;
  locations(
    request: InstanceLocationsRequest,
  ): Promise<InstanceLocationsResponse>;
  hardware(request: InstanceHardwareRequest): Promise<InstanceHardwareResponse>;
  calibrate(
    request: CalibrateInstanceRequest,
  ): Promise<CalibrateInstanceResponse>;
  streamInference(
    params: InferenceInstanceRequest,
  ): AsyncGenerator<InferenceInstanceMessage>;
  submitAsyncInference(
    request: InferenceInstanceSubmitRequest,
  ): Promise<InferenceInstanceSubmitResponse>;
  initialize(
    request: InitializeInstanceRequest,
  ): Promise<InitializeInstanceResponse>;
  ready(request: ReadyInstanceRequest): Promise<ReadyInstanceResponse>;
  heartbeat(
    request: HeartbeatInstanceRequest,
  ): Promise<HeartbeatInstanceResponse>;
  cleanupUnhealthyInstances(): Promise<void>;
  ensureInstanceNatsConsumers(instance: Instance): Promise<void>;
};

/** ******************************************************************************
 *  Get Instance
 ******************************************************************************* */

export const getInstanceParams = z.object({
  instanceId: z.string(),
});

export const getInstanceRequest = authedServiceRequest.merge(
  z.object({
    params: getInstanceParams,
  }),
);

export const getInstanceResponse = serviceResponse.merge(
  z.object({
    instance: InstanceSchema.nullable().optional(),
  }),
);

export type GetInstanceParams = z.infer<typeof getInstanceParams>;
export type GetInstanceRequest = z.infer<typeof getInstanceRequest>;
export type GetInstanceResponse = z.infer<typeof getInstanceResponse>;

/** ******************************************************************************
 *  List Instances
 ******************************************************************************* */

export const listInstancesParams = z.object({
  workerId: z.string(),
});

export const listInstancesRequest = authedServiceRequest.merge(
  z.object({
    params: listInstancesParams,
  }),
);

export const listInstancesResponse = serviceResponse.merge(
  z.object({
    instances: z.array(InstanceSchema).optional(),
  }),
);

export type ListInstancesParams = z.infer<typeof listInstancesParams>;
export type ListInstancesRequest = z.infer<typeof listInstancesRequest>;
export type ListInstancesResponse = z.infer<typeof listInstancesResponse>;

/** ******************************************************************************
 *  Count Instances
 ******************************************************************************* */

export const countInstancesParams = z.undefined();

export const countInstancesRequest = serviceRequest.merge(
  z.object({
    params: countInstancesParams,
  }),
);

export const countInstancesResponse = serviceResponse.merge(
  z.object({
    count: z.number().optional(),
  }),
);

export type CountInstancesParams = z.infer<typeof countInstancesParams>;
export type CountInstancesRequest = z.infer<typeof countInstancesRequest>;
export type CountInstancesResponse = z.infer<typeof countInstancesResponse>;

/** ******************************************************************************
 *  Instance Locations
 ******************************************************************************* */

export const instanceLocationsParams = z.undefined();

export const instanceLocationsRequest = serviceRequest.merge(
  z.object({
    params: instanceLocationsParams,
  }),
);

export const instanceLocationsResponse = serviceResponse.merge(
  z.object({
    locations: z.array(InstanceLocationInfoSchema).nullable(),
  }),
);

export type InstanceLocationsParams = z.infer<typeof instanceLocationsParams>;
export type InstanceLocationsRequest = z.infer<typeof instanceLocationsRequest>;
export type InstanceLocationsResponse = z.infer<
  typeof instanceLocationsResponse
>;

/** ******************************************************************************
 *  Instance Locations
 ******************************************************************************* */

export const instanceHardwareRow = z.object({
  gpu: z.string(),
  count: z.number(),
  percentage: z.number(),
  totalVRAM: z.string().nullable(),
  type: GpuTypeSchema,
});

export const instanceHardwareParams = z.undefined();

export const instanceHardwareRequest = serviceRequest.merge(
  z.object({
    params: instanceHardwareParams,
  }),
);

export const instanceHardwareResponse = serviceResponse.merge(
  z.object({
    gpus: z.array(instanceHardwareRow).optional(),
    totalCount: z.number().optional(),
    totalVRAM: z.string().optional(),
  }),
);

export type InstanceHardwareRow = z.infer<typeof instanceHardwareRow>;
export type InstanceHardwareParams = z.infer<typeof instanceHardwareParams>;
export type InstanceHardwareRequest = z.infer<typeof instanceHardwareRequest>;
export type InstanceHardwareResponse = z.infer<typeof instanceHardwareResponse>;

/** ******************************************************************************
 *  Calibrate Instance
 ******************************************************************************* */

export const calibrateInstanceParams = z.object({
  instanceId: z.string(),
  benchmark: BenchmarkSchema,
});

export const calibrateInstanceRequest = serviceRequest.merge(
  z.object({
    params: calibrateInstanceParams,
  }),
);

export const calibrateInstanceResponse = serviceResponse.merge(
  z.object({
    instance: InstanceSchema.nullable().optional(),
  }),
);

export type CalibrateInstanceParams = z.infer<typeof calibrateInstanceParams>;
export type CalibrateInstanceRequest = z.infer<typeof calibrateInstanceRequest>;
export type CalibrateInstanceResponse = z.infer<
  typeof calibrateInstanceResponse
>;

/** ******************************************************************************
 *  Initialize Instance
 ******************************************************************************* */

export const initializeInstanceParams = z.object({
  name: z.string().nullable(),
  workerId: z.string(),
  instanceId: z.string().nullable(),
  registrationCode: z.string(),
  info: InstanceInfoSchema,
});

export const initializeInstanceRequest = serviceRequest.merge(
  z.object({
    params: initializeInstanceParams,
  }),
);

export const initializeInstanceResponse = serviceResponse.merge(
  z.object({
    worker: WorkerSchema.optional(),
    instance: InstanceSchema.optional(),
    subscriptionTopics: z.array(SubscriptionTopicSchema).optional(),
  }),
);

export type InitializeInstanceParams = z.infer<typeof initializeInstanceParams>;
export type InitializeInstanceRequest = z.infer<
  typeof initializeInstanceRequest
>;
export type InitializeInstanceResponse = z.infer<
  typeof initializeInstanceResponse
>;

export const InitializeInstanceRequestCodec =
  JSONCodec<InitializeInstanceRequest>();
export const InitializeInstanceResponseCodec =
  JSONCodec<InitializeInstanceResponse>();

/** ******************************************************************************
 *  Ready Instance
 ******************************************************************************* */

export const readyInstanceParams = z.object({
  workerId: z.string(),
  instanceId: z.string(),
});

export const readyInstanceRequest = serviceRequest.merge(
  z.object({
    params: readyInstanceParams,
  }),
);

export const readyInstanceResponse = serviceResponse.merge(
  z.object({
    instance: InstanceSchema.optional(),
    subscriptionTopics: z.array(SubscriptionTopicSchema).optional(),
  }),
);

export type ReadyInstanceParams = z.infer<typeof readyInstanceParams>;
export type ReadyInstanceRequest = z.infer<typeof readyInstanceRequest>;
export type ReadyInstanceResponse = z.infer<typeof readyInstanceResponse>;

export const ReadyInstanceRequestCodec = JSONCodec<ReadyInstanceRequest>();
export const ReadyInstanceResponseCodec = JSONCodec<ReadyInstanceResponse>();

/** ******************************************************************************
 *  Heartbeat Instance
 ******************************************************************************* */

export const heartbeatInstanceParams = z.object({
  workerId: z.string(),
  instanceId: z.string(),
  info: z.union([z.instanceof(Uint8Array), z.null()]),
});

export const heartbeatInstanceRequest = serviceRequest.merge(
  z.object({
    params: heartbeatInstanceParams,
  }),
);

export const heartbeatInstanceResponse = serviceResponse.merge(
  z.object({
    worker: z.any().optional(),
    instance: InstanceSchema.optional(),
    subscriptionTopics: z.array(SubscriptionTopicSchema).optional(),

    workerConfig: WorkerConfigSchema.optional().nullable(),

    expectedCliVersion: z.string().optional().nullable(),
  }),
);

export type HeartbeatInstanceParams = z.infer<typeof heartbeatInstanceParams>;
export type HeartbeatInstanceRequest = z.infer<typeof heartbeatInstanceRequest>;
export type HeartbeatInstanceResponse = z.infer<
  typeof heartbeatInstanceResponse
>;

export const HeartbeatInstanceRequestCodec =
  JSONCodec<HeartbeatInstanceRequest>();

type InstanceWithoutInfo = Omit<Instance, "info"> & { info?: null };

type HeartbeatInstanceResponseWithoutInfo = Omit<
  HeartbeatInstanceResponse,
  "instance"
> & {
  instance?: InstanceWithoutInfo | null;
};

export const HeartbeatInstanceResponseCodec = {
  encode: (data: HeartbeatInstanceResponse): Uint8Array => {
    const cleanedData: HeartbeatInstanceResponseWithoutInfo = {
      ...data,
      instance: data.instance ? { ...data.instance, info: undefined } : null,
    };
    return JSONCodec<HeartbeatInstanceResponse>().encode(
      cleanedData as HeartbeatInstanceResponse,
    );
  },
  decode: (data: Uint8Array): HeartbeatInstanceResponse => {
    const decodedData = JSONCodec<HeartbeatInstanceResponse>().decode(
      data,
    ) as HeartbeatInstanceResponseWithoutInfo;
    if (decodedData.instance) {
      decodedData.instance.info = null;
    }
    return decodedData as HeartbeatInstanceResponse;
  },
};

/** ******************************************************************************
 *  Inference Instance
 ******************************************************************************* */

export const inferenceInstanceParams = z.object({
  workerId: z.string().nullish(),
  instanceId: z.string().nullish(),
  generationId: z.string(),
  lane: z.nativeEnum(GenerationLanes).nullable().optional(),
  input: CreateChatCompletionRequestSchema,
});

export const inferenceInstanceRequest = authedServiceRequest.merge(
  z.object({
    params: inferenceInstanceParams,
  }),
);

export const inferenceInstanceMessage = z.object({
  generationId: z.string(),
  instance: InstanceSchema.nullable(),
  dispatchedAt: z.number().nullable(),
  finishedAt: z.number().nullable(),
  state: InferenceStateSchema,
  stateMessage: z.string(),
  result: generationResult.nullable(),
  chunk: ChatCompletionChunkSchema.nullish(),
  messageType: InferenceMessageTypeEnum,
});

export type InferenceInstanceParams = z.infer<typeof inferenceInstanceParams>;
export type InferenceInstanceRequest = z.infer<typeof inferenceInstanceRequest>;
export type InferenceInstanceMessage = z.infer<typeof inferenceInstanceMessage>;

/**
 * When submitting an inference, we return the generationId immediately.
 * The inference will be dispatched to a worker and the result will be stored
 * in the database. The result can be retrieved via the API
 */
export const inferenceInstanceSubmitResponse = serviceResponse.merge(
  z.object({
    generationId: z.string(),
  }),
);

export const inferenceInstanceSubmitParams = z.object({
  generationId: z.string(),
  input: CreateChatCompletionRequestSchema,
  lane: z.nativeEnum(GenerationLanes).nullable().optional(),
});

export const inferenceInstanceSubmitRequest = serviceRequest.merge(
  z.object({
    params: inferenceInstanceSubmitParams,
  }),
);

export type InferenceInstanceSubmitParams = z.infer<
  typeof inferenceInstanceSubmitParams
>;
export type InferenceInstanceSubmitResponse = z.infer<
  typeof inferenceInstanceSubmitResponse
>;
export type InferenceInstanceSubmitRequest = z.infer<
  typeof inferenceInstanceSubmitRequest
>;

/** ******************************************************************************
 *  Dispatch Instance Inference
 ******************************************************************************* */

export const dispatchInstanceInferenceParams = z.object({
  input: CreateChatCompletionRequestSchema,
  id: z.string(),
});

export const dispatchInstanceInferenceResponse = z.object({
  workerId: z.string(),
  instanceId: z.string(),
  msgType: InferenceMessageTypeEnum,
  result: generationResult.nullable(),
  chunk: ChatCompletionChunkSchema.nullish(),
  id: z.string(),
  dispatchedAt: z.number().nullable().optional(),
  finishedAt: z.number().nullable().optional(),
  firstChunkAt: z.number().nullable().optional(),
});

export type DispatchInstanceInferenceParams = z.infer<
  typeof dispatchInstanceInferenceParams
>;
export type DispatchInstanceInferenceResponse = z.infer<
  typeof dispatchInstanceInferenceResponse
>;

export const DispatchInstanceInferenceParamsCodec =
  JSONCodec<DispatchInstanceInferenceParams>();
export const DispatchInstanceInferenceResponseCodec =
  JSONCodec<DispatchInstanceInferenceResponse>();

/**
 * When an inference fails, we publish a message to the inference failure handler stream
 */
export type FailedInstanceInference = {
  generationId: string;
  workerId: string | null;
  instanceId: string | null;
  errorMessage: string;
};

export const FailedInstanceInferenceCodec =
  JSONCodec<FailedInstanceInference>();
