Draft

Databricks RAG Chatbot Series - Step 4 – Deploying the Full RAG Chain

Author

Zohar Kapach

Published

March 7, 2025

Create and Deploy a RAG Chain

To enable full functionality for the RAG chatbot, we need to build a complete RagChain and provide an endpoint for integration with the app. This will allow the chatbot to perform full retrieval and generation tasks using the configured models and data pipelines.

Install Libraries

This notebook has been tested on Databricks Runtime 16.2 ML and Serverless (Environment version 2)

%load_ext autoreload
%autoreload 2 
# To disable autoreload; run %autoreload 0
%pip install --quiet -U databricks-agents mlflow-skinny=2.20.3 mlflow==2.20.3 mlflow[gateway]==2.20.3 langchain==0.2.1 langchain_core==0.2.5 langchain_community==0.2.4 databricks-vectorsearch databricks-sdk==0.23.0
dbutils.library.restartPython()
Note: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.

Imports and Variables

%run ./00_setup
import os
from operator import itemgetter
import sys
# Add the project root to sys.path to make raglib importable
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from databricks import agents
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
import mlflow

from raglib.core.augmentation_formatters import extract_user_query_string, combine_all_messages_for_vector_search, format_context, extract_previous_messages
from raglib.models import MODEL_CONFIGS
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/mlflow/pyfunc/utils/data_validation.py:168: UserWarning: Add type hints to the `predict` method to enable data validation and automatic signature inference during model logging. Check https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel for more details.

  color_warning(

Create and Write Chain

This cell creates the chain. The chain can be tested here using the .invoke() method. The chain is written to file to avoid serialization. This way, in next steps when we log the model we can provide this file (the following cell) as a langchain model. MLFlow will detect the model because we are using mlflow.models.set_model. Alternatively, you can log the model using the chain variable itself. In this case you will run into errors with the retriever which would need to be passed in as a separate input to log_model().


%%writefile chain.py

# Imports, variables, and functions because I need databricks to write this cell as a file. This file is used to log the model in the following cells of this notebook

import os
from operator import itemgetter
import sys
# Add the project root to sys.path to make raglib importable
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
import mlflow

from raglib.core.augmentation_formatters import extract_user_query_string, combine_all_messages_for_vector_search, format_context, extract_previous_messages
from raglib.models import MODEL_CONFIGS


UC_NAME = "ragchat"
PREFIX = ""
VECTOR_SEARCH_ENDPOINT_NAME = f"{PREFIX}vs_endpoint"

VECTOR_SEARCH_INDEX_NAME = f"{UC_NAME}.gold.vector_storage"
VECTOR_SEARCH_SOURCE_TABLE_NAME = f"{UC_NAME}.silver.ragchat_summarized_intercom_conversations"

## Enable MLflow Tracing
mlflow.langchain.autolog()

# Connect to the Vector Search Index
vsc = VectorSearchClient(disable_notice=True)
vs_index = vsc.get_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
    index_name=VECTOR_SEARCH_INDEX_NAME,
)

# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    vs_index,
    text_column=MODEL_CONFIGS["embedding"]["params"]["source_col"],
    columns=[
        MODEL_CONFIGS["embedding"]["params"]["primary_key"],
        MODEL_CONFIGS["embedding"]["params"]["source_col"],
    ],
).as_retriever(search_kwargs=MODEL_CONFIGS["retrieval"]["params"]["search_kwargs"])

# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
    primary_key=MODEL_CONFIGS["embedding"]["params"]["primary_key"],
    text_column=MODEL_CONFIGS["embedding"]["params"]["source_col"],
)

# Prompt Template for generation
prompt = PromptTemplate(
    template=MODEL_CONFIGS["llm"]["llm_prompt_template"],
    input_variables=MODEL_CONFIGS["llm"]["llm_prompt_template_variables"],
)

# FM for generation
model = ChatDatabricks(
    endpoint=MODEL_CONFIGS["llm"]["llm_endpoint_name"],
    extra_params=MODEL_CONFIGS["llm"]["llm_parameters"],
)

# RAG Chain
chain = (
    {
        "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
        "context": itemgetter("messages")
        | RunnableLambda(combine_all_messages_for_vector_search)
        | vector_search_as_retriever
        | RunnableLambda(format_context),
        "chat_history": itemgetter("messages") | RunnableLambda(extract_previous_messages)
    }
    | prompt
    | model
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=chain)

# COMMAND ----------
# I comment this out because I need databricks to write this cell as a file. This file is used to log the model in the following cells of this notebook
# chain.invoke({"messages": [{"content": "coke or pepsi?", "role": "user"}]})
Overwriting chain.py

###Log and Test the Chain

# Log the model to MLflow
with mlflow.start_run(run_name=f"dbrx_ragchat"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model=os.path.join(os.getcwd(), 'chain.py'),
        artifact_path="rag_chain",  # Required by MLflow
        input_example=MODEL_CONFIGS["llm"]["llm_input_example"],  # Save the chain's input schema.  MLflow will execute the chain before logging & capture it's output schema.
        code_paths=[
            '../raglib'
        ]
    )  
# Test the chain locally
load_chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
load_chain.invoke({"messages": [{"content": "coke or pepsi?", "role": "user"}]})
2025/03/05 01:52:54 INFO mlflow: Attempting to auto-detect Databricks resource dependencies for the current langchain model. Dependency auto-detection is best-effort and may not capture all dependencies of your langchain model, resulting in authorization errors when serving or querying your model. We recommend that you explicitly pass `resources` to mlflow.langchain.log_model() to ensure authorization to dependent resources succeeds when the model is deployed.
The logged model is compatible with the Mosaic AI Agent Framework.
'I do not know.'
Trace(request_id=tr-dc8b469fa5b842808825cf6237e97b49)

Register and Deploy the Chain

def wait_for_model_serving_endpoint_to_be_ready(endpoint_name):
    '''Wait for a model serving endpoint to be ready'''
    from databricks.sdk import WorkspaceClient
    from databricks.sdk.service.serving import EndpointStateReady, EndpointStateConfigUpdate
    import time

    # Wait for it to be ready
    w = WorkspaceClient()
    state = ""
    for i in range(400):
        state = w.serving_endpoints.get(endpoint_name).state
        if state.config_update == EndpointStateConfigUpdate.IN_PROGRESS:
            if i % 40 == 0:
                print(f"Waiting for endpoint to deploy {endpoint_name}. Current state: {state}")
            time.sleep(10)
        elif state.ready == EndpointStateReady.READY:
          print('endpoint ready.')
          return
        else:
          break
    raise Exception(f"Couldn't start the endpoint, timeout, please check your endpoint for more details: {state}")
# Register the chain to UC
mlflow.set_registry_uri('databricks-uc') 
uc_registered_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=MODEL_NAME_FQN)

# Deploy to enable the Review APP and create an API endpoint
deployment_info = agents.deploy(model_name=MODEL_NAME_FQN, model_version=uc_registered_model_info.version, scale_to_zero=True)

# Add the user-facing instructions to the Review App
agents.set_review_instructions(MODEL_NAME_FQN, MODEL_CONFIGS["llm"]["llm_reviewer_instructions"])

wait_for_model_serving_endpoint_to_be_ready(deployment_info.endpoint_name)
Registered model 'cma_ragchat.gold.dbrx_ragchat' already exists. Creating a new version of this model...
Created version '2' of model 'cma_ragchat.gold.dbrx_ragchat'.
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/databricks/agents/utils/mlflow_utils.py:130: FutureWarning: ``mlflow.models.rag_signatures.ChatCompletionRequest`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatCompletionRequest`` instead.
  ChatCompletionRequest()
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/mlflow/models/rag_signatures.py:26: FutureWarning: ``mlflow.models.rag_signatures.Message`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatMessage`` instead.
  messages: list[Message] = field(default_factory=lambda: [Message()])
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/databricks/agents/utils/mlflow_utils.py:133: FutureWarning: ``mlflow.models.rag_signatures.SplitChatMessagesRequest`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatCompletionRequest`` instead.
  split_chat_messages_schema = convert_dataclass_to_schema(SplitChatMessagesRequest())
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/databricks/agents/utils/mlflow_utils.py:184: FutureWarning: ``mlflow.models.rag_signatures.ChatCompletionResponse`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatCompletionResponse`` instead.
  ChatCompletionResponse()
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/mlflow/models/rag_signatures.py:72: FutureWarning: ``mlflow.models.rag_signatures.ChainCompletionChoice`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatChoice`` instead.
  choices: list[ChainCompletionChoice] = field(default_factory=lambda: [ChainCompletionChoice()])
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/mlflow/models/rag_signatures.py:48: FutureWarning: ``mlflow.models.rag_signatures.Message`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatMessage`` instead.
  default_factory=lambda: Message(
/local_disk0/.ephemeral_nfs/envs/pythonEnv-b793e302-246e-43f5-97c1-c106e8827231/lib/python3.11/site-packages/databricks/agents/utils/mlflow_utils.py:187: FutureWarning: ``mlflow.models.rag_signatures.StringResponse`` is deprecated. This method will be removed in a future release. Use ``mlflow.types.llm.ChatCompletionResponse`` instead.
  string_response_schema = convert_dataclass_to_schema(StringResponse())

    Deployment of cma_ragchat.gold.dbrx_ragchat version 2 initiated.  This can take up to 15 minutes and the Review App & Query Endpoint will not work until this deployment finishes.

    View status: https://dbc-464ba720-0425.cloud.databricks.com/ml/endpoints/agents_cma_ragchat-gold-dbrx_ragchat
    Review App: https://dbc-464ba720-0425.cloud.databricks.com/ml/review/cma_ragchat.gold.dbrx_ragchat/2?o=3203210999074204
    
Waiting for endpoint to deploy agents_cma_ragchat-gold-dbrx_ragchat. Current state: EndpointState(config_update=<EndpointStateConfigUpdate.IN_PROGRESS: 'IN_PROGRESS'>, ready=<EndpointStateReady.READY: 'READY'>)
Waiting for endpoint to deploy agents_cma_ragchat-gold-dbrx_ragchat. Current state: EndpointState(config_update=<EndpointStateConfigUpdate.IN_PROGRESS: 'IN_PROGRESS'>, ready=<EndpointStateReady.READY: 'READY'>)
endpoint ready.