Skip to content

Pooling models¶

Source examples/online_serving/pooling.

Cohere rerank usage¶

python examples/online_serving/pooling/cohere_rerank_client.py

Jinaai rerank usage¶

python examples/online_serving/pooling/jinaai_rerank_client.py

Openai chat embedding for multimodal usage¶

python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py

Openai classification usage¶

python examples/online_serving/pooling/openai_classification_client.py

Openai embedding usage¶

python examples/online_serving/pooling/openai_embedding_client.py

Openai embedding matryoshka dimensions usage¶

python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py

Openai pooling usage¶

python examples/online_serving/pooling/openai_pooling_client.py

Example materials¶

cohere_rerank_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
Note that `pip install cohere` is needed to run this example.

run: vllm serve BAAI/bge-reranker-base
"""

from typing import Union

import cohere
from cohere import Client, ClientV2

model = "BAAI/bge-reranker-base"

query = "What is the capital of France?"

documents = [
    "The capital of France is Paris",
    "Reranking is fun!",
    "vLLM is an open-source framework for fast AI serving",
]


def cohere_rerank(
    client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
) -> dict:
    return client.rerank(model=model, query=query, documents=documents)


def main():
    # cohere v1 client
    cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
    rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
    print("-" * 50)
    print("rerank_v1_result:\n", rerank_v1_result)
    print("-" * 50)

    # or the v2
    cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
    rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
    print("rerank_v2_result:\n", rerank_v2_result)
    print("-" * 50)


if __name__ == "__main__":
    main()
jinaai_rerank_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker

run: vllm serve BAAI/bge-reranker-base
"""

import json

import requests

url = "http://127.0.0.1:8000/rerank"

headers = {"accept": "application/json", "Content-Type": "application/json"}

data = {
    "model": "BAAI/bge-reranker-base",
    "query": "What is the capital of France?",
    "documents": [
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
        "Horses and cows are both animals",
    ],
}


def main():
    response = requests.post(url, headers=headers, json=data)

    # Check the response
    if response.status_code == 200:
        print("Request successful!")
        print(json.dumps(response.json(), indent=2))
    else:
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)


if __name__ == "__main__":
    main()
openai_chat_embedding_client_for_multimodal.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Example Python client for multimodal embedding API using vLLM API server
NOTE:
    start a supported multimodal embeddings model server with `vllm serve`, e.g.
    vllm serve TIGER-Lab/VLM2Vec-Full --runner pooling --trust_remote_code --max_model_len=1024
"""

import argparse
import base64
import io

import requests
from PIL import Image

image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"


def vlm2vec():
    response = requests.post(
        "http://localhost:8000/v1/embeddings",
        json={
            "model": "TIGER-Lab/VLM2Vec-Full",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": image_url}},
                        {"type": "text", "text": "Represent the given image."},
                    ],
                }
            ],
            "encoding_format": "float",
        },
    )
    response.raise_for_status()
    response_json = response.json()

    print("Embedding output:", response_json["data"][0]["embedding"])


def dse_qwen2_vl(inp: dict):
    # Embedding an Image
    if inp["type"] == "image":
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": inp["image_url"],
                        },
                    },
                    {"type": "text", "text": "What is shown in this image?"},
                ],
            }
        ]
    # Embedding a Text Query
    else:
        # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
        # of the minimum input size
        buffer = io.BytesIO()
        image_placeholder = Image.new("RGB", (56, 56))
        image_placeholder.save(buffer, "png")
        buffer.seek(0)
        image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image_placeholder}",
                        },
                    },
                    {"type": "text", "text": f"Query: {inp['content']}"},
                ],
            }
        ]

    response = requests.post(
        "http://localhost:8000/v1/embeddings",
        json={
            "model": "MrLight/dse-qwen2-2b-mrl-v1",
            "messages": messages,
            "encoding_format": "float",
        },
    )
    response.raise_for_status()
    response_json = response.json()

    print("Embedding output:", response_json["data"][0]["embedding"])


def parse_args():
    parser = argparse.ArgumentParser(
        "Script to call a specified VLM through the API. Make sure to serve "
        "the model with `--runner pooling` before running this."
    )
    parser.add_argument(
        "--model",
        type=str,
        choices=["vlm2vec", "dse_qwen2_vl"],
        required=True,
        help="Which model to call.",
    )
    return parser.parse_args()


def main(args):
    if args.model == "vlm2vec":
        vlm2vec()
    elif args.model == "dse_qwen2_vl":
        dse_qwen2_vl(
            {
                "type": "image",
                "image_url": image_url,
            }
        )
        dse_qwen2_vl(
            {
                "type": "text",
                "content": "What is the weather like today?",
            }
        )


if __name__ == "__main__":
    args = parse_args()
    main(args)
openai_classification_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for classification API using vLLM API server
NOTE:
    start a supported classification model server with `vllm serve`, e.g.
    vllm serve jason9693/Qwen2.5-1.5B-apeach
"""

import argparse
import pprint

import requests


def post_http_request(payload: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    response = requests.post(api_url, headers=headers, json=payload)
    return response


def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument("--host", type=str, default="localhost")
    parse.add_argument("--port", type=int, default=8000)
    parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
    return parse.parse_args()


def main(args):
    host = args.host
    port = args.port
    model_name = args.model

    api_url = f"http://{host}:{port}/classify"
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    payload = {
        "model": model_name,
        "input": prompts,
    }

    classify_response = post_http_request(payload=payload, api_url=api_url)
    pprint.pprint(classify_response.json())


if __name__ == "__main__":
    args = parse_args()
    main(args)
openai_embedding_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
    start a supported embeddings model server with `vllm serve`, e.g.
    vllm serve intfloat/e5-small
"""

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"


def main():
    client = OpenAI(
        # defaults to os.environ.get("OPENAI_API_KEY")
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    models = client.models.list()
    model = models.data[0].id

    responses = client.embeddings.create(
        # ruff: noqa: E501
        input=[
            "Hello my name is",
            "The best thing about vLLM is that it supports many different models",
        ],
        model=model,
    )

    for data in responses.data:
        print(data.embedding)  # List of float of len 4096


if __name__ == "__main__":
    main()
openai_embedding_matryoshka_fy.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API dimensions using vLLM API server
NOTE:
    start a supported Matryoshka Embeddings model server with `vllm serve`, e.g.
    vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
"""

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"


def main():
    client = OpenAI(
        # defaults to os.environ.get("OPENAI_API_KEY")
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    models = client.models.list()
    model = models.data[0].id

    responses = client.embeddings.create(
        input=["Follow the white rabbit."],
        model=model,
        dimensions=32,
    )

    for data in responses.data:
        print(data.embedding)  # List of float of len 32


if __name__ == "__main__":
    main()
openai_pooling_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Pooling API.

Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.

vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
"""

import argparse
import pprint

import requests


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward")

    return parser.parse_args()


def main(args):
    api_url = f"http://{args.host}:{args.port}/pooling"
    model_name = args.model

    # Input like Completions API
    prompt = {"model": model_name, "input": "vLLM is great!"}
    pooling_response = post_http_request(prompt=prompt, api_url=api_url)
    print("-" * 50)
    print("Pooling Response:")
    pprint.pprint(pooling_response.json())
    print("-" * 50)

    # Input like Chat API
    prompt = {
        "model": model_name,
        "messages": [
            {
                "role": "user",
                "content": [{"type": "text", "text": "vLLM is great!"}],
            }
        ],
    }
    pooling_response = post_http_request(prompt=prompt, api_url=api_url)
    print("Pooling Response:")
    pprint.pprint(pooling_response.json())
    print("-" * 50)


if __name__ == "__main__":
    args = parse_args()
    main(args)