import { openai } from "@/schemas";
import { z } from "zod";

const templateFormats = ["f-string", "jinja2"] as const;

const base = z.object({
  input_variables: z.array(z.string()),
});

const prompt = base.extend({
  template: z.string(),
  template_format: z.enum(templateFormats),
  validate_template: z.boolean().default(true),
  _type: z.literal("prompt").default("prompt"),
});

const langchainMessage = z.object({
  role: z.enum([openai.roles[0], openai.roles[1], openai.roles[2]]),
  prompt: prompt,
});

// TODO: What is a LangChain message? Why does this exist?
// Fix it
// We should have a single message schema
const message = langchainMessage.or(openai.message);

export const functionCall = ["none", "auto"] as const;

const chatPrompt = base.extend({
  messages: z.array(message),
  functions: z.array(openai._function),
  _type: z
    .literal("chat_promptlayer_langchain")
    .default("chat_promptlayer_langchain"),
  function_call: z
    .enum(functionCall)
    .or(z.object({ name: z.string() }))
    .default("none"),
});

const promptTemplate = z
  .discriminatedUnion("_type", [chatPrompt, prompt])
  .refine(
    (values) => {
      if (values._type === "prompt") return true;
      return values.function_call === "auto"
        ? values.functions.length > 0
        : true;
    },
    {
      message: "function_call is auto but no functions are defined",
    },
  )
  .refine(
    (values) => {
      if (values._type === "prompt") return true;
      const functions = values.functions.map(({ name }) => name);
      return typeof values.function_call === "object"
        ? functions.includes(values.function_call.name)
        : true;
    },
    { message: "function_call.name must be a function name in functions" },
  )
  .refine(
    (values) => {
      if (values._type === "prompt") return true;
      const functions = values.functions.map(({ name }) => name);
      return values.messages.every((message) =>
        message.role === "function"
          ? !!message.name && functions.includes(message.name)
          : true,
      );
    },
    {
      message: "message.name must be a function name in functions",
    },
  );

const model = z.object({
  provider: z.string().default("openai"),
  name: z.string().default("gpt-3.5-turbo"),
  parameters: z.any().default({
    temperature: 1.0,
    max_tokens: 256,
    top_p: 1.0,
    frequency_penalty: 0.0,
    presence_penalty: 0.0,
  }),
});

const metadata = z.object({
  model: model.nullable().default(null),
});

const schema = z.object({
  prompt_name: z.string(),
  prompt_template: promptTemplate,
  tags: z.array(z.string()),
  metadata: metadata.nullable().default(null),
});

type Schema = z.infer<typeof schema>;
type Chat = z.infer<typeof chatPrompt>;
type Message = z.infer<typeof message>;
type TemplateFormat = (typeof templateFormats)[number];
type Model = z.infer<typeof model>;
type Metadata = z.infer<typeof metadata>;

export {
  Chat,
  Message,
  Metadata,
  Model,
  Schema,
  TemplateFormat,
  chatPrompt,
  model,
  prompt,
  schema,
  templateFormats,
};
