Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 24 additions & 213 deletions src/vallm/validators/imports/python_imports.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""Python import validation."""
def _is_import_error_handler(handler: ast.ExceptHandler) -> bool:
handler_type = handler.type
if handler_type is None:
return True

if isinstance(handler_type, ast.Name):
return handler_type.id in _IMPORT_ERROR_NAMES

if isinstance(handler_type, ast.Tuple):
return any(
isinstance(element, ast.Name) and element.id in _IMPORT_ERROR_NAMES
for element in handler_type.elts
)

return False

import ast
import importlib.util
from pathlib import Path
from typing import List, Dict, Any, Set, Optional
from vallm.core.proposal import Proposal
from vallm.scoring import Issue, Severity, ValidationResult
from .base import BaseImportValidator

_IMPORT_ERROR_NAMES = frozenset(("ImportError", "ModuleNotFoundError"))
def _has_import_error_handler(handlers: list[ast.ExceptHandler]) -> bool:
return any(_is_import_error_handler(handler) for handler in handlers)


def _collect_guarded_lines(tree: ast.AST) -> Set[int]:
Expand All @@ -17,209 +25,12 @@ def _collect_guarded_lines(tree: ast.AST) -> Set[int]:
for node in ast.walk(tree):
if not isinstance(node, ast.Try):
continue
catches_import_error = any(
h.type is None
or (isinstance(h.type, ast.Name) and h.type.id in _IMPORT_ERROR_NAMES)
or (
isinstance(h.type, ast.Tuple)
and any(
isinstance(e, ast.Name) and e.id in _IMPORT_ERROR_NAMES
for e in h.type.elts
)
)
for h in node.handlers
)
if catches_import_error:
for stmt in node.body:
for n in ast.walk(stmt):
if isinstance(n, (ast.Import, ast.ImportFrom)):
guarded.add(n.lineno)
return guarded

# Common stdlib/builtin modules that importlib.util.find_spec may not find
_KNOWN_PYTHON_MODULES = {
"sys", "os", "re", "json", "math", "random", "datetime", "collections",
"functools", "itertools", "pathlib", "typing", "dataclasses", "enum",
"abc", "io", "string", "textwrap", "copy", "pprint", "warnings",
"logging", "unittest", "contextlib", "operator", "hashlib", "hmac",
"secrets", "struct", "time", "calendar", "locale", "decimal", "fractions",
"statistics", "array", "bisect", "heapq", "queue", "types", "weakref",
"inspect", "dis", "traceback", "gc", "argparse", "configparser", "csv",
"sqlite3", "urllib", "http", "email", "html", "xml", "socket", "ssl",
"select", "signal", "subprocess", "threading", "multiprocessing",
"concurrent", "asyncio", "shutil", "tempfile", "glob", "fnmatch",
"pickle", "shelve", "marshal", "dbm", "gzip", "bz2", "lzma", "zipfile",
"tarfile", "zlib", "base64", "binascii", "codecs", "unicodedata",
"difflib", "pdb", "profile", "timeit", "trace", "ast", "token",
"tokenize", "importlib", "pkgutil", "platform", "errno", "ctypes",
}


_module_exists_cache: dict[str, bool] = {}
_local_modules: frozenset[str] | None = None


def _get_local_modules() -> frozenset[str]:
"""Pre-scan cwd once for local packages/modules."""
global _local_modules
if _local_modules is None:
cwd = Path.cwd()
found: set[str] = set()
for p in cwd.iterdir():
if p.is_dir() and (p / "__init__.py").exists():
found.add(p.name)
elif p.is_file() and p.suffix == ".py" and p.stem != "__init__":
found.add(p.stem)
_local_modules = frozenset(found)
return _local_modules

if not _has_import_error_handler(node.handlers):
continue

class PythonImportValidator(BaseImportValidator):
"""Python-specific import validator."""

def validate(self, proposal: Proposal, context: dict) -> ValidationResult:
"""Validate Python imports using AST."""
issues = []
try:
tree = ast.parse(proposal.code)
imports = self.extract_imports(proposal.code)

for import_info in imports:
module_name = import_info["module"]
line = import_info["line"]
level = import_info.get("level", 0)

if level > 0:
# Relative import - resolve against source file
if not self._relative_import_exists(module_name, level, proposal.filename):
issues.append(Issue(
message=f"Relative import '{module_name}' not found",
severity=Severity.ERROR,
line=line,
rule="python.import.relative.resolvable"
))
elif not self.module_exists(module_name):
issues.append(Issue(
message=f"Module '{module_name}' not found",
severity=Severity.ERROR,
line=line,
rule="python.import.resolvable"
))

return self.create_validation_result(
issues, len(imports), len(imports) - len(issues), "python"
)

except SyntaxError as e:
return ValidationResult(
validator="imports.python",
score=0.0,
weight=self.weight,
issues=[Issue(
message=f"Syntax error: {e}",
severity=Severity.ERROR,
line=e.lineno,
rule="python.syntax"
)],
details={"error": str(e), "language": "python"},
)

def _relative_import_exists(self, module_name: str, level: int, filename: Optional[str]) -> bool:
"""Check if a relative import resolves to an existing module."""
if not filename:
# Without filename context, we can't validate relative imports
return True

source_path = Path(filename).resolve()
base_path = source_path.parent

# Go up directories for each level (beyond the first dot)
for _ in range(level - 1):
base_path = base_path.parent

# Resolve the module path
module_parts = module_name.split(".")
target_path = base_path

# Navigate through package structure
for part in module_parts[:-1] if len(module_parts) > 1 else []:
target_path = target_path / part

# Check final module
final_name = module_parts[-1] if module_parts else module_name

# Check as module file
if (target_path / f"{final_name}.py").exists():
return True

# Check as package
if (target_path / final_name / "__init__.py").exists():
return True

return False

def extract_imports(self, code: str) -> List[Dict[str, Any]]:
"""Extract import statements from Python code using AST."""
imports = []
try:
tree = ast.parse(code)
guarded = _collect_guarded_lines(tree)

for node in ast.walk(tree):
if isinstance(node, ast.Import):
if node.lineno in guarded:
continue
for alias in node.names:
imports.append({
"module": alias.name,
"line": node.lineno
})
elif isinstance(node, ast.ImportFrom):
if node.lineno in guarded:
continue
# Handle relative imports by storing the level
if node.module:
imports.append({
"module": node.module,
"line": node.lineno,
"level": node.level
})
except SyntaxError:
pass

return imports

def module_exists(self, module_name: str) -> bool:
"""Check if a Python module exists in current environment (cached)."""
top_level = module_name.split(".")[0]
if top_level in _KNOWN_PYTHON_MODULES:
return True

cached = _module_exists_cache.get(top_level)
if cached is not None:
return cached

found = False
try:
if importlib.util.find_spec(top_level) is not None:
found = True
except (ImportError, ValueError):
pass

if not found:
found = top_level in _get_local_modules()

_module_exists_cache[top_level] = found
return found

def get_language(self) -> str:
"""Get the language identifier."""
return "python"

def _get_error_message(self, module_name: str) -> str:
"""Get error message for missing module."""
return f"Module '{module_name}' not found"

def _get_rule_name(self) -> str:
"""Get rule name for validation errors."""
return "python.import.resolvable"
for stmt in node.body:
for n in ast.walk(stmt):
if isinstance(n, (ast.Import, ast.ImportFrom)):
guarded.add(n.lineno)
return guarded
Loading