import { z } from "zod";

import {
  CompletionUsageSchema,
  CreateChatCompletionRequestSchema,
  CreateChatCompletionResponseSchema,
} from "../lib/openai.models";
import {
  authedServiceRequest,
  serviceRequest,
  serviceResponse,
} from "./BaseService";

export const InferenceStateSchema = z.enum([
  "Queued",
  "In Progress",
  "Success",
  "Failed",
]);

export enum GenerationLanes {
  FAST = "fast",
  SLOW = "slow",
}

export const NATS_INFERENCE_STREAM_NAME_PREFIX = "inference";
export const NATS_FAST_INFERENCE_STREAM_NAME = "inference-fast";
export const NATS_BENCHMARKING_INFERENCE_STREAM_NAME = "inference-benchmarking";
export const NATS_SLOW_INFERENCE_STREAM_NAME = "inference-slow";
export const NATS_INFERENCE_RESULT_STREAM_NAME = "inference-result";
export const NATS_INFERENCE_FAILURE_HANDLER_STREAM_NAME =
  "inference-failure-handler";

export const GenerationSchema = z.object({
  _id: z.string(),
  userId: z.string(),
  workerId: z.string().nullable(),
  workerUserId: z.string().nullable(),
  workerTeamId: z.string().nullable(),
  instanceId: z.string().nullable(),
  createdAt: z.date(),
  updatedAt: z.date(),
  dispatchedAt: z.date().nullable(),
  firstChunkAt: z.date().nullable(),
  finishedAt: z.date().nullable(),
  state: InferenceStateSchema,
  stateMessage: z.string(),
  // TODO[@sean]: Make this more consistent once the migration is complete.
  model: z.string().optional(),
  request: CreateChatCompletionRequestSchema.nullable(),
  response: CreateChatCompletionResponseSchema.nullable(),
  usage: CompletionUsageSchema.partial(),
  isSystem: z.boolean().optional(),
  lane: z.nativeEnum(GenerationLanes).default(GenerationLanes.FAST).nullish(),
  version: z.number().default(1).nullable(),
});

export type Generation = z.infer<typeof GenerationSchema>;
export type InferenceState = z.infer<typeof InferenceStateSchema>;

export const GenerationRequestResponseSchema = z.object({
  generationId: z.string(),
  request: CreateChatCompletionRequestSchema.nullable(),
  response: CreateChatCompletionResponseSchema.nullable(),
});

export type GenerationRequestResponse = z.infer<
  typeof GenerationRequestResponseSchema
>;

export type CreateGenerationResult = {
  generation: Generation;
};

/** ******************************************************************************
 *  #handleCreateGeneration
 ******************************************************************************* */

export type HandleCreationGenerationFailure = {
  response: CreateGenerationResponse;
  success: false;
};

export type HandleCreationGenerationSuccess = {
  success: true;
  response: CreateGenerationResponse;
  generation: Generation;
};

export type HandleCreationGenerationResponse =
  | HandleCreationGenerationFailure
  | HandleCreationGenerationSuccess;

export type GenerationService = {
  create(
    request: CreateGenerationRequest,
  ): Promise<HandleCreationGenerationResponse>;
  getResult(
    request: GetGenerationResultRequest,
  ): Promise<GetGenerationResultResponse>;
  list(request: ListGenerationsRequest): Promise<ListGenerationsResponse>;
  listByWorker(
    request: ListGenerationsByWorkerRequest,
  ): Promise<ListGenerationsByWorkerResponse>;
  count(request: CountGenerationsRequest): Promise<CountGenerationsResponse>;
  rpm(request: GenerationRPMRequest): Promise<GenerationRPMResponse>;
};

export type CreateGenerationInStoreParams = {
  attributes: Generation;
  // TODO[@sean]: Should this be required for all inference requests?
  apiKey: string | undefined;
  userId: string;
  maxTokens: number | undefined | null;
};

export type GenerationStorageInterface = {
  create(
    params: CreateGenerationInStoreParams,
  ): Promise<CreateGenerationResult>;
  deleteById(_id: string): Promise<void>;
  byId(_id: string, usePrimary?: boolean): Promise<Generation | null>;
  get(teamId: string, _id: string): Promise<Generation | null>;
  findUserIdByGenerationId(_id: string): Promise<string | null>;
  list(teamId: string, params: ListGenerationsParams): Promise<Generation[]>;
  listByWorker(
    teamId: string,
    params: ListGenerationsByWorkerParams,
  ): Promise<Omit<Generation, "request.messages" | "response.choices">[]>;
  count(params: {
    instanceId?: string;
    userId?: string;
    workerId?: string;
    workerUserId?: string;
  }): Promise<number>;
  countSuccessfulByUser(teamId: string): Promise<number>;
  countByUser(teamId: string): Promise<number>;
  countByState(state: InferenceState): Promise<number>;
  countRpm(): Promise<number>;
  update(_id: string, params: Partial<Generation>): Promise<Generation | null>;
  finalizeSuccessfulGeneration(
    id: string,
    params: Partial<Generation>,
  ): Promise<void>;
  finalizeFailedGeneration(
    id: string,
    params: Partial<Generation>,
  ): Promise<void>;
  finalizeAbandonedGeneration(id: string): Promise<void>;
};

/** ******************************************************************************
 *  Create Generation
 ******************************************************************************* */

export const createGenerationParams = z.object({
  workerId: z.string().optional(),
  instanceId: z.string().optional(),
  input: CreateChatCompletionRequestSchema,
  isSystem: z.boolean().optional(),
  lane: z.nativeEnum(GenerationLanes).default(GenerationLanes.FAST).nullish(),
});

export const createGenerationParamsNA = z.object({
  input: CreateChatCompletionRequestSchema,
  userId: z.string(),
});

export const createGenerationRequest = authedServiceRequest.merge(
  z.object({
    params: createGenerationParams,
  }),
);

export const createGenerationRequestNA = serviceRequest.merge(
  z.object({
    params: createGenerationParamsNA,
  }),
);

export const generationResult = z.object({
  output: CreateChatCompletionResponseSchema,
});

export const createGenerationResponse = serviceResponse.merge(
  z.object({
    generationId: z.string().optional(),
    result: generationResult.nullable().optional(),
    stream: z.any().nullish(),
  }),
);

export type CreateGenerationParams = z.infer<typeof createGenerationParams>;
export type CreateGenerationRequest = z.infer<typeof createGenerationRequest>;

export type GenerationResult = z.infer<typeof generationResult>;
export type CreateGenerationResponse = z.infer<typeof createGenerationResponse>;

/** ******************************************************************************
 *  Get Generation Result By ID
 ******************************************************************************* */

export const getGenerationResultParams = z.object({
  generationId: z.string(),
});

export const getGenerationResultRequest = authedServiceRequest.merge(
  z.object({
    params: getGenerationResultParams,
  }),
);

export const FinalGenerationDetailsSchema = z.object({
  state: InferenceStateSchema.optional(),
  stateMessage: z.string().optional(),
  info: z.string().optional(),
  request: CreateChatCompletionRequestSchema.optional(),
  response: CreateChatCompletionResponseSchema.optional(),
  dispatchedAt: z.date().nullish(),
  finishedAt: z.date().nullish(),
});

export type FinalGenerationDetails = z.infer<
  typeof FinalGenerationDetailsSchema
>;

export const getGenerationResultResponse = serviceResponse.merge(
  FinalGenerationDetailsSchema,
);

export type GetGenerationResultParams = z.infer<
  typeof getGenerationResultParams
>;
export type GetGenerationResultRequest = z.infer<
  typeof getGenerationResultRequest
>;
export type GetGenerationResultResponse = z.infer<
  typeof getGenerationResultResponse
>;

/** ******************************************************************************
 *  List Generations By User
 ******************************************************************************* */

export const listGenerationsParams = z.object({
  offset: z.number().optional(),
  pageSize: z.number().optional(),
  cursor: z.any().nullish(), // not used by required for trpc
});

export const listGenerationsRequest = authedServiceRequest.merge(
  z.object({
    params: listGenerationsParams,
  }),
);

export const listGenerationsResponse = serviceResponse.merge(
  z.object({
    generations: z.array(GenerationSchema).optional(),
    total: z.number().optional(),
  }),
);

export type ListGenerationsParams = z.infer<typeof listGenerationsParams>;
export type ListGenerationsRequest = z.infer<typeof listGenerationsRequest>;
export type ListGenerationsResponse = z.infer<typeof listGenerationsResponse>;

/** ******************************************************************************
 *  List Generations By Worker
 ******************************************************************************* */

export const listGenerationsByWorkerParams = z.object({
  workerId: z.string().optional(),
  instanceId: z.string().optional(),
  state: InferenceStateSchema.optional(),
  model: z.string().optional(),
  isSystem: z.boolean().optional(),
  cursor: z
    .object({
      pageSize: z.number(),
      offset: z.number(),
    })
    .nullish(),
  orderBy: z.enum(["createdAt", "updatedAt"]).optional(),
});

export const listGenerationsByWorkerRequest = authedServiceRequest.merge(
  z.object({
    params: listGenerationsByWorkerParams,
  }),
);

export const listGenerationsByWorkerResponse = serviceResponse.merge(
  z.object({
    generations: z.array(GenerationSchema).optional(),
    total: z.number().optional(),
  }),
);

export type ListGenerationsByWorkerParams = z.infer<
  typeof listGenerationsByWorkerParams
>;
export type ListGenerationsByWorkerRequest = z.infer<
  typeof listGenerationsByWorkerRequest
>;
export type ListGenerationsByWorkerResponse = z.infer<
  typeof listGenerationsByWorkerResponse
>;

/** ******************************************************************************
 *  Count Generations
 ******************************************************************************* */

export const countGenerationsParams = z.object({
  state: InferenceStateSchema,
});

export const countGenerationsRequest = serviceRequest.merge(
  z.object({
    params: countGenerationsParams,
  }),
);

export const countGenerationsResponse = serviceResponse.merge(
  z.object({
    count: z.number().optional(),
  }),
);

export type CountGenerationsParams = z.infer<typeof countGenerationsParams>;
export type CountGenerationsRequest = z.infer<typeof countGenerationsRequest>;
export type CountGenerationsResponse = z.infer<typeof countGenerationsResponse>;

/** ******************************************************************************
 *  RPS
 ******************************************************************************* */

export const generationRPMParams = z.object({
  state: InferenceStateSchema,
});

export const generationRPMRequest = serviceRequest.merge(
  z.object({
    params: generationRPMParams,
  }),
);

export const generationRPMResponse = serviceResponse.merge(
  z.object({
    rpm: z.number().optional(),
  }),
);

export type GenerationRPMParams = z.infer<typeof generationRPMParams>;
export type GenerationRPMRequest = z.infer<typeof generationRPMRequest>;
export type GenerationRPMResponse = z.infer<typeof generationRPMResponse>;
