Want to train an AI agent with RL that can solve math problems or write code? This tutorial walks you through building your own math and coding agents with step-by-step examples with plenty of screenshots to help you along the way.

We use VERL (a production-ready training framework) to apply RL post-training for LLM and SkyPilot to run and scale the training on any of your own AI infrastructure, including Kubernetes and clouds.

What You’ll Build

By the end of this tutorial, you’ll have:

  1. Prepared your data - Download and format training examples for math and coding tasks
  2. Launched training - Start training with a single command on any AI infra, including Kubernetes or clouds.
  3. Served your agent - Deploy your trained model as an API endpoint you can query
  4. Compared against base model - See how your trained agent performs vs the original model
  5. Monitored progress - Track your agent’s improvement in real-time with WandB
  6. (Optional) Shared your model - Upload to HuggingFace for others to use

All of this happens automatically from a single SkyPilot YAML file. Just one command and the entire pipeline runs end-to-end.

TL;DR: Complete examples with screenshots showing exactly how to train math and coding agents. One SkyPilot YAML does everything: finds GPUs, prepares data, trains, and serves your agent.

Understanding VERL: Production-Ready Agent Framework

VERL (Volcano Engine Reinforcement Learning) is an open-source framework that makes RL training practical and scalable. Instead of building your own training infrastructure from scratch, VERL handles the complexity for you.

image VERL’s distributed architecture: Actor and Critic models train on dedicated GPUs while Rollout model generates experiences asynchronously (image from https://tongyx361.github.io/blogs/posts/verl-intro)

Why SkyPilot: Run and Scale AI on Any Infra

SkyPilot Architecture

  • Run anywhere: One YAML works across Kubernetes, AWS, GCP, Azure, Lambda, Nebius, and on-prem
  • Auto-provision: SkyPilot finds GPUs and sets up your environment automatically
  • Scale easily: Go from 1 GPU to multi-node
  • Infrastructure as code: Your entire pipeline in one file

Let’s see this in action with two complete examples.

Example: Training a Math Agent for Formatted Step-by-Step Problem Solving

Step 1: Preparing the Math Dataset

Before we can train our agent, we need to prepare our “practice problems.” We’ll use the GSM8K dataset—a collection of elementary school math word problems.

Here’s a sample question:

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Why format matters: VERL needs to know 3 things for each problem:

  1. The question to ask the model
  2. The correct answer to check against
  3. How to calculate the reward (did the model get it right?)

Here’s how we transform GSM8K into VERL’s format (see VeRL documentation for details):

instruction_following = "Let's think step by step and output the final answer after \"####\"."

def make_map_fn(split):
    def process_fn(example, idx):
        question = example.pop('question')
        question = question + ' ' + instruction_following

        answer = example.pop('answer')
        solution = extract_solution(answer)  # Extracts the final number after "####"

        data = {
            "data_source": "gsm8k",
            "prompt": [{
                "role": "user",
                "content": question
            }],
            "ability": "math",
            "reward_model": {
                "style": "rule",
                "ground_truth": solution
            },
            "extra_info": {
                'split': split,
                'index': idx
            }
        }
        return data

    return process_fn

Example transformation:

FieldOriginalTransformed
QuestionNatalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let’s think step by step and output the final answer after “####”.
AnswerNatalia sold 48/2 = «48/2=24»24 clips in May. Natalia sold 48+24 = «48+24=72»72 clips altogether in April and May. #### 72Extracted to ground_truth: 72

Resulting JSON format:

{
  "data_source": "gsm8k",
  "prompt": [
    {
      "role": "user",
      "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\"."
    }
  ],
  "ability": "math",
  "reward_model": {
    "style": "rule",
    "ground_truth": "72"
  },
  "extra_info": {
    "split": "train",
    "index": 0
  }
}

Key points:

  • reward_model.style: "rule" means we use simple rule-based checking: if the model’s answer matches the ground truth (72), it gets a reward
  • The prompt instructs the model to output its final answer after #### so we can easily extract it
  • VeRL automatically handles downloading and processing—you just run gsm8k.py (included in the YAML below)

Step 2: Training the Math Agent

We’ll use PPO (Proximal Policy Optimization), a popular agent training algorithm, to teach our model through trial and error with rewards.

Click to see verl-ppo.yaml
# Usage:
#   sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
#
#   sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y


resources:
  infra: nebius
  accelerators: H100:1
  memory: 128+
  image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
  ports:
    - 8265
    - 9090

num_nodes: 1

envs:
  TOTAL_EPOCHS: 1
  WANDB_PROJECT_NAME: skypilot-verl
  WANDB_EXPERIMENT_NAME: ppo-math
  CHECKPOINT_BUCKET_NAME: sky-verl-ppo-checkpoints
  HF_UPLOAD_MODEL_NAME: "maknee/verl-ppo-math"
  SAVE_FINAL_MODEL_HF_PATH: /checkpoints/hf_model

file_mounts:
  /checkpoints:
    store: nebius
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT_CACHED

secrets:
  HF_TOKEN: null
  WANDB_API_KEY: null 

setup: |  
  rm -f ~/.pip/pip.conf
  rm -f ~/.config/pip/pip.conf

  sudo apt install iproute2 -y

  uv venv --python 3.10 --seed
  source .venv/bin/activate

  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout 83aebcc133663c12ac33ea3d5ba5c5c5b4687286
  uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
  uv pip install -v -e .
  uv pip install hf_transfer
  uv pip install flashinfer-python
  uv pip install "vllm==0.10.0" --torch-backend=auto
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"

  uv pip install datasets

  uv pip install "ray[train]" "click<8.2.0"
  uv pip install tqdm

  echo "Downloading Math dataset..."
  mkdir -p ~/data/math
  python3 "$(pwd)/examples/data_preprocess/math_dataset.py" --local_dir ~/data/math
  echo "Math dataset download completed"

  uv pip install zmq

run: |
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE

  NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
  export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
  export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE

  export VLLM_USE_V1=1

  source .venv/bin/activate

  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    echo "Starting Ray head node..."
    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \
          --port=6379 \
          --dashboard-host=0.0.0.0 \
          --dashboard-port=8265

    # Wait for all worker nodes to join
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"

      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "All nodes connected to Ray cluster"
        break
      fi

      retry_count=$((retry_count+1))
      sleep 10
    done

    python3 -m verl.trainer.main_ppo \
     data.train_files=$HOME/data/math/train.parquet \
     data.val_files=$HOME/data/math/test.parquet \
     data.train_batch_size=256 \
     data.max_prompt_length=1024 \
     data.max_response_length=1024 \
     data.filter_overlong_prompts=True \
     actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
     actor_rollout_ref.actor.optim.lr=1e-6 \
     actor_rollout_ref.actor.ppo_mini_batch_size=64 \
     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
     actor_rollout_ref.rollout.name=vllm \
     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
     actor_rollout_ref.rollout.trace.backend=weave \
     actor_rollout_ref.rollout.trace.token2text=True \
     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
     actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
     critic.optim.lr=1e-5 \
     critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
     critic.ppo_micro_batch_size_per_gpu=4 \
     critic.model.fsdp_config.model_dtype=bfloat16 \
     algorithm.kl_ctrl.kl_coef=0.001 \
     trainer.logger=[console,wandb] \
     trainer.val_before_train=False \
     trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
     trainer.nnodes=$NUM_NODES \
     trainer.save_freq=10 \
     trainer.test_freq=1 \
     trainer.default_local_dir=/checkpoints \
     trainer.total_epochs=${TOTAL_EPOCHS} \
     trainer.project_name=$WANDB_PROJECT_NAME \
     trainer.experiment_name=$WANDB_EXPERIMENT_NAME

    LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
    CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"

    if [ -n "$HF_TOKEN" ]; then
      python -m verl.model_merger merge \
        --backend fsdp \
        --tie-word-embedding \
        --local_dir ${CHECKPOINT_DIR} \
        --target_dir ${SAVE_FINAL_MODEL_HF_PATH} \
        --hf_upload_path ${HF_UPLOAD_MODEL_NAME}
    else
      python -m verl.model_merger merge \
        --backend fsdp \
        --tie-word-embedding \
        --local_dir ${CHECKPOINT_DIR} \
        --target_dir ${SAVE_FINAL_MODEL_HF_PATH}
    fi

    vllm serve /checkpoints/hf_model \
      --host 0.0.0.0 \
      --port 9090

  else
    sleep 15
    echo "Starting Ray worker node..."
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    sleep 10
  fi

  echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."  

As you can see, the YAML contains bash scripting functionality under run (offering the same usability as Slurm).

Launch training with a single command:

# Single node (1 GPU)
sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --num-nodes 1 -y

# Multi-node for faster training (2 nodes, multiple GPUs)
sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --num-nodes 2 -y

That’s it! SkyPilot will:

  • Find available GPUs on any of your AI infra, including Kubernetes and clouds.
  • Set up the Ray cluster automatically
  • Download the dataset and start training
  • Save checkpoints to cloud storage

Logs while training

You will see a constant stream of logs when you launch:

log

NOTE: Setting up the Docker image may take several minutes - a good time to grab some coffee!

And when running

log

Tip: To access the node for debugging or monitoring, simply run ssh verl-ppo

If you ever lose connection, you can always stream the logs back to your local machine!

sky logs verl-ppo

At the end you should see vLLM serving your model on port 9090!

log

Let’s make a request to it (on our local machine used to launch the skypilot yaml):

ENDPOINT=$(sky status --endpoint 9090 verl-ppo)
curl http://$ENDPOINT/v1/chat/completions \
 -H "Content-Type: application/json" \
 -d '{
    "model": "/checkpoints/hf_model",
    "messages": [
      {
        "role": "system",
        "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?  Let'\''s think step by step and output the final answer within \\boxed{}"
      }
    ]
  }' | jq -r '.choices[0]["message"]["content"]'

We get this as a response, which breaks it down:

To determine the total number of bolts needed for the robe, we need to follow these steps:

1. Identify the amount of blue fiber required.
2. Calculate the amount of white fiber required.
3. Add the amounts of blue and white fibers together.

From the problem statement:
- It takes 2 bolts of blue fiber.
- It takes half as much white fiber as blue fiber.

Let's calculate the amount of white fiber first:
\[ \text{Amount of white fiber} = \frac{\text{Amount of blue fiber}}{2} = \frac{2}{2} = 1 \text{ bolt} \]

Now, let's add the amounts of blue and white fibers together:
\[ \text{Total bolts} = \text{Amount of blue fiber} + \text{Amount of white fiber} = 2 + 1 = 3 \]

Therefore, the total number of bolts needed is \(\boxed{3}\).

Step 3: Comparing against the base model

Let’s compare how the trained model responds versus the base model on the same problem:

Test Problem: “A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?”

✓ Trained Model (RL-Enhanced)

To determine the total number of bolts needed for the robe, we need to follow these steps:

  1. Identify the amount of blue fiber required.
  2. Calculate the amount of white fiber required.
  3. Add the amounts of blue and white fibers together.

From the problem statement:

  • It takes 2 bolts of blue fiber.
  • It takes half as much white fiber as blue fiber.

Let’s calculate the amount of white fiber first: \[ \text{Amount of white fiber} = \frac{\text{Amount of blue fiber}}{2} = \frac{2}{2} = 1 \text{ bolt} \]

Now, let’s add the amounts of blue and white fibers together: \[ \text{Total bolts} = \text{Amount of blue fiber} + \text{Amount of white fiber} = 2 + 1 = 3 \]

Therefore, the total number of bolts needed is \(\boxed{3}\).

Base Model

To determine how many bolts of fabric are needed for one robe, we need to consider the amount of each type of fiber required and then sum them up.

  • We add the number of bolts of blue and white fibers together: 2 + 1 = 3

Thus, the total number of bolts needed for one robe is \(\boxed{3}\).

Key Difference: The trained model provides a more structured, step-by-step breakdown with clear mathematical notation, making it easier to follow the reasoning process. The base model gets the right answer but skips intermediate reasoning steps.

Click to see the curl commands to test both models

Testing the trained model:

ENDPOINT=$(sky status --endpoint 9090 verl-ppo)
curl http://$ENDPOINT/v1/chat/completions \
 -H "Content-Type: application/json" \
 -d '{
    "model": "/checkpoints/hf_model",
    "messages": [
      {
        "role": "system",
        "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?  Let'\''s think step by step and output the final answer within \\boxed{}"
      }
    ]
  }' | jq -r '.choices[0]["message"]["content"]'

Testing the base model:

# First, serve the base model on GPU Server
vllm serve Qwen/Qwen2.5-0.5B-Instruct \
  --host 0.0.0.0 \
  --port 9090

# Then query it
ENDPOINT=$(sky status --endpoint 9090 verl-ppo)
curl http://$ENDPOINT/v1/chat/completions \
 -H "Content-Type: application/json" \
 -d '{
    "model": "Qwen/Qwen2.5-0.5B-Instruct",
    "messages": [
      {
        "role": "system",
        "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?  Let'\''s think step by step and output the final answer within \\boxed{}"
      }
    ]
  }' | jq -r '.choices[0]["message"]["content"]'

Example: Training a Coding Agent to Generate Executable Solutions

Step 1: Preparing the Coding Dataset

For the coding agent, we’ll use the rStar-Coder dataset—a collection of competitive programming problems. These are much harder than the math problems, requiring algorithmic thinking and code generation.

Expand to see an example

Question

Chef has a cubic die with 6 faces kept on an infinite plane. Each face has a distinct integer in the range [1,6] written on it, but the exact arrangement of the numbers on the faces of the die is unknown to Chef. Curiosity gets the better of Chef and he wants to find out o(1), o(2), …, o(6), where o(i) is the number written opposite to the number i.

Chef performs the following N-1 steps to learn the exact arrangement of the numbers on the die. In the i-th step, Chef pushes the die in some direction (there are 4 possible directions), and the die rolls 90o in this direction. The picture below demonstrates a die and the result that it produced after rolling in each of the 4 directions respectively. For this die, we have o(1)=4, o(2)=5, o(3)=6, o(4)=1, o(5)=2, o(6)=3.

Chef records N numbers A1, A2, …, AN, where Ai is the number written on the top of the die before the i-th step. However, the information on the direction in which he pushes the die each time are lost. Can you help him determine the arrangement of 1-6? You need to output o(1), o(2), …, o(6). If there are multiple solutions, output any one of them.


Input—–

  • The first line of input contains a single integer T, the number of test cases.
  • Each test case consists of two lines.
  • The first line contains an integer N, and the second line contains N space-separated integers A1, A2, …, AN.

Output—–

For each test case, output one line. If a possible solution exists, then output 6 space-separated integers o(1), o(2), …, o(6); otherwise, output -1.


Constraints—–

  • 1 ≤ T, N ≤ 500
  • 1 ≤ Ai ≤ 6

Subtask #1 (17 points):

  • N ≤ 2

Subtask #2 (83 points):

  • original constraints

Example—– Input: 4 9 1 3 4 6 1 3 4 6 1 10 1 2 1 3 1 4 1 5 1 6 6 1 2 3 4 5 6 4 1 3 3 1

Output: 2 1 6 5 4 3 -1 4 5 6 1 2 3 -1


Explanation—– Example case 1: See below.

Example case 2: You can prove that there’s no solution. Example case 3: The solution is not unique. 3 5 1 6 2 4 is another valid solution.

response

Okay, so I need to solve this problem where Chef rolls a die multiple times and records the top faces each time. The goal is to figure out the opposite faces of the die. Hmm, let’s see how to approach this.

The key is to model the die’s possible configurations after each step, considering all possible directions the die could have been rolled. Since each roll is a 90-degree turn in one of four directions (like north, south, east, west?), but the direction is unknown, each step’s transition between the current top and next top depends on the die’s current orientation.

First, I need to figure out how the die’s top, front, and other faces change when it rolls in a direction. Wait, but maybe there’s a standard way to model the die’s state. Let’s think of the die as having a current top face, front face, etc. However, direction is ambiguous here. Alternatively, perhaps the key is to track the current top and bottom faces, along with the possible adjacent faces?

Alternatively, since when a die rolls in a direction, the new top will be one of the adjacent faces (the ones adjacent to the current top). Wait, perhaps the movement of rolling can be modeled such that each roll changes the top to one of the adjacent faces, but the exact opposite face remains fixed?

Wait, the problem says that the exact arrangement of the numbers on the faces is unknown. We need to determine for each number i, which number is opposite to it (o(i)) such that the sequence of top faces A_1 to A_N can be achieved through a series of rolls in some directions, starting from A_1’s initial position.

Hmm, the problem requires that given the sequence of top numbers, we can determine the opposite pairs. The key is to model the transitions between the top faces as the die rolls. Each time the die rolls, the new top is adjacent to the previous top but not the opposite. Wait, no. Wait when you roll the die, the new top is a face adjacent to the previous top. Because when a die rolls over an edge, the new face comes from the side. The opposite face never becomes adjacent. So the opposite is fixed.

Wait, perhaps the opposite pairs are fixed, so for example, if the opposite of 1 is 4, then whenever the die is on 1, the bottom is 4. So whenever the die is rolled, the next top must be one of the four adjacent faces (the ones not opposite).

Therefore, for each pair of consecutive numbers in the sequence A, the next top (A_{i+1}) must be adjacent to the previous top (A_i) on the die. Since the opposite can’t be adjacent. So the adjacent faces are those faces that are not the opposite.

Step 2: Training the Coding Agent

We’ll use GRPO (Group Relative Policy Optimization), a more memory-efficient agent training algorithm.

Click to see verl-grpo.yaml
# Usage:
#   sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
#
#   sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y


resources:
  accelerators: H100:1
  memory: 128+
  image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
  ports:
    - 8265
    - 9090

envs:
  TOTAL_EPOCHS: 1
  WANDB_PROJECT_NAME: skypilot-verl
  WANDB_EXPERIMENT_NAME: grpo-code
  CHECKPOINT_BUCKET_NAME: sky-verl-grpo-checkpoints
  HF_UPLOAD_MODEL_NAME: "maknee/verl-grpo-code"
  SAVE_FINAL_MODEL_HF_PATH: /checkpoints/hf_model

file_mounts:
  /checkpoints:
    store: nebius
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT_CACHED
  /code:
    name: code
    source: llm/verl/code
    mode: COPY

secrets:
  HF_TOKEN: null
  WANDB_API_KEY: null

setup: |  
  rm -f ~/.pip/pip.conf
  rm -f ~/.config/pip/pip.conf

  sudo apt install iproute2 -y

  uv venv --python 3.10 --seed
  source .venv/bin/activate

  rm -rf verl
  git clone https://github.com/volcengine/verl.git
  cd verl
  git checkout 83aebcc133663c12ac33ea3d5ba5c5c5b4687286
  uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
  uv pip install -v -e .
  uv pip install hf_transfer
  uv pip install flashinfer-python
  uv pip install "vllm==0.10.0" --torch-backend=auto
  uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"

  uv pip install datasets

  uv pip install "ray[train]" "click<8.2.0"
  uv pip install tqdm

  echo "Downloading code dataset..."
  mkdir -p ~/data/code
  python3 /code/preprocess_rstar_coder.py --local_dir ~/data/code
  echo "code dataset download completed"

run: |
  HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
  NUM_NODES=$SKYPILOT_NUM_NODES
  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE

  NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
  export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
  export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE

  export VLLM_USE_V1=1

  source .venv/bin/activate

  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    echo "Starting Ray head node..."
    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \
          --port=6379 \
          --dashboard-host=0.0.0.0 \
          --dashboard-port=8265

    # Wait for all worker nodes to join
    retry_count=0
    max_retries=30
    while [ $retry_count -lt $max_retries ]; do
      connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
      echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"
      
      if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
        echo "All nodes connected to Ray cluster"
        break
      fi
      
      retry_count=$((retry_count+1))
      sleep 10
    done

    python3 -m verl.trainer.main_ppo \
     algorithm.adv_estimator=grpo \
     data.train_files=$HOME/data/code/train.parquet \
     data.val_files=$HOME/data/code/test.parquet \
     data.train_batch_size=32 \
     data.max_prompt_length=256 \
     data.max_response_length=256 \
     data.filter_overlong_prompts=True \
     data.truncation='error' \
     actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
     actor_rollout_ref.actor.optim.lr=1e-6 \
     actor_rollout_ref.model.use_remove_padding=True \
     actor_rollout_ref.actor.ppo_mini_batch_size=16 \
     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
     actor_rollout_ref.actor.ppo_epochs=1 \
     actor_rollout_ref.actor.use_kl_loss=False \
     actor_rollout_ref.actor.entropy_coeff=0 \
     actor_rollout_ref.model.enable_gradient_checkpointing=True \
     actor_rollout_ref.actor.fsdp_config.param_offload=True \
     actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
     actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
     actor_rollout_ref.rollout.name=vllm \
     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
     actor_rollout_ref.rollout.n=1 \
     actor_rollout_ref.rollout.enable_chunked_prefill=True \
     actor_rollout_ref.rollout.max_num_batched_tokens=2048 \
     actor_rollout_ref.rollout.trace.backend=weave \
     actor_rollout_ref.rollout.trace.token2text=True \
     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
     actor_rollout_ref.ref.fsdp_config.param_offload=True \
     algorithm.use_kl_in_reward=False \
     trainer.critic_warmup=0 \
     trainer.logger=[console,wandb] \
     trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
     trainer.nnodes=$NUM_NODES \
     trainer.save_freq=10 \
     trainer.test_freq=1 \
     trainer.total_epochs=${TOTAL_EPOCHS} \
     trainer.default_local_dir=/checkpoints \
     trainer.project_name=$WANDB_PROJECT_NAME \
     trainer.experiment_name=$WANDB_EXPERIMENT_NAME

    LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
    CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"

    if [ -z "$HF_TOKEN" ]; then
      python -m verl.model_merger merge \
        --backend fsdp \
        --tie-word-embedding \
        --local_dir ${CHECKPOINT_DIR} \
        --target_dir ${SAVE_FINAL_MODEL_HF_PATH} \
        --hf_upload_path ${HF_UPLOAD_MODEL_NAME}
    else
      python -m verl.model_merger merge \
        --backend fsdp \
        --tie-word-embedding \
        --local_dir ${CHECKPOINT_DIR} \
        --target_dir ${SAVE_FINAL_MODEL_HF_PATH}
    fi

    vllm serve /checkpoints/hf_model \
      --host 0.0.0.0 \
      --port 9090

  else
    sleep 15
    echo "Starting Ray worker node..."
    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
    sleep 10
  fi

  echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."  

Launch training:

# Single node
sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --num-nodes 1 -y

# Multi-node
sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --num-nodes 2 -y

Training log:

log

Same output for vllm serving:

log

Let’s test it with a coding problem: Word Capitalization

Problem: Capitalize the first letter of a word, keeping all other letters unchanged.

codeforces

Let’s run the command to test the agent against the coding problem

ENDPOINT=$(sky status --endpoint 9090 verl-grpo)
curl http://$ENDPOINT/v1/chat/completions \
 -H "Content-Type: application/json" \
 -d '{
    "model": "/checkpoints/hf_model",
    "messages": [
      {
        "role": "system",
        "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "Problem Description:\\nCapitalization is writing a word with its first letter as a capital letter. Your task is to capitalize the given word.\\nNote, that during capitalization all the letters except the first one remains unchanged.\\n\\nInput:\\nA single line contains a non-empty word. This word consists of lowercase and uppercase English letters. The length of the word will not exceed 10³.\\n\\nOutput:\\nOutput the given word after capitalization.\\n\\nExamples\\n\\nInput:\\nApPLe\\n\\nOutput:\\nApPLe\\n\\nInput:\\nkonjac\\n\\nOutput:\\nKonjac\\n\\n Let'\''s think step by step and output the final answer after \"####\""
      }
    ]
  }' | jq -r '.choices[0]["message"]["content"]'

log

We get this as a response, which is correct:

To solve this problem, you can follow these steps:

1. Read the input word.
2. Capitalize the first letter of the word.
3. Keep the rest of the letters as they are.
4. Output the capitalized word.

Here's a simple Python function to achieve this:

```python
def capitalize_word(word):
    # Capitalize the first letter of the word
    if len(word) > 0:
        return word[0].upper() + word[1:]
    else:
        return word

# Example usage:
input_word = input().strip()
output_word = capitalize_word(input_word)
print(output_word)

### Explanation:
- `word[0].upper()` capitalizes the first character of the string.
- `word[1:]` keeps the rest of the string unchanged.
- If the word is empty, it simply returns an empty string.

### Example:
For the input "ApPLe":
- The first character 'A' is capitalized to 'A'.
- The rest of the string remains "pPLe".
- So, the output is "ApPLe".

For the input "konjac":
- The first character 'k' is capitalized to 'K'.
- The rest of the string remains "onjac".
- So, the output is "Konjac".

This solution works for any valid input as per the problem statement.

Step 3: Comparing against the base model

Let’s compare how the trained coding agent responds versus the base model on the same problem:

Test Problem: Word Capitalization from Codeforces

✓ Trained Model (RL-Enhanced)

To solve this problem, you can follow these steps:

  1. Read the input word.
  2. Capitalize the first letter of the word.
  3. Keep the rest of the letters as they are.
  4. Output the capitalized word.

Here’s a simple Python function to achieve this:

def capitalize_word(word):
    if len(word) > 0:
        return word[0].upper() + word[1:]
    else:
        return word

input_word = input().strip()
print(capitalize_word(input_word))

Explanation:

  • word[0].upper() capitalizes the first character

  • word[1:] keeps the rest unchanged

  • Handles edge cases (empty strings)

Base Model

Let’s break down the problem step by step:

  1. Read the Input: The input is a single non-empty word.
  2. Identify the First Letter: Check the first letter.
  3. Capitalize the First Letter: Convert to uppercase.
  4. Keep the Rest Unchanged: All other letters stay the same.
  5. Output the Result: Print the modified word.

Example: “ApPLe”

  1. Read the Input: “ApPLe”
  2. First letter is ‘A’
  3. Convert ‘A’ to ‘A’ (already uppercase)
  4. Keep ‘p’, ‘P’, ‘L’, ’e’ unchanged
  5. Output: “ApPLe”

Example: “konjac”

  1. Read: “konjac”
  2. First letter is ‘k’
  3. Convert ‘k’ to ‘K’
  4. Keep ‘o’, ’n’, ‘j’, ‘a’, ‘c’ unchanged
  5. Output: “Konjac”

Final Answer: For “ApPLe” → “ApPLe”. For “konjac” → “Konjac”.

Key Difference: The trained model immediately provides executable code with clear implementation details, while the base model gives a verbose step-by-step explanation but no actual code solution. For coding problems, the trained model’s approach is more practical and directly useful.

Reference log

Here is a dump of the skypilot logs for running the example.

Monitoring Training with WandB

Once training starts, you can monitor progress in real-time using Weights & Biases (WandB). The training logs automatically sync to your WandB dashboard.

What to watch:

  • Episode Reward: Is your model getting better? This should increase over time
  • KL Divergence: How much is the model changing from the original? Too high means it’s drifting
  • Loss curves: Should generally trend downward

WandB training curves Example training curves showing reward improving over epochs

Tip: You can view example runs at wandb.ai/skypiloth-a/skypilot-verl to see what healthy training looks like.

Checkpoint Persistence

file_mounts:
  /checkpoints:
    store: s3  # or gcs, azure, nebius
    name: ${CHECKPOINT_BUCKET_NAME}
    mode: MOUNT_CACHED

Checkpoints are automatically saved to cloud storage every N epochs (trainer.save_freq=N). When spot instances are preempted, SkyPilot provisions new instances and VERL automatically resumes from the last checkpoint.

image

Conclusion

Training agents at scale requires breaking free from single-region constraints. By combining VERL’s production-ready distributed training with SkyPilot’s multi-region orchestration, you can

  • Access more GPUs by searching across regions globally
  • Scale your training with more compute resources
  • Reduce costs+ using spot instances with automatic recovery
  • Run anywhere: Same YAML works on any cloud or Kubernetes
  • Eliminate manual ops: Automatic provisioning, checkpointing, and recover

Next Steps

Ready to scale your agent training globally? Try the examples above and join the SkyPilot Slack community for support.