Daniel Liden

Blog / About Me / Photos / LLM Fine Tuning / Notes /

DBRX with MLflow

DBRX and MLflow

This note shows how to access Databricks foundation model APIs via OSS MLflow deployments server and via the MLflow OpenAI model flavor.

Here are two ways to use the DBRX model with MLflow. The following assumes that you have the following environment variables set:

  • DATABRICKS_TOKEN: a databricks PAT
  • DATABRICKS_ENDPOINT: your databricks workspace model serving endpoint. It will probably have the form https://<workspace_name>.cloud.databricks.com/serving-endpoints.

Via the MLflow Deployments Server

  1. Configure the Deployments Server
yaml_content = {
    "endpoints": [
        {
            "name": "dbrx",
            "endpoint_type": "llm/v1/chat",
            "model": {
                "provider": "openai",
                "name": "databricks-dbrx-instruct",
                "config": {
                    "openai_api_key": os.getenv("DATABRICKS_TOKEN"),
                    "openai_api_base": os.getenv("DATABRICKS_ENDPOINT")
                }
            }
        }
    ]
}

with open("deploy.yml", "w") as file:
    yaml.dump(yaml_content, file, default_flow_style=False)

Then start the server with:

mlflow deployments start-server --config-path deploy.yml

And query the model with:

from mlflow.deployments import get_deploy_client

client = get_deploy_client("http://127.0.0.1:5000")
name = "dbrx"
data = dict(
     messages=[
        {"role": "user", "content": "Hello, World."},
    ],
    n=1,
    max_tokens=50,
    temperature=.5,
)

response = client.predict(endpoint=name, inputs=data)
print(response)

which will return

{
    'id': '19c82206-cae5-4d9c-a4f3-676e83281bb8',
    'object': 'chat.completion',
    'created': 1712267435,
    'model': 'dbrx-instruct-032724',
    'choices': [
        {
            'index': 0,
            'message': {'role': 'assistant', 'content': 'Hello, World! How can I assist you today?'},
            'finish_reason': 'stop'
        }
    ],
    'usage': {'prompt_tokens': 228, 'completion_tokens': 11, 'total_tokens': 239}
}

Using the OpenAI Model Flavor

We can also log the model usign the OpenAI model flavor. We just need to be careful to set up the appropriate environment variables first.

# log a model to MLflow using the OpenAI Model Flavor

import os
import mlflow
import openai


os.environ["OPENAI_API_KEY"] = os.getenv("DATABRICKS_TOKEN")
os.environ["OPENAI_API_BASE"] = os.getenv("DATABRICKS_ENDPOINT")

# Log the OpenAI model to MLflow
with mlflow.start_run():
    info = mlflow.openai.log_model(
        model="databricks-dbrx-instruct",
        task=openai.chat.completions,
        artifact_path="dbrx",
        messages=[{"role": "system", "content": "You are a helpful assistant."}],
    )

dbrx_model = mlflow.pyfunc.load_model(info.model_uri)

print(dbrx_model.predict("Hello, world"))

Which returns:

[
    "Hello! How can I assist you today? I'm here to help answer any questions you might have or provide information
on a topic of your choosing. Let me know how I can make your day a little bit easier!"
]

Date: 2024-04-04 Thu 00:00

Emacs 29.3 (Org mode 9.6.15)