DBRX with MLflow
DBRX and MLflow
This note shows how to access Databricks foundation model APIs via OSS MLflow deployments server and via the MLflow OpenAI model flavor.
Here are two ways to use the DBRX model with MLflow. The following assumes that you have the following environment variables set:
DATABRICKS_TOKEN
: a databricks PATDATABRICKS_ENDPOINT
: your databricks workspace model serving endpoint. It will probably have the formhttps://<workspace_name>.cloud.databricks.com/serving-endpoints
.
Via the MLflow Deployments Server
- Configure the Deployments Server
yaml_content = { "endpoints": [ { "name": "dbrx", "endpoint_type": "llm/v1/chat", "model": { "provider": "openai", "name": "databricks-dbrx-instruct", "config": { "openai_api_key": os.getenv("DATABRICKS_TOKEN"), "openai_api_base": os.getenv("DATABRICKS_ENDPOINT") } } } ] } with open("deploy.yml", "w") as file: yaml.dump(yaml_content, file, default_flow_style=False)
Then start the server with:
mlflow deployments start-server --config-path deploy.yml
And query the model with:
from mlflow.deployments import get_deploy_client client = get_deploy_client("http://127.0.0.1:5000") name = "dbrx" data = dict( messages=[ {"role": "user", "content": "Hello, World."}, ], n=1, max_tokens=50, temperature=.5, ) response = client.predict(endpoint=name, inputs=data) print(response)
which will return
{ 'id': '19c82206-cae5-4d9c-a4f3-676e83281bb8', 'object': 'chat.completion', 'created': 1712267435, 'model': 'dbrx-instruct-032724', 'choices': [ { 'index': 0, 'message': {'role': 'assistant', 'content': 'Hello, World! How can I assist you today?'}, 'finish_reason': 'stop' } ], 'usage': {'prompt_tokens': 228, 'completion_tokens': 11, 'total_tokens': 239} }
Using the OpenAI Model Flavor
We can also log the model usign the OpenAI model flavor. We just need to be careful to set up the appropriate environment variables first.
# log a model to MLflow using the OpenAI Model Flavor import os import mlflow import openai os.environ["OPENAI_API_KEY"] = os.getenv("DATABRICKS_TOKEN") os.environ["OPENAI_API_BASE"] = os.getenv("DATABRICKS_ENDPOINT") # Log the OpenAI model to MLflow with mlflow.start_run(): info = mlflow.openai.log_model( model="databricks-dbrx-instruct", task=openai.chat.completions, artifact_path="dbrx", messages=[{"role": "system", "content": "You are a helpful assistant."}], ) dbrx_model = mlflow.pyfunc.load_model(info.model_uri) print(dbrx_model.predict("Hello, world"))
Which returns:
[ "Hello! How can I assist you today? I'm here to help answer any questions you might have or provide information on a topic of your choosing. Let me know how I can make your day a little bit easier!" ]