April 11, 20246 minute read

Use Dagster and SkyPilot to Orchestrate Cost-Effective AI Training Jobs

Muhammad Jarir Kanji
Name
Muhammad Jarir Kanji
Handle
@muhammad

AI is becoming increasingly essential for organizations, but the associated costs can be prohibitive. GPU scarcity and the added complexity of standing up infrastructure and managing AI training jobs are other limiting factors preventing smaller enterprises from fully participating in the AI revolution.

Another challenge is the awkward handoff between data engineering and ML teams, as both functions often work in silos and use different tooling. These issues can be addressed through better data platform design. To this end, we believe that a data platform should:

  • Be heterogeneous: A data platform should seamlessly accommodate a diversity of user types (data engineers, data scientists, business stakeholders, etc.) using a variety of data storage, processing, and cloud technologies.
  • Be monolithic: Teams generating and consuming data assets across the organization benefit from having a single pane of glass, leading to improved visibility, productivity, consistency, and alignment.
  • Enable declarative flows: Imperative approaches can often result in additional complexity and cognitive overhead. Standardizing processes using a domain-specific language (DSL) can help tame some of that complexity and facilitate collaboration across teams.

These are the principles we adopt at Dagster Labs. But these goals also align with the capabilities of SkyPilot, a powerful Sky Computing framework that enables resilient and cost-effective AI/ML training jobs across cloud environments and regions.

This article explores the cost-effective and efficient solution of using Dagster and SkyPilot to orchestrate ML training jobs within a single data platform. This combination abstracts the resource acquisition and job execution through an intuitive declarative DSL.

Critically, this solution allows data engineering to invite ML teams to bring their existing ML training and inference pipelines into Dagster and orchestrate them with minimal code changes and without the need to learn Dagster internals.

Existing Dagster Cloud users can enhance their Serverless deployment with SkyPilot, allowing for seamless scalability and the ability to tap into additional computing resources (like GPUs) to accelerate machine learning workloads, without the additional overhead of spinning up your own infrastructure and migrating to a Hybrid deployment.

What is SkyPilot and Sky Computing?

SkyPilot is a framework that implements the Sky Computing paradigm, where workloads can be transparently executed on one or more clouds, abstracting the provision of resources and execution of arbitrary workloads across cloud vendors while automatically maximizing cost savings and availability for users.

In doing so, SkyPilot reduces the barrier of entry to AI by not only allowing you to find the cheapest vendor and region for resources but also making it easy to run managed jobs on spot instances, with automatic recovery after preemption. These features combined can yield both massive cost savings for your business and also significantly reduce development time reworking existing pipelines to work with other cloud vendors.

All of this takes place via a user-friendly DSL and, much like Dagster, SkyPilot boasts a great local development experience.

Project Overview

For this project, we’ll be showing how you can fine-tune Gemma, the newest open-source LLMs from Google. We’ll use the Abirate/english_quotes dataset and fine-tune Gemma to mimic the quotability of literary geniuses like Oscar Wilde.

Setup

Before we get started, here are a few preliminaries you’ll need to satisfy:

  • Step 1: Fork the project repo into your GitHub account.
  • Step 2: Sign up for a Dagster Cloud Serverless account. While the demo can be run locally, this article will focus on deploying to Dagster Cloud.
  • Step 3: Go through the Getting Started flow under the Import a Dagster project. Select the fork you created in Step 1 as your repository. (You may need to approve the Dagster integration in GitHub.)
  • Step 4: The above step will create some files under .github/workflows/ in the repo. Edit both deploy.yml and branch_deployments.yml to disable fast deployments by making the following change and commit the change. This will deploy the Dagster project using Docker. SkyPilot requires openssh and rsync as native dependencies, which are installed in the Docker image by the dagster_cloud_post_install.sh script.
env:
    ...
    ENABLE_FAST_DEPLOYS: 'false' # This was originally true

Once the GitHub Actions workflow has finished running, the project will be deployed to Dagster Cloud.

  • Step 5: While SkyPilot allows you to orchestrate jobs on a variety of cloud providers (and even choose between them for the best deals), we’ll be using AWS for this demo. To that end, please set up the minimal required permissions for SkyPilot by following the instructions here.
  • Step 6: Create an S3 bucket to use for this example.
  • Step 7: Create an account on Hugging Face, generate a read-only access token for your account, and agree to Google’s terms and conditions for using Gemma on the model page here.
  • Step 8: Go to the Deployment > Environment Variables tab in Dagster Cloud and set up the following variables:
    • SKYPILOT_BUCKET with the name of the bucket you created in Step 6.
    • HF_TOKEN with the token you got in Step 7.
    • AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY with your AWS credentials.

Defining Tasks in SkyPilot

Before spinning up a training run, let’s go over how you can define SkyPilot jobs. SkyPilot allows you to use either a YAML file or its Python API for this purpose. We’ll go over SkyPilot’s DSL for defining jobs and go through the different sections of the finetune.yaml file below.

One of SkyPilot’s greatest strengths lies in abstracting the provision of cloud resources across multiple vendors. In the snippet below, we declare that we want to run our job on AWS using any of the accelerators in the list. SkyPilot will then automatically search for instances with one of these accelerators, check their availability, and provision a cluster with the most cost effective option that satisfies your requirements.

resources:
    cloud: aws
    accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB}
    disk_tier: best

We set the values of the environment variables we want to propagate to the remote cluster using the env key. Some of the variables are set to an empty string, because we will override them with the Dagster process that invokes SkyPilot.

envs:
    # The first three env vars are left empty and overwritten by the Dagster process.
    DAGSTER_RUN_ID: "no-run" # The ID of the Dagster run that triggered the job.
    HF_TOKEN: ""
    SKYPILOT_BUCKET: ""
    MAX_STEPS: 10
    TERM: "dumb"
    NO_COLOR: 1

SkyPilot lets you easily sync local files to the remote cluster, as well as mounting cloud storage onto the cluster. See the docs on Syncing Code and Artifacts on their website for more information.

The workdir key can be used to define the working directory for the setup and run commands (covered below) and its contents will be synced to the remote cluster in the ~/sky_workdir directory. In this case, we want to sync the scripts subdirectory containing the lora.py script.

workdir: dagster_skypilot/scripts

Additionally, we are also mounting the S3 bucket we created earlier to the /artifacts path in the remote cluster. This is where we’ll save intermediate checkpoints and the trained model. If we use spot instances to run the task, SkyPilot can use these checkpoints to recover from preemptions/interruptions without losing progress.

file_mounts:
    /artifacts:
        source: ${SKYPILOT_BUCKET}
        mode: MOUNT

The setup command defines a setup script to be run when you launch the cluster for the first time. This is useful for installing any dependencies and setting up your environment correctly, which is exactly what we use it for here.

setup: |
    conda activate gemma
    if [ $? -ne 0 ]; then
        conda create -q -y -n gemma python=3.10
        conda activate gemma
    fi
    echo "Installing Python dependencies."
    pip install -q -U bitsandbytes==0.42.0
    pip install -q -U peft==0.8.2
    pip install -q -U trl==0.7.10
    pip install -q -U accelerate==0.27.1
    pip install -q -U datasets==2.17.0
    pip install -q -U transformers==4.38.1
    pip install -q "torch<2.2" torchvision --index-url 
https://download.pytorch.org/whl/cu121

Finally, the run command defines the main program that’s run on every node in the remote cluster. We use it to run the lora.py training script, which contains the logic for training the Gemma model.

run: |
    conda activate gemma

    NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
    HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`

    # Turn off wandb
    WANDB_MODE="offline"

    TERM=dumb NO_COLOR=1 torchrun \
        --nnodes=$NUM_NODES \
        --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
        --master_port=12375 \
        --master_addr=$HOST_ADDR \
        --node_rank=${SKYPILOT_NODE_RANK} \
        lora.py \
        --model_name_or_path google/gemma-7b \
        --save_steps 4 \
        --max_steps ${MAX_STEPS} \
        --output_dir /artifacts/${DAGSTER_RUN_ID}

Take note of the fact that the output_dir is set to change based on the DAGSTER_RUN_ID. This ensures that the outputs of each run are stored in separate sub-directories in the S3 bucket, instead of overwriting the files from the previous run.

The details of the lora.py training script are out of the scope of this article and won’t be discussed in detail, though you’re welcome to go through the code yourself.

Defining the Dagster Asset

You can interactively call the task defined above by running the following command from the terminal:

### Run this from the root of the project
SKYPILOT_BUCKET="s3://<your-skypilot-bucket>" HF_TOKEN="<your-token>" \
    sky launch -c gemma dagster_skypilot/finetune.yaml \
    --env HF_TOKEN --env SKYPILOT_BUCKET

However, we want to orchestrate the job using Dagster. Thankfully, SkyPilot provides an easy-to-use Python API for triggering jobs, alongside the CLI. Running the task defined in the YAML file is as simple as creating a software-defined asset.

We first define a dagster.Config object with different parameters that we want to be configurable from within the Dagster UI. In this case, that’s just the number of training steps and whether the task is run as a Managed Spot Job by SkyPilot (more on that later).

class *SkyPilotConfig(Config):
        """A minimal set of configurations for SkyPilot. This is NOT intended as a complete or exhaustive representation of a Task YAML config."""

    max_steps: int = Field(
        default=10, description="Number of training steps to perform."
    )
    spot_launch: bool = Field(
        default=False, description="Should the task be run as a managed spot job?"
    )

We then use this Config object in our software-defined asset, which reads the task’s YAML definition and overrides the required environment variables based on their definitions in Dagster Cloud.

Note that because HF_TOKEN is a secret, we don’t want to have the real value defined in the YAML file itself. Instead, we set the value as an environment variable in Dagster Cloud and override the default values at runtime. We use the same trick to dynamically populate the Dagster Run ID, the name of the bucket to be used, and the number of training steps at runtime.

@asset(group_name="ai")
def skypilot_model(context: AssetExecutionContext, config: SkyPilotConfig) -> 
None:
    # SkyPilot doesn't support reading credentials from environment variables.
    # So, we need to populate the required keyfiles.
    populate_keyfiles()
    skypilot_bucket = os.getenv("SKYPILOT_BUCKET")

### The parent of the current script*
parent_dir = UPath(__file__).parent
yaml_file = parent_dir / "finetune.yaml"
with yaml_file.open("r", encoding="utf-8") as f:
    task_config = yaml.safe_load(f)

task = sky.Task().from_yaml_config(
    config=task_config,
    env_overrides={
        "HF_TOKEN": os.getenv("HF_TOKEN", ""),
        "DAGSTER_RUN_ID": context.run_id,
        "BUCKET_NAME": skypilot_bucket,
        "MAX_STEPS": config.max_steps,
        },
)
task.workdir = str(parent_dir.absolute() / "scripts")
...

Having defined the task, we check the configuration to determine if the task should be run as a managed spot job and launch the task.

@asset(group_name="ai")
def skypilot_model(context: AssetExecutionContext, config: SkyPilotConfig) -> None:
    ...
    try:
        if config.spot_launch:
            ...
            sky.spot_launch(task, name="gemma")
        else:
            ...
            sky.launch(task, cluster_name="gemma")

        context.log.info("Task completed.")
        context.add_output_metadata(get_metrics(context, skypilot_bucket))

    finally:
        teardown_all_clusters(context.log)

It’s important to note here that in order to shut down any clusters it started, SkyPilot needs some metadata that it stores locally. If you’re on Dagster Cloud, however, the runner that an asset op executes on is ephemeral and will be shut down after the run is complete.

To ensure that an error in the op logic doesn’t immediately kill the runner machine and prevent us from shutting down the remote cluster, we wrap that logic in a finally clause. In order to be absolutely sure you don’t have idle resources that continue to cost you, though, it’s a good idea to check the EC2 Dashboard for any running machines.

Finally, because SkyPilot cannot directly return data from the remote cluster to the Python process, we save the final metrics to the S3 bucket (see the lora.py script for the details of the implementation). The get_metrics function reads the metrics file from S3 and logs them as metadata for the Dagster run. This will be used to visualize the training metrics across runs.

Using Dagster to Orchestrate the SkyPilot Job

To run the job, go to your Dagster Cloud instance and select the skypilot_model asset in the Asset Graph. Then, hold the Shift key and click on the Materialize Selected button. This will open the Launchpad and allow us to populate the parameters we defined as part of SkyPilotConfig.

For now, leave the default values and click on Materialize. This will launch a new run and show you the logs from the process.

The logs from SkyPilot and the remote cluster are sent to stdout and are not visible under the Events pane. Click on the stdout pane to see them. Note that if you’re using Dagster Cloud, they will not be visible until after the run finishes; this should take 10 - 15 minutes. If you’re running the job locally, the stdout logs will start streaming immediately.

Once the run has completed, the training metrics are logged into the run metadata in Dagster. We can use the Plots tab in the Asset Catalog page to compare metrics across runs.

I re-ran the same job for 20, 30, 40, and 50 max_steps, and as one might expect, training for a larger number of epochs decreases the training loss.

You can use Dagster's robust metadata logging to similarly track other useful bits of information, such as which instance type SkyPilot chose for a given run or how much the run cost.

Cutting Costs with Managed Spot Jobs

One other SkyPilot feature that we’d like to call out here is managed spot jobs. Cloud providers like AWS offer spot instances for up to 90% cheaper than on-demand. Given how expensive GPU instances can be, using spot instances can result in significant savings for your organization.

The catch, of course, is that the instance can be preempted (i.e., shut down) by AWS at any time in response to a surge in demand. SkyPilot helps you navigate this uncertainty by managing the job for you. Not only does it find the region with the best pricing (the same spot instance can be 2x cheaper in another region) but if an instance is preempted, SkyPilot will also immediately spin up a new cluster in another region or on another cloud and continue your training job with minimal loss of progress!

Let’s try doing another job and setting the spot_launch parameter in the Launchpad to true. This will run the same task (with zero code changes required) as a spot job.

Comparing the logs between the earlier runs (above) and the spot job (below), we can see that SkyPilot was able to find a spot instance with the same specs for almost 3x cheaper than before!

Conclusion

Orchestrating AI training jobs using Dagster and SkyPilot can help you:

  • Execute AI training jobs in a cost-effective fashion with plug-in support for spot instances and automatic recovery from preemption.
  • Benefit from declarative code and a simple DSL to quickly onboard ML teams and centralize orchestration with Dagster as your single pane of glass.
  • Overcome the limitations of a Serverless deployment without migrating to a Hybrid deployment or spinning up your own clusters.

Learn more about SkyPilot at their GitHub repo and docs. If you’re interested in exploring more, make sure to fork the project repo and spin up a few training runs on Dagster Cloud.


The Dagster Labs logo

We're always happy to hear your feedback, so please reach out to us! If you have any questions, ask them in the Dagster community Slack (join here!) or start a Github discussion. If you run into any bugs, let us know with a Github issue. And if you're interested in working with us, check out our open roles!

Follow us:


Read more filed under
Blog post category for Blog Post. Blog Post