Want to train an AI agent that can use external tools like search engines, calculators, or APIs? This tutorial shows you how to build tool-calling agents using reinforcement learningagents that learn when and how to request external information to solve complex problems.

We use VERL (a production-ready RL training framework) to train agents with search tool integration and SkyPilot to run and scale training on any AI infrastructure, including Kubernetes and clouds. Additionally, we propose an architecture that allow the rollout workers and the trainers to scale / recover independentally on any AI infrastructure with SkyPilot.

What You’ll Build

Tool-Calling Agent Architecture

By the end of this tutorial, you’ll have:

  1. Set up a search-augmented RL pipeline — Bring up a retrieval service, prepare tool-aware data, and configure VERL for tool-calling rollouts.
  2. Trained and evaluated the agent — Compare the trained agent’s answers against a baseline (no tools) on questions that require external knowledge.
  3. Swapped retrieval backends — Run the same agent with Wikipedia and Google Search to compare behavior and quality.
  4. Scaled the system — Separate the retrieval service from RL training across nodes to improve throughput and reliability.

All of this is orchestrated from a single SkyPilot YAML file, with optional multi-node setup that cleanly separates retrieval and training for scale.

Why Tool-Calling Agents Matter

Tool-Calling Agent Architecture Without tool usage, the LLM has only knowledge up to when it was trained and cannot verify any results

Traditional language models are limited by their training data cutoff. They can’t:

  • Access real-time information
  • Query specialized databases
  • Perform calculations beyond their learned patterns
  • Retrieve specific documents on demand

Tool-calling agents overcome these limitations by learning to:

  • Recognize when they need external information
  • Formulate appropriate queries to tools (search, APIs, calculators)
  • Integrate retrieved results into their reasoning
  • Optimize for task success using both internal knowledge and external tools

How Tool-Calling Works During Training

During RL training, there is a trainer and multiple rollout workers that coordinate tool calls:

  1. Model requests a tool call (e.g., <search> protein in egg yolks </search>)
  2. Rollout worker processes the request asynchronously by calling the retrieval service
  3. Model receives the response and continues generation with the retrieved information
  4. Trainer calculates the rewards based on the final answer and updates the model

This creates a training loop where the agent learns optimal tool-calling strategies through reinforcement—when to search, how to formulate queries, and how to synthesize retrieved information.

Tool-Calling Agent Architecture Agent-tool interaction: The agent decides when to search, retrieves documents, and synthesizes answers

Example: Training a Search-Augmented Q&A Agent

Let’s build an agent that uses Wikipedia search to answer questions it couldn’t otherwise answer correctly.

Step 1: Setting Up the Retrieval Service

Before training, we need a retrieval service that the agent can query. VERL’s example uses a local Wikipedia search service based on:

  • E5-base-v2 embeddings (dense retrieval)
  • FAISS indexing for fast similarity search
  • 132GB Wikipedia corpus (full text + indices)
  • FastAPI server exposing HTTP endpoints

The retrieval service runs independently from the training process, communicating via HTTP. This decoupled design means:

  • Retrieval can scale separately from training
  • Multiple training jobs can share one retrieval service
  • Easy to swap different retrieval backends

Step 2: Preparing Tool-Calling Data

The data format extends VERL’s standard format with tool-calling annotations. Here’s how a training example looks:

Example Question:

Who won the Nobel Prize in Physics in 2023?

Expected Behavior:

  1. Agent recognizes it doesn’t know (post-training-cutoff information)
  2. Agent calls search tool: search("Nobel Prize Physics 2023")
  3. Retrieval service returns Wikipedia excerpt
  4. Agent synthesizes answer from retrieved text

Data Format:

{
  "data_source": "search_qa",
  "prompt": [{
    "role": "user",
    "content": "Who won the Nobel Prize in Physics in 2023?"
  }],
  "ability": "search_augmented_qa",
  "reward_model": {
    "style": "rule",
    "ground_truth": "Pierre Agostini, Ferenc Krausz, and Anne L'Huillier"
  },
  "tool_config": {
    "available_tools": ["search"],
    "retrieval_service_url": "http://127.0.0.1:8000/retrieve"
  },
  "extra_info": {
    "requires_search": true,
    "split": "train",
    "index": 42
  }
}

Key differences from standard RL data:

  • tool_config specifies available tools and endpoints
  • requires_search hint (optional, for analysis)
  • Multi-turn expected (agent may search multiple times)

The preprocessing script (preprocess_search_r1_dataset.py) handles:

  • Downloading question-answer pairs
  • Formatting for multi-turn conversations
  • Validating retrieval service accessibility
  • Creating train/validation splits

Step 3: Training an Agent with Tool Calling

The training uses SGLang backend for rollout to support tool calling. Key configuration differences:

# Tool-calling config
actor_rollout_ref.rollout.name: sglang  #  Supports tools
multi_turn.enable: True                  #  Multi-turn conversations
retrieval_service_url: http://127.0.0.1:8000/retrieve  # Tool endpoint

Complete SkyPilot YAML:

Click to go to verl-search-tool.yaml in skypilot repo

Click to see verl-search-tool.yaml
# Search Tool Interaction Training with VERL
#
# This example demonstrates multi-turn tool interaction training using VERL with a search/retrieval tool.
# The model learns to use a search tool for answering questions that require external knowledge.
#
# Based on: https://verl.readthedocs.io/en/v0.5.x/sglang_multiturn/search_tool_example.html
#
# Usage:
#   sky launch -c verl-search llm/verl/search-tooling/verl-search-interaction.yaml --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
#
# Optional:
#   --secret WANDB_API_KEY  # For logging to Weights & Biases
#
#   sky launch -c verl-search llm/verl/search-tooling/verl-search-interaction.yaml --secret WANDB_API_KEY --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
# 
# Requirements:
#   - Docker with SYS_PTRACE capability (for PyTorch multiprocessing CUDA tensor sharing)
#   - Single H100 or equivalent GPU (can be adjusted for other accelerators)

resources:
  accelerators: H100:1
  memory: 128+
  image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
  ports:
    - 8265  # Ray dashboard
    - 8000  # Retrieval service

num_nodes: 1

config:
  docker:
    run_options:
      - --cap-add=SYS_PTRACE  # Required for PyTorch CUDA tensor sharing between Ray workers
      - --ipc=host
      - --shm-size=16g

envs:
  DATASET_SIZE: small  # Options: small (1000 train, 200 test), medium (10k train, 2k test), full
  TOTAL_EPOCHS: 1
  TOTAL_STEPS: 10
  TRAIN_BATCH_SIZE: 512  # Reduced from 512 for smaller steps
  VAL_BATCH_SIZE: 256  # Reduced from 256 for smaller steps
  SAVE_FREQ: 5  # Save checkpoints every 10 steps (reduced from 100)
  TEST_FREQ: 5  # Test every 5 steps (reduced from 50)
  MODEL_NAME: Qwen/Qwen2.5-3B-Instruct
  WANDB_PROJECT_NAME: search_r1_like_async_rl
  WANDB_EXPERIMENT_NAME: qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn
  CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints

file_mounts:
  /checkpoints:
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT

secrets:
  WANDB_API_KEY: ""

setup: |
  rm -f ~/.pip/pip.conf
  rm -f ~/.config/pip/pip.conf
  
  set -e

  echo "=== VERL Search Tool Interaction Setup ==="

  # System dependencies
  echo "Installing system dependencies..."
  sudo apt update && sudo apt install -y iproute2 npm

  # Python environment
  echo "Setting up Python virtual environment..."
  uv venv --python 3.10 --seed
  source .venv/bin/activate

  # Clone VERL repository
  echo "Cloning VERL repository..."
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout v0.6.0

  # Core dependencies
  echo "Installing PyTorch and VERL..."
  uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
  uv pip install -v -e .
  uv pip install wheel
  uv pip install packaging
  uv pip install -r ./requirements_sglang.txt

  # Search/retrieval specific dependencies
  echo "Installing retrieval service dependencies..."
  uv pip install faiss-gpu-cu12
  # issue with uvloop version https://github.com/volcengine/verl/issues/3806
  uv pip install uvloop==0.21.0

  # Download Wikipedia corpus and FAISS index
  echo "Downloading Wikipedia corpus and FAISS index..."
  export save_path=~/dataset
  mkdir -p $save_path

  huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
  huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path

  # Move files to expected locations
  mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
  mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index

  # Decompress the JSONL file
  gzip -d $save_path/wiki-18.jsonl.gz -f

  # Data preparation
  echo "Preparing search R1 dataset..."
  python3 examples/data_preprocess/preprocess_search_r1_dataset.py

  git clone https://github.com/PeterGriffinJin/Search-R1/  

run: |
  set -e

  echo "=== VERL Search Tool Interaction Training ==="

  # Multi-node setup
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE

  # Network configuration for distributed training
  NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
  export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
  export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE

  # PyTorch multiprocessing configuration
  export TORCH_MULTIPROCESSING_SHARING_STRATEGY=file_system

  # Activate environment
  source .venv/bin/activate

  # Set up paths
  cd verl
  PROJECT_DIR="$(pwd)"
  export PYTHONPATH="$PROJECT_DIR:$PYTHONPATH"

  # Start retrieval service
  echo "Starting retrieval server..."
  # conda activate retriever
  save_path=~/dataset
  index_file=$save_path/e5_Flat.index
  corpus_file=$save_path/wiki-18.jsonl
  retriever_name=e5
  retriever_path=intfloat/e5-base-v2

  python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
    --index_path $index_file \
    --corpus_path $corpus_file \
    --topk 3 \
    --retriever_name $retriever_name \
    --retriever_model $retriever_path &

  RETRIEVAL_PID=$!
  sleep 10
  conda deactivate

  save_path=~/dataset
  index_file=$save_path/e5_Flat.index
  corpus_file=$save_path/wiki-18.jsonl
  retriever_name=e5
  retriever_path=intfloat/e5-base-v2

  # WandB login (optional)
  if [ -n "$WANDB_API_KEY" ]; then
    echo "Logging into Weights & Biases..."
    python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
  fi

  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    echo "Starting Ray head node on port 6379..."
    ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port=6379 --dashboard-host=0.0.0.0 --dashboard-port=8265

    # Wait for all nodes to connect
    echo "Waiting for $NUM_NODES nodes to connect..."
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "✓ All $NUM_NODES nodes connected"
        break
      fi
      retry_count=$((retry_count+1))
      sleep 10
    done

    # Display Ray cluster status
    echo "Ray cluster status:"
    ray status

    echo "Starting search tool interaction training..."
    cd $PROJECT_DIR

    # Increase file descriptor limit
    ulimit -n 65535

    # Set up configuration paths
    CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
    TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet"
    VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet"
    TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml"

    # Training with search tool
    python3 -m verl.trainer.main_ppo \
      --config-path="$CONFIG_PATH" \
      --config-name='search_multiturn_grpo' \
      algorithm.adv_estimator=grpo \
      data.train_batch_size=$TRAIN_BATCH_SIZE \
      data.val_batch_size=$VAL_BATCH_SIZE \
      data.max_prompt_length=4096 \
      data.max_response_length=3000 \
      data.filter_overlong_prompts=True \
      data.truncation='error' \
      data.return_raw_chat=True \
      actor_rollout_ref.model.path=$MODEL_NAME \
      actor_rollout_ref.actor.optim.lr=1e-6 \
      actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
      actor_rollout_ref.model.use_remove_padding=True \
      actor_rollout_ref.actor.ppo_mini_batch_size=16 \
      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
      actor_rollout_ref.actor.use_kl_loss=True \
      actor_rollout_ref.actor.kl_loss_coef=0.001 \
      actor_rollout_ref.actor.kl_loss_type=low_var_kl \
      actor_rollout_ref.actor.entropy_coeff=0 \
      actor_rollout_ref.model.enable_gradient_checkpointing=True \
      actor_rollout_ref.actor.fsdp_config.param_offload=True \
      actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
      actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
      actor_rollout_ref.rollout.max_model_len=15000 \
      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
      actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
      actor_rollout_ref.rollout.name=sglang \
      actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
      actor_rollout_ref.rollout.n=5 \
      actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \
      actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
      algorithm.use_kl_in_reward=False \
      trainer.critic_warmup=0 \
      trainer.val_before_train=False \
      trainer.logger='["console","wandb"]' \
      trainer.project_name=$WANDB_PROJECT_NAME \
      trainer.experiment_name=$WANDB_EXPERIMENT_NAME \
      trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
      trainer.nnodes=$NUM_NODES \
      trainer.save_freq=$SAVE_FREQ \
      trainer.test_freq=$TEST_FREQ \
      data.train_files="$TRAIN_DATA" \
      data.val_files="$VAL_DATA" \
      actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \
      trainer.total_epochs=$TOTAL_EPOCHS \
      trainer.total_training_steps=$TOTAL_STEPS \
      trainer.default_local_dir=/checkpoints

    echo "✓ Training complete!"

    # Model checkpoint merging
    echo "Merging model checkpoints..."
    LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
    CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"

    python -m verl.model_merger merge \
      --backend fsdp \
      --tie-word-embedding \
      --local_dir ${CHECKPOINT_DIR} \
      --target_dir /checkpoints/hf_model

    echo "✓ Model saved to /checkpoints/hf_model"
    echo "Training artifacts saved to cloud bucket: ${CHECKPOINT_BUCKET_NAME}"

    # Cleanup retrieval service before starting vLLM
    if [ -n "$RETRIEVAL_PID" ]; then
      echo "Stopping retrieval service..."
      kill $RETRIEVAL_PID 2>/dev/null || true
      sleep 5
    fi

  else
    # Worker node setup
    echo "Worker node (rank $SKYPILOT_NODE_RANK) connecting to head at $HEAD_IP:6379..."
    sleep 15
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    echo "✓ Worker node connected"
    sleep infinity
  fi  

Launch training:

# Single node (1 GPU)
sky launch -c verl-search llm/verl/search-tooling/verl-search-tool.yaml \
  --secret WANDB_API_KEY --num-nodes 1 -y

# Multi-node for faster training (2+ nodes)
sky launch -c verl-search llm/verl/search-tooling/verl-search-tool.yaml \
  --secret WANDB_API_KEY --num-nodes 2 -y

What SkyPilot does under the hood?

SkyPilot pipeline

SkyPilot Training Setup SkyPilot orchestrating the retrieval service and distributed training

How effective is the RL training for tool use?

Click to go to verl-search-interaction-infer.yaml in skypilot repo

Click to see verl-search-interaction-infer.yaml
resources:
  accelerators: H100:1
  memory: 128+
  ports:
    - 8000  # Retrieval service

num_nodes: 1

envs:
  MODEL_PATH: ""  # Optional: Path to model checkpoint (defaults to base model)
  RETRIEVAL_TOPK: 3
  RETRIEVER_NAME: e5
  RETRIEVER_MODEL: intfloat/e5-base-v2
  CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints

file_mounts:
  /checkpoints:
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT

setup: |
  set -e

  echo "=== Search Tool Inference Setup ==="

  # System dependencies
  echo "Installing system dependencies..."
  sudo apt update && sudo apt install -y iproute2

  # Python environment
  echo "Setting up Python virtual environment..."
  uv venv --python 3.10 --seed
  source .venv/bin/activate

  # Install dependencies
  echo "Installing PyTorch and dependencies..."
  uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
  uv pip install -v -e .
  uv pip install wheel
  uv pip install packaging
  uv pip install -r ./requirements_sglang.txt
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"

  # Download Wikipedia corpus and FAISS index
  echo "Downloading Wikipedia corpus and FAISS index..."
  export save_path=~/dataset
  mkdir -p $save_path

  huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
  huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path

  # Move files to expected locations
  mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
  mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index

  # Decompress the JSONL file
  gzip -d $save_path/wiki-18.jsonl.gz -f

  # Clone VERL repository
  echo "Cloning VERL repository..."
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout v0.6.0
  cd ..

  # Clone Search-R1 for inference
  echo "Cloning Search-R1 repository..."
  rm -rf Search-R1
  git clone https://github.com/PeterGriffinJin/Search-R1/

  # Install additional inference dependencies if needed
  cd Search-R1
  if [ -f requirements.txt ]; then
    uv pip install -r requirements.txt
  fi
  cd ..

  echo "✓ Inference setup complete!"  

run: |
  set -e

  echo "=== Search Tool Inference ==="

  # Activate environment
  source .venv/bin/activate

  # Set up paths
  save_path=~/dataset
  index_file=$save_path/e5_Flat.index
  corpus_file=$save_path/wiki-18.jsonl

  # Start retrieval server in background
  echo "Starting retrieval server on port 8000..."
  cd verl
  python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
    --index_path $index_file \
    --corpus_path $corpus_file \
    --topk $RETRIEVAL_TOPK \
    --retriever_name $RETRIEVER_NAME \
    --retriever_model $RETRIEVER_MODEL &

  RETRIEVAL_PID=$!
  sleep 10

  # Run inference
  cd ~/sky_workdir/Search-R1
  python infer.py  
sky launch -c verl-infer llm/verl/search-tooling/verl-search-interaction-infer.yaml -y

Comparison 1: Trained Agent vs Base Model

Let’s first compare the trained agent against the base model on questions requiring search:

Test Question: “How much protein in four boiled egg yolks?”

✅ Trained Agent (Wikipedia)

Complete interaction:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the given question. You must conduct reasoning inside
<think> and </think> first every time you get new information.
After reasoning, if you find you lack some knowledge, you can
call a search engine by <search> query </search> and it will
return the top searched results between <information> and
</information>. You can search as many times as your want. If
you find no further external knowledge needed, you can directly
provide the answer inside <answer> and </answer>, without
detailed illustrations. For example, <answer> Beijing </answer>.
Question: How much protein in four boiled egg yolks?
<|im_end|>
<|im_start|>assistant

<think>I need to find out how much protein is in four boiled
egg yolks. I'll search for it.</think>

<search> how much protein in four boiled egg yolks </search>

<information>Doc 1(Title: "Peptide YY") Peptide YY is a
peptide hormone released from cells in the ileum and colon
in response to feeding. It acts on various receptors in the
brain to reduce appetite.
Doc 2(Title: "Rolled oats") Rolled oats are traditionally
oat groats that have been de-husked and steamed, before being
rolled into flat flakes under heavy rollers and then lightly
toasted.
Doc 3(Title: "Egg sandwich") An egg sandwich is a sandwich
with a cooked egg as the main filling. Fried eggs, poached
eggs, scrambled eggs, and egg salad are popular options.
</information>

<think>These results aren't helpful - they don't contain
information about egg yolk protein content. Let me try a more
specific search for the protein content in one boiled egg
yolk first.</think>

<search> protein content in one boiled egg yolk </search>

<information>Doc 1(Title: "Immunoglobulin superfamily") The
immunoglobulin superfamily is a large protein superfamily of
cell surface and soluble proteins involved in recognition,
binding, and adhesion.
Doc 2(Title: "Collagen") Collagen is the main structural
protein in the extracellular matrix found in the body's
various connective tissues.
Doc 3(Title: "Egg as food") A chicken egg contains
approximately 6 grams of protein. The egg white contains
about 3.6 grams while the yolk contains around 2.7 grams
of protein per large egg.
</information>

<think>Great! I found useful information in Doc 3. One egg
yolk contains approximately 2.7 grams of protein. For four
egg yolks, that would be 2.7 × 4 = 10.8 grams, which I can
round to about 11 grams.</think>

<answer> 11 </answer>

Key behaviors:

  • Recognizes knowledge gap and searches
  • Adapts when first results are irrelevant
  • Refines query and extracts information from Wikipedia
  • Performs calculation from found data

Correct: ~11 grams (2.7g per yolk × 4 = 10.8g)

❌ Base Model (No Tool Training)

Complete interaction:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the given question. You must conduct reasoning inside
<think> and </think> first every time you get new information.
After reasoning, if you find you lack some knowledge, you can
call a search engine by <search> query </search> and it will
return the top searched results between <information> and
</information>. You can search as many times as your want. If
you find no further external knowledge needed, you can directly
provide the answer inside <answer> and </answer>, without
detailed illustrations. For example, <answer> Beijing </answer>.
Question: How much protein in four boiled egg yolks?
<|im_end|>
<|im_start|>assistant

<think>
1. We know that there is no need for an external search
   for this information, so we don't need to use a
   search engine.
2. We can directly provide the answer using the
   knowledge we have.
</think>

<answer>
44 grams of protein
</answer>

Problems:

  • Never attempts to search despite lacking knowledge
  • Over-confident in incorrect internal knowledge
  • Hallucinates answer (44g) with false certainty

Incorrect: 44 grams (hallucinated)

Key takeaway: The trained agent recognizes when it needs external information and uses tools effectively, while the base model hallucinates with false confidence.


Comparison 2: Wikipedia (Stale Corpus) vs Google Search (Current Data)

The trained agent can work with different retrieval backends. This comparison highlights a key advantage: Google Search provides up-to-date information, while a Wikipedia corpus from 2023 may be outdated.

Test Question: “Who are the latest Nobel Prize winners in Physics?”

✅ Trained Agent + Google Search

Complete interaction:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the given question. You must conduct reasoning inside
<think> and </think> first every time you get new information.
After reasoning, if you find you lack some knowledge, you can
call a search engine by <search> query </search> and it will
return the top searched results between <information> and
</information>. You can search as many times as your want. If
you find no further external knowledge needed, you can directly
provide the answer inside <answer> and </answer>, without
detailed illustrations. For example, <answer> Beijing </answer>.
Question: Who are the latest Nobel Prize winners in Physics?
<|im_end|>
<|im_start|>assistant

<think>I need to find out who are the latest Nobel Prize
winners in Physics. I'll search for this information.</think>

<search> latest Nobel Prize Physics winner </search>

<information>
Doc 1 (Title: "All Nobel Prizes 2025 – NobelPrize.org")
The 2025 Nobel Prize in Physics has been awarded to 
John Clarke, Michel H. Devoret, and John M. Martinis 
“for the discovery of macroscopic quantum mechanical 
tunnelling and energy quantisation in an electric circuit.”
Doc 2 (Title: "Nobel Prize in Physics 2025 – Summary")
John Clarke, Michel H. Devoret, and John M. Martinis 
share the Nobel Prize in Physics 2025.
Doc 3 (Title: "John Clarke, Michel Devoret and John Martinis 
win the 2025 Nobel Prize for Physics")
The physics prize was awarded jointly to J
ohn Clarke, Michel H. Devoret and John M. Martinis for their 
experimental demonstration that quantum mechanical 
tunnelling and energy quantisation can 
occur in a macroscopic electric circuit.
</information>

<think>
I found the answer. The 2025 Nobel Prize in Physics 
was awarded to John Clarke, Michel H. Devoret, and 
John M. Martinis for their discoveries related to 
macroscopic quantum tunneling and quantized
energy states in electric circuits.
</think>

<answer>John Clarke, Michel H. Devoret, 
and John M. Martinis</answer>

Key behaviors:

  • Single search returns current results
  • Google provides up-to-date (2024) information
  • Quickly finds accurate answer

Correct: John J. Hopfield and Geoffrey E. Hinton (2024)

⚠️ Trained Agent + Wikipedia (2023 Corpus)

Complete interaction:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the given question. You must conduct reasoning inside
<think> and </think> first every time you get new information.
After reasoning, if you find you lack some knowledge, you can
call a search engine by <search> query </search> and it will
return the top searched results between <information> and
</information>. You can search as many times as your want. If
you find no further external knowledge needed, you can directly
provide the answer inside <answer> and </answer>, without
detailed illustrations. For example, <answer> Beijing </answer>.
Question: Who are the latest Nobel Prize winners in Physics?
<|im_end|>
<|im_start|>assistant

<think>I need to find out who are the latest Nobel Prize
winners in Physics. I'll search for this information.</think>

<search> latest Nobel Prize Physics winner </search>

<information>Doc 1(Title: "2023 Nobel Prize in Physics") The
Nobel Prize in Physics 2023 was awarded to Pierre Agostini,
Ferenc Krausz and Anne L'Huillier for experimental methods
that generate attosecond pulses of light.
Doc 2(Title: "List of Nobel laureates in Physics") Most recent
physics laureates include Pierre Agostini, Ferenc Krausz, and
Anne L'Huillier in 2023.
Doc 3(Title: "Attosecond physics") Pierre Agostini, Ferenc
Krausz, and Anne L'Huillier pioneered attosecond physics and
were awarded the Nobel Prize.
</information>

<think>Based on the search results, the most recent Nobel
Prize in Physics information available is from 2023, awarded
to Pierre Agostini, Ferenc Krausz, and Anne L'Huillier.</think>

<answer> Pierre Agostini, Ferenc Krausz, and Anne L'Huillier </answer>

Problems:

  • Wikipedia corpus outdated (stops at 2023)
  • Agent returns 2023 winners as “latest” due to stale data
  • No access to current (2024) information

⚠️ Outdated: Returns 2023 winners (Pierre Agostini, Ferenc Krausz, and Anne L’Huillier) instead of 2024 winners


Once trained, agents can swap retrieval backends (Wikipedia → Google → custom APIs) without retraining, as they’ve learned the general skill of tool use. This lets you test the same agent with different knowledge sources and compare their effectiveness.

Running inference with different retrieval backends:

For Google Search backend:

sky launch -c verl-infer-google llm/verl/search-tooling/verl-search-interaction-google-infer.yaml \
  --env MODEL_PATH=/checkpoints/hf_model \
  --env GOOGLE_API_KEY=your_key_here \
  --env GOOGLE_CSE_ID=your_cse_id_here \
  -y

Click to go to verl-search-interaction-google-infer.yaml in skypilot repo

Click to see verl-search-interaction-google-infer.yaml
# Search Tool Interaction Inference (Google Search backend)
#
# This example demonstrates inference using Search-R1 with a search/retrieval tool.
# The model uses a Google Search–backed tool for answering questions that require external knowledge.
# Both the Google search server and inference run on the same node.
#
# Usage:
#   sky launch -c verl-infer-google llm/verl/search-tooling/verl-search-interaction-google-infer.yaml \
#     --env MODEL_PATH=/checkpoints/hf_model \
#     --env GOOGLE_API_KEY=your_key_here \
#     --env GOOGLE_CSE_ID=your_cse_id_here \
#     -y
#
# Requirements:
#   - Single GPU for inference
#   - Valid Google Programmable Search Engine (CSE) + API key

resources:
  accelerators: H100:1
  memory: 128+
  ports:
    - 8000  # Google search server

num_nodes: 1

envs:
  MODEL_PATH: ""          # Optional: Path to model checkpoint (defaults to base model)
  GOOGLE_API_KEY: ""      # Required: Google API key
  GOOGLE_CSE_ID: ""       # Required: Google Custom Search Engine ID
  CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints

file_mounts:
  /checkpoints:
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT

setup: |
  set -e

  echo "=== Search Tool Inference Setup (Google Search) ==="

  # System dependencies
  echo "Installing system dependencies..."
  sudo apt update && sudo apt install -y iproute2 git

  # Python environment
  echo "Setting up Python virtual environment..."
  uv venv --python 3.10 --seed
  source .venv/bin/activate

  echo "Installing PyTorch..."
  uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

  # Clone VERL repository (if infer.py relies on its code / configs)
  echo "Cloning VERL repository..."
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout v0.6.0

  echo "Installing VERL + SGLang dependencies..."
  uv pip install -v -e .
  uv pip install wheel
  uv pip install packaging
  uv pip install -r ./requirements_sglang.txt
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"

  cd ..

  # Clone Search-R1 for inference
  echo "Cloning Search-R1 repository..."
  rm -rf Search-R1
  git clone https://github.com/PeterGriffinJin/Search-R1.git

  # Install additional inference dependencies
  cd Search-R1
  if [ -f requirements.txt ]; then
    echo "Installing Search-R1 requirements..."
    uv pip install -r requirements.txt
  fi

  # Ensure Google API client is available (if not already pulled in)
  uv pip install google-api-python-client

  cd ..

  echo "✓ Inference setup complete!"  

run: |
  set -e

  echo "=== Search Tool Inference (Google Search backend) ==="

  # Activate environment
  source .venv/bin/activate

  # Sanity check env vars
  if [ -z "$GOOGLE_API_KEY" ] || [ -z "$GOOGLE_CSE_ID" ]; then
    echo "ERROR: GOOGLE_API_KEY and GOOGLE_CSE_ID must be set via --env."
    exit 1
  fi

  echo "Using GOOGLE_API_KEY: (set)"
  echo "Using GOOGLE_CSE_ID:  (set)"

  # Start Google search server in background
  cd ~/sky_workdir/Search-R1
  echo "Starting Google search server on port 8000..."
  python search_r1/search/google_search_server.py \
    --api_key "$GOOGLE_API_KEY" \
    --cse_id "$GOOGLE_CSE_ID" \
    > google_search_server.log 2>&1 &

  RETRIEVAL_PID=$!
  echo "Google search server PID: $RETRIEVAL_PID"

  # Give the server a moment to start
  sleep 10

  # (Optional) basic health check if the server exposes one
  # curl -f http://127.0.0.1:8000/health || echo "Healthcheck failed (continuing anyway)"

  # Run inference
  echo "Running infer.py..."
  if [ -n "$MODEL_PATH" ]; then
    # If your infer.py supports a flag, use it; otherwise it may read MODEL_PATH from env.
    python infer.py --model_path "$MODEL_PATH" || python infer.py
  else
    python infer.py
  fi

  echo "✓ Inference finished"

  # Clean up search server (SkyPilot will tear down the node afterwards anyway)
  if ps -p $RETRIEVAL_PID > /dev/null 2>&1; then
    echo "Stopping Google search server..."
    kill $RETRIEVAL_PID || true
  fi

  echo "=== Done ==="  

The agent will interactively answer questions using the configured retrieval backend, demonstrating the search behaviors shown in the comparisons above.

Key takeaway: Once trained on the general skill of tool use, agents can swap retrieval backends without retraining. This flexibility lets you optimize for different use cases—Wikipedia for encyclopedic knowledge, Google for current information, or custom APIs for domain-specific data.


Scale up the RL post-training: Separating Rollout and Training

VERL Architecture Distributed RL training architecture with tool calling: Rollout workers handle tool invocations asynchronously (from VERL Agentic RL documentation)

For large-scale training, you can distribute the retrieval service and training across different machines to optimize resource utilization and performance.

Why separate retrieval and training:

  1. Resource isolation: Retrieval service (CPU/memory intensive) doesn’t compete with training (GPU intensive)
  2. Independent scaling: Scale retrieval throughput and training capacity separately
  3. Better utilization: Dedicate high-memory nodes to retrieval, GPU nodes to training
  4. Fault tolerance: Retrieval service can persist across training restarts

Architecture:

Tool-Calling Agent Architecture

Multi-node scaling with SkyPilot:

The YAML can be extended to run retrieval and training on separate node types. We can scale the retrieval service easily using sky serve!

Sky serve

Click to go to verl-search-interaction-retrieval.yaml in skypilot repo

Click to see verl-search-interaction-retrieval.yaml
# Search Tool Retrieval Service
#
# This service provides Wikipedia retrieval capabilities using FAISS indexing.
# It runs on CPU nodes and exposes a retrieval API on port 8000.
#
# Usage:
#   sky launch -c retrieval llm/verl/search-tooling/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
#
# Get endpoint:
#   sky status retrieval --endpoint 8000

resources:
  cpus: 32+
  memory: 256+
  use_spot: false
  ports:
    - 8000  # Retrieval service API

num_nodes: 1

envs:
  RETRIEVAL_TOPK: 3
  RETRIEVER_NAME: e5
  RETRIEVER_MODEL: intfloat/e5-base-v2

setup: |
  set -e

  echo "=== Retrieval Service Setup ==="

  # System dependencies
  echo "Installing system dependencies..."
  sudo apt update && sudo apt install -y iproute2

  # Python environment
  echo "Setting up Python virtual environment..."
  uv venv --python 3.10 --seed
  source .venv/bin/activate

  # Install retrieval service dependencies
  echo "Installing retrieval service dependencies..."
  uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
  uv pip install transformers datasets huggingface_hub
  uv pip install faiss-cpu
  uv pip install uvicorn fastapi uvloop==0.21.0

  # Download Wikipedia corpus and FAISS index
  echo "Downloading Wikipedia corpus and FAISS index..."
  export save_path=~/dataset
  mkdir -p $save_path

  huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
  huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path

  # Move files to expected locations
  mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
  mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index

  # Decompress the JSONL file
  gzip -d $save_path/wiki-18.jsonl.gz -f

  # Clone VERL repository for retrieval server code
  echo "Cloning repositories..."
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout v0.6.0

  # Patch retrieval server for CPU-only usage (comment out CUDA calls)
  echo "Patching retrieval server for CPU-only usage..."
  sed -i 's/^\(\s*\)\(model\.cuda()\)/\1# \2  # Commented out for CPU-only deployment/' \
    examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py
  sed -i 's/^\(\s*\)\(inputs = {k: v\.cuda() for k, v in inputs\.items()}\)/\1# \2  # Commented out for CPU-only deployment/' \
    examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py

  cd ..

  echo "✓ Retrieval service setup complete!"  

run: |
  set -e

  echo "=== Starting Retrieval Service ==="

  # Activate environment
  source .venv/bin/activate

  # Set up paths
  save_path=~/dataset
  index_file=$save_path/e5_Flat.index
  corpus_file=$save_path/wiki-18.jsonl

  # Start retrieval server
  echo "Starting retrieval server on port 8000..."
  cd verl
  python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
    --index_path $index_file \
    --corpus_path $corpus_file \
    --topk $RETRIEVAL_TOPK \
    --retriever_name $RETRIEVER_NAME \
    --retriever_model $RETRIEVER_MODEL &

  echo "✓ Retrieval service running on port 8000"  

Click to go to verl-search-interaction-rl-trainer.yaml in skypilot repo

Click to see verl-search-interaction-rl-trainer.yaml
# Search Tool Interaction Training with VERL (RL Trainer)
#
# This example demonstrates multi-turn tool interaction training using VERL with a search/retrieval tool.
# The model learns to use a search tool for answering questions that require external knowledge.
#
# Requires a separate retrieval service running (see verl-search-interaction-retrieval.yaml)
#
# Based on: https://verl.readthedocs.io/en/v0.5.x/sglang_multiturn/search_tool_example.html
#
# Usage:
#   # 1. Launch retrieval service first
#   sky launch -c retrieval llm/verl/search-tooling/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
#
#   # 2. Get retrieval service endpoint
#   RETRIEVAL_IP=$(sky status retrieval --endpoint 8000)
#
#   # 3. Launch training (without WandB)
#   sky launch -c verl-train llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml --env RETRIEVAL_SERVICE_URL=http://$RETRIEVAL_IP --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
#
#   # Or with WandB logging (optional)
#   sky launch -c verl-train llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml --env RETRIEVAL_SERVICE_URL=http://$RETRIEVAL_IP --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 --secret WANDB_API_KEY -y
#
# Requirements:
#   - Docker with SYS_PTRACE capability (for PyTorch multiprocessing CUDA tensor sharing)
#   - H100 GPUs (can be adjusted for other accelerators)
#   - Running retrieval service at RETRIEVAL_SERVICE_URL

resources:
  accelerators: H100:1
  memory: 128+
  image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
  ports:
    - 8265  # Ray dashboard
    - 9090  # vLLM model serving

num_nodes: 1

config:
  docker:
    run_options:
      - --cap-add=SYS_PTRACE  # Required for PyTorch CUDA tensor sharing between Ray workers
      - --ipc=host
      - --shm-size=16g

envs:
  RETRIEVAL_SERVICE_URL: ""  # Required: URL of the retrieval service (e.g., http://retrieval-ip:8000)
  DATASET_SIZE: small  # Options: small (1000 train, 200 test), medium (10k train, 2k test), full
  TOTAL_EPOCHS: 1
  TOTAL_STEPS: 10
  TRAIN_BATCH_SIZE: 512
  VAL_BATCH_SIZE: 256
  SAVE_FREQ: 5  # Save checkpoints every 5 steps
  TEST_FREQ: 5  # Test every 5 steps
  MODEL_NAME: Qwen/Qwen2.5-3B-Instruct
  WANDB_PROJECT_NAME: search_r1_like_async_rl
  WANDB_EXPERIMENT_NAME: qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn
  CHECKPOINT_BUCKET_NAME: nebius://verl-search-interaction-checkpoints

file_mounts:
  /checkpoints:
    source: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT_CACHED

secrets:
  WANDB_API_KEY: ""  # Optional: Set to enable WandB logging. If not set, only console logging will be used.

setup: |
  rm -f ~/.pip/pip.conf
  rm -f ~/.config/pip/pip.conf

  set -e

  echo "=== VERL Search Tool Interaction Training Setup ==="

  # Validate required environment variables
  if [ -z "$RETRIEVAL_SERVICE_URL" ]; then
    echo "ERROR: RETRIEVAL_SERVICE_URL environment variable is required"
    echo "Example: --env RETRIEVAL_SERVICE_URL=http://retrieval-ip:8000"
    exit 1
  fi

  # Python environment
  echo "Setting up Python virtual environment..."
  uv venv --python 3.10 --seed
  source .venv/bin/activate

  # Clone VERL repository
  echo "Cloning VERL repository..."
  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout v0.6.0

  # Core dependencies
  echo "Installing PyTorch and VERL..."
  uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
  uv pip install -v -e .
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
  uv pip install wheel
  uv pip install packaging
  uv pip install -r ./requirements_sglang.txt

  # Install uvloop (required version)
  uv pip install uvloop==0.21.0

  # Data preparation
  echo "Preparing search R1 dataset..."
  python3 examples/data_preprocess/preprocess_search_r1_dataset.py

  # Clone Search-R1 for additional utilities
  git clone https://github.com/PeterGriffinJin/Search-R1/

  # Update tool config to use external retrieval service
  echo "Configuring external retrieval service..."
  TOOL_CONFIG="examples/sglang_multiturn/config/tool_config/search_tool_config.yaml"

  # Backup original config
  cp $TOOL_CONFIG ${TOOL_CONFIG}.bak

  # Update retrieval URL and num_workers in the config
  sed -i 's/num_workers: *120/num_workers: 8/' $TOOL_CONFIG
  sed -i "s|http://127\.0\.0\.1:8000/retrieve|$RETRIEVAL_SERVICE_URL/retrieve|g" $TOOL_CONFIG
  sed -i "s|http://localhost:8000|$RETRIEVAL_SERVICE_URL|g" $TOOL_CONFIG

  echo "✓ Setup complete!"
  echo "Dataset location: ~/data/searchR1_processed_direct/"
  echo "VERL repository: $(pwd)"
  echo "Retrieval service: $RETRIEVAL_SERVICE_URL"  

run: |
  set -e

  echo "=== VERL Search Tool Interaction Training ==="
  sudo apt update && sudo apt install -y iproute2 npm

  # Validate retrieval service
  if [ -z "$RETRIEVAL_SERVICE_URL" ]; then
    echo "ERROR: RETRIEVAL_SERVICE_URL environment variable is required"
    exit 1
  fi

  echo "Testing connection to retrieval service at $RETRIEVAL_SERVICE_URL..."
  # Give it a few retries in case the service is still starting
  max_retries=30
  retry_count=0
  while [ $retry_count -lt $max_retries ]; do
    # Test the /retrieve endpoint with a sample query
    test_response=$(curl -s -X POST "${RETRIEVAL_SERVICE_URL}/retrieve" \
      -H "Content-Type: application/json" \
      -d '{"queries": ["test query"], "topk": 1, "return_scores": false}' \
      -w "\n%{http_code}" 2>&1)

    http_code=$(echo "$test_response" | tail -n1)

    if [ "$http_code" = "200" ]; then
      echo "✓ Successfully connected to retrieval service"
      echo "✓ /retrieve endpoint is responding correctly"
      break
    fi
    retry_count=$((retry_count+1))
    if [ $retry_count -eq $max_retries ]; then
      echo "WARNING: Could not connect to retrieval service at $RETRIEVAL_SERVICE_URL"
      echo "Make sure the retrieval service is running and accessible"
      echo "Last response code: $http_code"
    fi
    sleep 5
  done

  # Multi-node setup
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE

  # Network configuration for distributed training
  NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
  export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
  export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE

  # PyTorch multiprocessing configuration
  export TORCH_MULTIPROCESSING_SHARING_STRATEGY=file_system

  # Activate environment
  source .venv/bin/activate

  # Set up paths
  cd verl
  PROJECT_DIR="$(pwd)"
  export PYTHONPATH="$PROJECT_DIR:$PYTHONPATH"

  # WandB login (optional)
  if [ -n "$WANDB_API_KEY" ]; then
    echo "Logging into Weights & Biases..."
    python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
  fi

  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    echo "Starting Ray head node on port 6379..."
    ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port=6379 --dashboard-host=0.0.0.0 --dashboard-port=8265

    # Wait for all nodes to connect
    echo "Waiting for $NUM_NODES nodes to connect..."
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "✓ All $NUM_NODES nodes connected"
        break
      fi
      retry_count=$((retry_count+1))
      sleep 10
    done

    # Display Ray cluster status
    echo "Ray cluster status:"
    ray status

    echo "Starting search tool interaction training..."
    cd $PROJECT_DIR

    # Increase file descriptor limit
    ulimit -n 65535

    # Set up configuration paths
    CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
    TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet"
    VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet"
    TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml"

    # Configure logging based on WANDB_API_KEY availability
    if [ -n "$WANDB_API_KEY" ]; then
      LOGGER_CONFIG='["console","wandb"]'
      WANDB_ARGS="trainer.project_name=$WANDB_PROJECT_NAME trainer.experiment_name=$WANDB_EXPERIMENT_NAME"
      echo "✓ WandB logging enabled"
    else
      LOGGER_CONFIG='["console"]'
      WANDB_ARGS=""
      echo "ℹ WandB logging disabled (no API key provided)"
    fi

    # Training with search tool
    python3 -m verl.trainer.main_ppo \
      --config-path="$CONFIG_PATH" \
      --config-name='search_multiturn_grpo' \
      algorithm.adv_estimator=grpo \
      data.train_batch_size=$TRAIN_BATCH_SIZE \
      data.val_batch_size=$VAL_BATCH_SIZE \
      data.max_prompt_length=4096 \
      data.max_response_length=3000 \
      data.filter_overlong_prompts=True \
      data.truncation='error' \
      data.return_raw_chat=True \
      actor_rollout_ref.model.path=$MODEL_NAME \
      actor_rollout_ref.actor.optim.lr=1e-6 \
      actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
      actor_rollout_ref.model.use_remove_padding=True \
      actor_rollout_ref.actor.ppo_mini_batch_size=16 \
      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
      actor_rollout_ref.actor.use_kl_loss=True \
      actor_rollout_ref.actor.kl_loss_coef=0.001 \
      actor_rollout_ref.actor.kl_loss_type=low_var_kl \
      actor_rollout_ref.actor.entropy_coeff=0 \
      actor_rollout_ref.model.enable_gradient_checkpointing=True \
      actor_rollout_ref.actor.fsdp_config.param_offload=True \
      actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
      actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
      actor_rollout_ref.rollout.max_model_len=15000 \
      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
      actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
      actor_rollout_ref.rollout.name=sglang \
      actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
      actor_rollout_ref.rollout.n=5 \
      actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \
      actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
      algorithm.use_kl_in_reward=False \
      trainer.critic_warmup=0 \
      trainer.val_before_train=False \
      trainer.logger="$LOGGER_CONFIG" \
      $WANDB_ARGS \
      trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
      trainer.nnodes=$NUM_NODES \
      trainer.save_freq=$SAVE_FREQ \
      trainer.test_freq=$TEST_FREQ \
      data.train_files="$TRAIN_DATA" \
      data.val_files="$VAL_DATA" \
      actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \
      trainer.total_epochs=$TOTAL_EPOCHS \
      trainer.total_training_steps=$TOTAL_STEPS \
      trainer.default_local_dir=/checkpoints

    echo "✓ Training complete!"

    # Model checkpoint merging
    echo "Merging model checkpoints..."
    LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
    CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"

    python -m verl.model_merger merge \
      --backend fsdp \
      --tie-word-embedding \
      --local_dir ${CHECKPOINT_DIR} \
      --target_dir /checkpoints/hf_model

    echo "✓ Model saved to /checkpoints/hf_model"
    echo "Training artifacts saved to cloud bucket: ${CHECKPOINT_BUCKET_NAME}"

  else
    # Worker node setup
    echo "Worker node (rank $SKYPILOT_NODE_RANK) connecting to head at $HEAD_IP:6379..."
    sleep 15
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    echo "✓ Worker node connected"
    sleep infinity
  fi  
# 1. Launch retrieval service
sky serve up -n retrieval llm/verl/search-tooling/verl-search-interaction-retrieval.yaml \
  --cpus 32+ \
  --memory 256+ \
  -y

# 2. Get retrieval endpoint
RETRIEVAL_IP=$(sky serve status retrieval --endpoint 8000)

# 3a. For training
sky launch -c verl-train llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml \
  --env RETRIEVAL_SERVICE_URL=http://$RETRIEVAL_IP \
  --env DATASET_SIZE=small \
  -y

Benefits of this architecture:

  • Use every spare GPU: Rollout workers are just HTTP clients, so you can run them on any available GPU (on-prem, k8s, or cloud) and point them at the same trainer + retrieval endpoints.
  • Decouple sampling from training: Scale rollout workers for more trajectories, or trainer GPUs for faster updates — each can scale independently.

Screensshots:

RL trainer node

Tool-Calling Agent Architecture

Retrieval node

Tool-Calling Agent Architecture

Why SkyPilot for Tool-Calling Agents

Infrastructure Challenges:

  1. Need to run retrieval service + training simultaneously
  2. Retrieval service requires persistent state (loaded indices)
  3. Multi-node coordination with tool access
  4. Checkpointing both model and tool usage stats

SkyPilot Solutions:

  • Multi-process orchestration: Manages retrieval service + training
  • Persistent storage: Mounts Wikipedia corpus efficiently
  • Auto-recovery: Restarts all services on spot preemption
  • Cross-cloud: Same YAML works on any infrastructure
# Works identically on Kubernetes
sky launch -c verl-search llm/verl/search-tooling/verl-search-tool.yaml \
  --cloud kubernetes \
  --secret WANDB_API_KEY \
  -y

# Or AWS
sky launch -c verl-search llm/verl/search-tooling/verl-search-tool.yaml \
  --cloud aws \
  --secret WANDB_API_KEY \
  -y

Conclusion

Tool-calling agents learn to access external knowledge beyond their training data. This tutorial showed how to train them with VERL and scale them with SkyPilot.

Key takeaway: In RL training, the rollout phase (where agents generate responses and call tools) and training phase (where models are updated) have different resource needs. Rollout needs fast retrieval access but minimal GPU, while training needs heavy GPU compute. SkyPilot lets you separate these cleanly—run the retrieval service on cheap CPU/GPU nodes that the rollout workers query via HTTP, while reserving expensive GPU nodes purely for the training loop. This architecture means you can maintain a persistent retrieval service shared across experiments while training nodes scale up or down as needed, optimizing both cost and performance.

Everything from single-node prototypes to multi-node production runs orchestrates from a single YAML file, on any infrastructure.

Next Steps

Ready to build tool-calling agents? Launch the example and join the community for support!