Distributed AI inference for PySpark DataFrames.
spark-ai brings model-powered text processing directly into Spark transformations using a simple API.
It is designed for portability and works anywhere Spark runs: local development, EMR, Dataproc, Kubernetes, or on-prem clusters.
- Spark-native sentiment, summarization, and zero-shot text classification APIs
- Vectorized execution with
pandas_udffor better throughput than row-wise Python UDFs - Hugging Face Transformers backend
- Null-safe text handling for production pipelines
- Clean package structure for extension with additional AI tasks/backends
pip install spark-infer-ai- Python 3.10+
- Apache Spark 3.5+
- Java runtime compatible with your Spark distribution
Core dependencies are installed automatically:
pysparkpandaspyarrowtransformerstorch
from pyspark.sql import SparkSession
from spark_ai import AI
spark = (
SparkSession.builder
.appName("spark-ai-demo")
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.getOrCreate()
)
df = spark.createDataFrame(
[
("I love this product!",),
("This is the worst experience ever.",),
],
["review"],
)
ai = AI()
result = df.withColumn("sentiment", ai.sentiment("review"))
result.show(truncate=False)Expected sentiment labels are typically POSITIVE / NEGATIVE (model-dependent).
Classify custom labels with zero-shot inference:
topic_result = df.withColumn(
"topic",
ai.classify("review", labels=["urgent", "complaint", "praise"]),
)
topic_result.show(truncate=False)Summarize long-form text:
summary_result = df.withColumn("summary", ai.summarize("review"))
summary_result.show(truncate=False)Primary interface for DataFrame AI transformations.
Applies sentiment analysis to a text column and returns a Spark Column.
result = df.withColumn("sentiment", ai.sentiment("review"))Categorizes free-text into your custom labels with a zero-shot classification
result = df.withColumn(
"topic",
ai.classify("message", labels=["urgent", "spam", "normal"]),
)Generates a concise summary for long-form text and returns a Spark Column.
result = df.withColumn("summary", ai.summarize("article_text"))spark-ai uses a vectorized Pandas UDF and batched Hugging Face inference internally.
For best performance in production:
- Enable Arrow:
spark.sql.execution.arrow.pyspark.enabled=true
- Tune Spark partitions to match your cluster resources
- Tune
batch_sizefor your hardware, or enableauto_tune_batch_size=True - Run benchmarks on representative text lengths and data sizes
Model-loading behavior:
- Spark may run multiple Python workers per executor
- Each Python worker keeps its own singleton model instance
- That means model reuse is per worker process, not globally shared across all workers
You can use the included benchmark script:
python examples/benchmark_sentiment.pyExample benchmark output:
rows=20000
partitions=8
elapsed_seconds=6.383
rows_per_second=3133.1
Common Spark startup warnings like:
NativeCodeLoader: Unable to load native-hadoop library...- JDK incubator module notices
are typically informational in local environments and do not indicate a failure.
Clone and install in editable mode:
pip install -e ".[dev]"Run tests:
pytest -qsrc/spark_ai/
ai.py # Public API
config.py # Central configuration
udf/ # Vectorized Spark UDF
backends/ # Inference backend implementations
tests/unit/ # Unit tests
examples/ # Usage and benchmark scripts