An integration package connecting xAI and LangChain
npx @tessl/cli install tessl/pypi-langchain-xai@0.2.0An integration package connecting xAI and LangChain, providing access to xAI's chat completion models through LangChain's standardized interface. This package enables developers to integrate xAI's Grok models into LangChain-based applications for chat completions, tool calling, structured output, and conversational AI workflows.
pip install langchain-xaifrom langchain_xai import ChatXAIfrom langchain_xai import ChatXAI
# Initialize the chat model
llm = ChatXAI(
model="grok-4",
temperature=0,
api_key="your-xai-api-key" # or set XAI_API_KEY environment variable
)
# Simple chat completion
messages = [
("system", "You are a helpful assistant."),
("human", "What is the capital of France?")
]
response = llm.invoke(messages)
print(response.content)
# Streaming chat completion
for chunk in llm.stream(messages):
print(chunk.content, end="")The langchain-xai package provides a single primary class ChatXAI that inherits from LangChain's BaseChatOpenAI. This design leverages the OpenAI-compatible API structure while providing xAI-specific features:
Core chat completion functionality supporting both synchronous and asynchronous operations with streaming capabilities.
class ChatXAI:
def __init__(
self,
model: str = "grok-4",
temperature: float = 1.0,
max_tokens: Optional[int] = None,
timeout: Optional[Union[float, Tuple[float, float]]] = None,
max_retries: int = 2,
api_key: Optional[str] = None,
xai_api_key: Optional[str] = None,
xai_api_base: str = "https://api.x.ai/v1/",
search_parameters: Optional[Dict[str, Any]] = None,
logprobs: Optional[bool] = None,
**kwargs
):
"""
Initialize ChatXAI model.
Parameters:
- model: Name of xAI model to use (default: "grok-4")
- temperature: Sampling temperature 0-2 (default: 1.0)
- max_tokens: Maximum tokens to generate
- timeout: Request timeout in seconds
- max_retries: Maximum retry attempts
- api_key: xAI API key (alias for xai_api_key)
- xai_api_key: xAI API key, reads from XAI_API_KEY env var if not provided
- xai_api_base: Base URL for xAI API
- search_parameters: Parameters for live search functionality
- logprobs: Whether to return log probabilities
"""
def invoke(self, input: LanguageModelInput, **kwargs) -> BaseMessage:
"""Generate chat completion for input messages."""
def stream(self, input: LanguageModelInput, **kwargs) -> Iterator[BaseMessageChunk]:
"""Stream chat completion chunks for input messages."""
async def ainvoke(self, input: LanguageModelInput, **kwargs) -> BaseMessage:
"""Async generate chat completion for input messages."""
async def astream(self, input: LanguageModelInput, **kwargs) -> AsyncIterator[BaseMessageChunk]:
"""Async stream chat completion chunks for input messages."""
def batch(self, inputs: List[LanguageModelInput], **kwargs) -> List[BaseMessage]:
"""Generate chat completions for multiple inputs."""
async def abatch(self, inputs: List[LanguageModelInput], **kwargs) -> List[BaseMessage]:
"""Async generate chat completions for multiple inputs."""Usage Example:
from langchain_xai import ChatXAI
# Basic setup
llm = ChatXAI(
model="grok-4",
temperature=0.7,
max_tokens=1000
)
# Invoke with message list
messages = [("human", "Explain quantum computing")]
response = llm.invoke(messages)
# Streaming
for chunk in llm.stream(messages):
print(chunk.content, end="")
# Async operations
import asyncio
async def chat_async():
response = await llm.ainvoke(messages)
return response
response = asyncio.run(chat_async())Support for function/tool calling with parallel execution, enabling the model to call external functions and tools.
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any
) -> Runnable[LanguageModelInput, BaseMessage]:
"""
Bind tools/functions to the model for tool calling.
Parameters:
- tools: List of tools (Pydantic models, functions, or tool definitions)
Returns:
Runnable that can invoke tools
"""Usage Example:
from pydantic import BaseModel, Field
from langchain_xai import ChatXAI
class GetWeather(BaseModel):
"""Get current weather for a location."""
location: str = Field(description="City and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
"""Get population for a location."""
location: str = Field(description="City and state, e.g. San Francisco, CA")
llm = ChatXAI(model="grok-4")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
# Model will decide which tools to call
response = llm_with_tools.invoke("Compare weather and population of LA vs NY")
print(response.tool_calls)
# Control tool choice via extra_body
llm_no_tools = ChatXAI(
model="grok-4",
extra_body={"tool_choice": "none"}
)
llm_required_tools = ChatXAI(
model="grok-4",
extra_body={"tool_choice": "required"}
)
llm_specific_tool = ChatXAI(
model="grok-4",
extra_body={
"tool_choice": {
"type": "function",
"function": {"name": "GetWeather"}
}
}
)Generate responses conforming to specified schemas using JSON schema, JSON mode, or function calling approaches.
def with_structured_output(
self,
schema: Optional[Union[Dict[str, Any], Type[BaseModel], Type]] = None,
*,
method: Literal["function_calling", "json_mode", "json_schema"] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""
Configure model to return structured output matching schema.
Parameters:
- schema: Output schema (Pydantic class, TypedDict, or OpenAI tool schema)
- method: Steering method ("function_calling", "json_schema", "json_mode")
- include_raw: If True, return both raw and parsed responses
- strict: Whether to enforce strict schema validation
Returns:
Runnable that outputs structured data
"""Usage Example:
from typing import Optional
from pydantic import BaseModel, Field
from langchain_xai import ChatXAI
class Joke(BaseModel):
"""A joke with setup and punchline."""
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline of the joke")
rating: Optional[int] = Field(description="Funniness rating 1-10")
llm = ChatXAI(model="grok-4")
structured_llm = llm.with_structured_output(Joke)
joke = structured_llm.invoke("Tell me a joke about cats")
print(f"Setup: {joke.setup}")
print(f"Punchline: {joke.punchline}")
print(f"Rating: {joke.rating}")
# Using JSON schema method
json_llm = llm.with_structured_output(
Joke,
method="json_schema",
strict=True
)
# Including raw response
raw_llm = llm.with_structured_output(
Joke,
include_raw=True
)
result = raw_llm.invoke("Tell me a joke")
# result = {"raw": BaseMessage, "parsed": Joke, "parsing_error": None}Enable Grok models to ground responses using web search results with configurable search parameters.
# Configure via search_parameters in constructor
search_parameters: Optional[Dict[str, Any]] = {
"mode": str, # Search mode, e.g. "auto"
"max_search_results": int, # Maximum search results to use
"from_date": str, # Start date for search (YYYY-MM-DD)
"to_date": str, # End date for search (YYYY-MM-DD)
}Usage Example:
from langchain_xai import ChatXAI
# Configure live search
llm = ChatXAI(
model="grok-4",
search_parameters={
"mode": "auto",
"max_search_results": 5,
"from_date": "2025-01-01",
"to_date": "2025-01-02"
}
)
# Model will use web search to ground response
response = llm.invoke("What are the latest developments in AI research?")
print(response.content)
# Citations available in Grok 3 models
if hasattr(response, 'additional_kwargs') and 'citations' in response.additional_kwargs:
citations = response.additional_kwargs['citations']
print("Sources:", citations)Access reasoning content from supported models (Grok 3) that provide transparent reasoning processes.
Usage Example:
from langchain_xai import ChatXAI
# Configure reasoning effort for Grok 3 models
llm = ChatXAI(
model="grok-3-mini",
extra_body={"reasoning_effort": "high"}
)
response = llm.invoke("Solve this logic puzzle: If all cats are animals...")
print("Response:", response.content)
# Access reasoning content
if hasattr(response, 'additional_kwargs') and 'reasoning_content' in response.additional_kwargs:
reasoning = response.additional_kwargs['reasoning_content']
print("Reasoning:", reasoning)
# Reasoning also available in streaming
for chunk in llm.stream("Complex math problem"):
if hasattr(chunk, 'additional_kwargs') and 'reasoning_content' in chunk.additional_kwargs:
print("Reasoning chunk:", chunk.additional_kwargs['reasoning_content'])
print(chunk.content, end="")Access token-level probability information for generated responses.
# Enable via logprobs parameter
logprobs: Optional[bool] = TrueUsage Example:
from langchain_xai import ChatXAI
# Enable logprobs
llm = ChatXAI(model="grok-4", logprobs=True)
# Or bind logprobs to existing model
logprobs_llm = llm.bind(logprobs=True)
response = logprobs_llm.invoke([("human", "Say Hello World!")])
# Access logprobs from response metadata
logprobs_data = response.response_metadata.get("logprobs")
if logprobs_data:
print("Tokens:", logprobs_data["tokens"])
print("Token IDs:", logprobs_data["token_ids"])
print("Log probabilities:", logprobs_data["token_logprobs"])All responses include comprehensive metadata about the generation process.
response_metadata: Dict[str, Any] = {
"token_usage": {
"completion_tokens": int,
"prompt_tokens": int,
"total_tokens": int
},
"model_name": str,
"system_fingerprint": Optional[str],
"finish_reason": str, # "stop", "length", "tool_calls", etc.
"logprobs": Optional[Dict]
}
usage_metadata: Dict[str, int] = {
"input_tokens": int,
"output_tokens": int,
"total_tokens": int
}Usage Example:
from langchain_xai import ChatXAI
llm = ChatXAI(model="grok-4")
response = llm.invoke([("human", "Hello!")])
# Access token usage
print("Usage:", response.usage_metadata)
# {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
# Access detailed metadata
print("Model:", response.response_metadata["model_name"])
print("Finish reason:", response.response_metadata["finish_reason"])
print("Token usage:", response.response_metadata["token_usage"])# Required
XAI_API_KEY = "your-xai-api-key"# Primary configuration options
model: str = "grok-4" # Model name
temperature: float = 1.0 # Sampling temperature (0-2)
max_tokens: Optional[int] = None # Maximum tokens
timeout: Optional[Union[float, Tuple[float, float]]] = None # Request timeout
max_retries: int = 2 # Retry attempts
xai_api_key: Optional[str] = None # API key
xai_api_base: str = "https://api.x.ai/v1/" # API base URL
# Advanced options
search_parameters: Optional[Dict[str, Any]] = None # Live search config
logprobs: Optional[bool] = None # Enable log probabilities
default_headers: Optional[Dict] = None # Default request headers
default_query: Optional[Dict] = None # Default query parameters
http_client: Optional[Any] = None # Custom HTTP client
http_async_client: Optional[Any] = None # Custom async HTTP clientfrom langchain_xai import ChatXAI
try:
llm = ChatXAI(model="grok-4")
response = llm.invoke(messages)
except ValueError as e:
# API key not set or invalid parameters
print(f"Configuration error: {e}")
except Exception as e:
# Network, API, or other errors
print(f"Runtime error: {e}")Common error scenarios:
ValueError if XAI_API_KEY not set and no api_key providedValueError for invalid parameter combinations (e.g., n > 1 with streaming)from typing import Any, Dict, List, Optional, Union, Tuple, Iterator, AsyncIterator, Literal
from pydantic import BaseModel, SecretStr
from langchain_core.language_models.chat_models import LanguageModelInput
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.runnables import Runnable
# Message types for input
LanguageModelInput = Union[
str,
List[Union[str, Dict[str, Any]]],
List[BaseMessage]
]
# Response types
BaseMessage # Complete response message
BaseMessageChunk # Streaming response chunk
# Tool definition types
ToolDefinition = Union[
Dict[str, Any], # OpenAI tool schema
Type[BaseModel], # Pydantic model
Callable, # Function
BaseTool # LangChain tool
]