feat: smart integration detection with package classifier
- Add PackageClassifier with built-in dictionary (~200 popular packages) - Hardcode Python 3.10+ stdlib list to filter out standard library imports - Add PyPI API lookup for unknown packages (online mode, 3s timeout) - Cache PyPI results in .wtismycode/cache/pypi.json - Add --offline flag to skip PyPI lookups - Classify packages into: HTTP, Database, Queue, Storage, AI/ML, Auth, Testing, Logging, Internal, Third-party - User config integration_patterns override auto-detection - Update renderer to show integrations grouped by category - Update ARCHITECTURE.md template with new integration format
This commit is contained in:
@@ -16,3 +16,5 @@ rustpython-parser = "0.4"
|
||||
rustpython-ast = "0.4"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
tempfile = "3.10"
|
||||
ureq = "3"
|
||||
lazy_static = "1.4"
|
||||
|
||||
@@ -13,6 +13,7 @@ pub mod renderer;
|
||||
pub mod writer;
|
||||
pub mod cache;
|
||||
pub mod cycle_detector;
|
||||
pub mod package_classifier;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use errors::WTIsMyCodeError;
|
||||
|
||||
@@ -12,6 +12,9 @@ pub struct ProjectModel {
|
||||
pub files: HashMap<String, FileDoc>,
|
||||
pub symbols: HashMap<String, Symbol>,
|
||||
pub edges: Edges,
|
||||
/// Classified integrations by category (e.g. "HTTP" -> ["fastapi", "requests"])
|
||||
#[serde(default)]
|
||||
pub classified_integrations: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl ProjectModel {
|
||||
@@ -21,6 +24,7 @@ impl ProjectModel {
|
||||
files: HashMap::new(),
|
||||
symbols: HashMap::new(),
|
||||
edges: Edges::new(),
|
||||
classified_integrations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
452
wtismycode-core/src/package_classifier.rs
Normal file
452
wtismycode-core/src/package_classifier.rs
Normal file
@@ -0,0 +1,452 @@
|
||||
//! Package classifier for Python imports
|
||||
//!
|
||||
//! Classifies Python packages into categories using:
|
||||
//! 1. Python stdlib list (hardcoded)
|
||||
//! 2. Built-in dictionary (~200 popular packages)
|
||||
//! 3. PyPI API lookup (online mode)
|
||||
//! 4. Internal package detection (fallback)
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
pub enum PackageCategory {
|
||||
Stdlib,
|
||||
Http,
|
||||
Database,
|
||||
Queue,
|
||||
Storage,
|
||||
AiMl,
|
||||
Testing,
|
||||
Logging,
|
||||
Auth,
|
||||
Internal,
|
||||
ThirdParty,
|
||||
}
|
||||
|
||||
impl PackageCategory {
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Stdlib => "Stdlib",
|
||||
Self::Http => "HTTP",
|
||||
Self::Database => "Database",
|
||||
Self::Queue => "Queue",
|
||||
Self::Storage => "Storage",
|
||||
Self::AiMl => "AI/ML",
|
||||
Self::Testing => "Testing",
|
||||
Self::Logging => "Logging",
|
||||
Self::Auth => "Auth",
|
||||
Self::Internal => "Internal",
|
||||
Self::ThirdParty => "Third-party",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of classifying all imports in a project
|
||||
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ClassifiedIntegrations {
|
||||
/// category -> list of package names
|
||||
pub by_category: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
pub struct PackageClassifier {
|
||||
offline: bool,
|
||||
cache_dir: Option<String>,
|
||||
/// user overrides from config integration_patterns
|
||||
user_overrides: HashMap<String, PackageCategory>,
|
||||
/// PyPI cache: package_name -> Option<PackageCategory> (None = not found)
|
||||
pypi_cache: HashMap<String, Option<PackageCategory>>,
|
||||
}
|
||||
|
||||
impl PackageClassifier {
|
||||
pub fn new(offline: bool, cache_dir: Option<String>) -> Self {
|
||||
let mut classifier = Self {
|
||||
offline,
|
||||
cache_dir: cache_dir.clone(),
|
||||
user_overrides: HashMap::new(),
|
||||
pypi_cache: HashMap::new(),
|
||||
};
|
||||
// Load PyPI cache from disk
|
||||
if let Some(ref dir) = cache_dir {
|
||||
classifier.load_pypi_cache(dir);
|
||||
}
|
||||
classifier
|
||||
}
|
||||
|
||||
/// Add user overrides from config integration_patterns
|
||||
pub fn add_user_overrides(&mut self, patterns: &[(String, Vec<String>)]) {
|
||||
for (type_name, pkgs) in patterns {
|
||||
let cat = match type_name.as_str() {
|
||||
"http" => PackageCategory::Http,
|
||||
"db" => PackageCategory::Database,
|
||||
"queue" => PackageCategory::Queue,
|
||||
"storage" => PackageCategory::Storage,
|
||||
"ai" => PackageCategory::AiMl,
|
||||
"testing" => PackageCategory::Testing,
|
||||
"logging" => PackageCategory::Logging,
|
||||
"auth" => PackageCategory::Auth,
|
||||
_ => PackageCategory::ThirdParty,
|
||||
};
|
||||
for pkg in pkgs {
|
||||
self.user_overrides.insert(pkg.to_lowercase(), cat.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a single package name (top-level import)
|
||||
pub fn classify(&mut self, package_name: &str) -> PackageCategory {
|
||||
let normalized = normalize_package_name(package_name);
|
||||
|
||||
// 1. User overrides take priority
|
||||
if let Some(cat) = self.user_overrides.get(&normalized) {
|
||||
return cat.clone();
|
||||
}
|
||||
|
||||
// 2. Built-in dictionary (check BEFORE stdlib, so sqlite3 etc. are categorized properly)
|
||||
if let Some(cat) = builtin_lookup(&normalized) {
|
||||
return cat;
|
||||
}
|
||||
|
||||
// 3. Stdlib
|
||||
if is_stdlib(&normalized) {
|
||||
return PackageCategory::Stdlib;
|
||||
}
|
||||
|
||||
// 4. PyPI lookup (if online)
|
||||
if !self.offline {
|
||||
if let Some(cached) = self.pypi_cache.get(&normalized) {
|
||||
return cached.clone().unwrap_or(PackageCategory::Internal);
|
||||
}
|
||||
match self.pypi_lookup(&normalized) {
|
||||
Some(cat) => {
|
||||
self.pypi_cache.insert(normalized, Some(cat.clone()));
|
||||
return cat;
|
||||
}
|
||||
None => {
|
||||
self.pypi_cache.insert(normalized, None);
|
||||
return PackageCategory::Internal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Offline fallback: if not in stdlib or dictionary, assume internal
|
||||
PackageCategory::Internal
|
||||
}
|
||||
|
||||
/// Classify all imports and return grouped integrations
|
||||
pub fn classify_all(&mut self, import_names: &[String]) -> ClassifiedIntegrations {
|
||||
let mut result = ClassifiedIntegrations::default();
|
||||
let mut seen: HashMap<String, PackageCategory> = HashMap::new();
|
||||
|
||||
for import in import_names {
|
||||
let top_level = top_level_package(import);
|
||||
if seen.contains_key(&top_level) {
|
||||
continue;
|
||||
}
|
||||
let cat = self.classify(&top_level);
|
||||
seen.insert(top_level.clone(), cat.clone());
|
||||
|
||||
// Skip stdlib and third-party without category
|
||||
if cat == PackageCategory::Stdlib {
|
||||
continue;
|
||||
}
|
||||
|
||||
let category_name = cat.display_name().to_string();
|
||||
result.by_category
|
||||
.entry(category_name)
|
||||
.or_default()
|
||||
.push(top_level);
|
||||
}
|
||||
|
||||
// Deduplicate and sort each category
|
||||
for pkgs in result.by_category.values_mut() {
|
||||
pkgs.sort();
|
||||
pkgs.dedup();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Save PyPI cache to disk
|
||||
pub fn save_cache(&self) {
|
||||
if let Some(ref dir) = self.cache_dir {
|
||||
let cache_path = Path::new(dir).join("pypi.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(&self.pypi_cache) {
|
||||
let _ = std::fs::create_dir_all(dir);
|
||||
let _ = std::fs::write(&cache_path, json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_pypi_cache(&mut self, dir: &str) {
|
||||
let cache_path = Path::new(dir).join("pypi.json");
|
||||
if let Ok(content) = std::fs::read_to_string(&cache_path) {
|
||||
if let Ok(cache) = serde_json::from_str::<HashMap<String, Option<PackageCategory>>>(&content) {
|
||||
self.pypi_cache = cache;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pypi_lookup(&self, package_name: &str) -> Option<PackageCategory> {
|
||||
let url = format!("https://pypi.org/pypi/{}/json", package_name);
|
||||
|
||||
let agent = ureq::Agent::new_with_config(
|
||||
ureq::config::Config::builder()
|
||||
.timeout_global(Some(std::time::Duration::from_secs(3)))
|
||||
.build()
|
||||
);
|
||||
|
||||
let response = agent.get(&url).call().ok()?;
|
||||
|
||||
if response.status() != 200 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let body_str = response.into_body().read_to_string().ok()?;
|
||||
let body: serde_json::Value = serde_json::from_str(&body_str).ok()?;
|
||||
let info = body.get("info")?;
|
||||
|
||||
// Check classifiers
|
||||
if let Some(classifiers) = info.get("classifiers").and_then(|c: &serde_json::Value| c.as_array()) {
|
||||
for classifier in classifiers {
|
||||
if let Some(s) = classifier.as_str() {
|
||||
if let Some(cat) = classify_from_pypi_classifier(s) {
|
||||
return Some(cat);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check summary and keywords for hints
|
||||
let summary = info.get("summary").and_then(|s: &serde_json::Value| s.as_str()).unwrap_or("");
|
||||
let keywords = info.get("keywords").and_then(|s: &serde_json::Value| s.as_str()).unwrap_or("");
|
||||
let combined = format!("{} {}", summary, keywords).to_lowercase();
|
||||
|
||||
if combined.contains("database") || combined.contains("sql") || combined.contains("orm") {
|
||||
return Some(PackageCategory::Database);
|
||||
}
|
||||
if combined.contains("http") || combined.contains("web framework") || combined.contains("rest api") {
|
||||
return Some(PackageCategory::Http);
|
||||
}
|
||||
if combined.contains("queue") || combined.contains("message broker") || combined.contains("amqp") || combined.contains("kafka") {
|
||||
return Some(PackageCategory::Queue);
|
||||
}
|
||||
if combined.contains("storage") || combined.contains("s3") || combined.contains("blob") {
|
||||
return Some(PackageCategory::Storage);
|
||||
}
|
||||
if combined.contains("machine learning") || combined.contains("deep learning") || combined.contains("neural") || combined.contains("artificial intelligence") {
|
||||
return Some(PackageCategory::AiMl);
|
||||
}
|
||||
if combined.contains("testing") || combined.contains("test framework") {
|
||||
return Some(PackageCategory::Testing);
|
||||
}
|
||||
if combined.contains("logging") || combined.contains("error tracking") {
|
||||
return Some(PackageCategory::Logging);
|
||||
}
|
||||
if combined.contains("authentication") || combined.contains("jwt") || combined.contains("oauth") {
|
||||
return Some(PackageCategory::Auth);
|
||||
}
|
||||
|
||||
// Found on PyPI but no category detected
|
||||
Some(PackageCategory::ThirdParty)
|
||||
}
|
||||
}
|
||||
|
||||
fn classify_from_pypi_classifier(classifier: &str) -> Option<PackageCategory> {
|
||||
let c = classifier.to_lowercase();
|
||||
if c.contains("framework :: django") || c.contains("framework :: flask") ||
|
||||
c.contains("framework :: fastapi") || c.contains("framework :: tornado") ||
|
||||
c.contains("framework :: aiohttp") || c.contains("topic :: internet :: www") {
|
||||
return Some(PackageCategory::Http);
|
||||
}
|
||||
if c.contains("topic :: database") {
|
||||
return Some(PackageCategory::Database);
|
||||
}
|
||||
if c.contains("topic :: scientific/engineering :: artificial intelligence") ||
|
||||
c.contains("topic :: scientific/engineering :: machine learning") {
|
||||
return Some(PackageCategory::AiMl);
|
||||
}
|
||||
if c.contains("topic :: software development :: testing") {
|
||||
return Some(PackageCategory::Testing);
|
||||
}
|
||||
if c.contains("topic :: system :: logging") {
|
||||
return Some(PackageCategory::Logging);
|
||||
}
|
||||
if c.contains("topic :: security") && (classifier.contains("auth") || classifier.contains("Auth")) {
|
||||
return Some(PackageCategory::Auth);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract top-level package name from an import string
|
||||
/// e.g. "sqlalchemy.orm.Session" -> "sqlalchemy"
|
||||
fn top_level_package(import: &str) -> String {
|
||||
import.split('.').next().unwrap_or(import).to_lowercase()
|
||||
}
|
||||
|
||||
/// Normalize package name for lookup (lowercase, replace hyphens with underscores)
|
||||
fn normalize_package_name(name: &str) -> String {
|
||||
name.to_lowercase().replace('-', "_")
|
||||
}
|
||||
|
||||
/// Check if a package is in the Python standard library
|
||||
fn is_stdlib(name: &str) -> bool {
|
||||
PYTHON_STDLIB.contains(&name)
|
||||
}
|
||||
|
||||
/// Look up a package in the built-in dictionary
|
||||
fn builtin_lookup(name: &str) -> Option<PackageCategory> {
|
||||
for (cat, pkgs) in BUILTIN_PACKAGES.iter() {
|
||||
if pkgs.contains(&name) {
|
||||
return Some(cat.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Python 3.10+ standard library modules
|
||||
const PYTHON_STDLIB: &[&str] = &[
|
||||
"__future__", "_thread", "abc", "aifc", "argparse", "array", "ast",
|
||||
"asynchat", "asyncio", "asyncore", "atexit", "audioop", "base64",
|
||||
"bdb", "binascii", "binhex", "bisect", "builtins", "bz2",
|
||||
"calendar", "cgi", "cgitb", "chunk", "cmath", "cmd", "code",
|
||||
"codecs", "codeop", "collections", "colorsys", "compileall",
|
||||
"concurrent", "configparser", "contextlib", "contextvars", "copy",
|
||||
"copyreg", "cprofile", "crypt", "csv", "ctypes", "curses",
|
||||
"dataclasses", "datetime", "dbm", "decimal", "difflib", "dis",
|
||||
"distutils", "doctest", "email", "encodings", "enum", "errno",
|
||||
"faulthandler", "fcntl", "filecmp", "fileinput", "fnmatch",
|
||||
"formatter", "fractions", "ftplib", "functools", "gc", "getopt",
|
||||
"getpass", "gettext", "glob", "grp", "gzip", "hashlib", "heapq",
|
||||
"hmac", "html", "http", "idlelib", "imaplib", "imghdr", "imp",
|
||||
"importlib", "inspect", "io", "ipaddress", "itertools", "json",
|
||||
"keyword", "lib2to3", "linecache", "locale", "logging", "lzma",
|
||||
"mailbox", "mailcap", "marshal", "math", "mimetypes", "mmap",
|
||||
"modulefinder", "multiprocessing", "netrc", "nis", "nntplib",
|
||||
"numbers", "operator", "optparse", "os", "ossaudiodev", "parser",
|
||||
"pathlib", "pdb", "pickle", "pickletools", "pipes", "pkgutil",
|
||||
"platform", "plistlib", "poplib", "posix", "posixpath", "pprint",
|
||||
"profile", "pstats", "pty", "pwd", "py_compile", "pyclbr",
|
||||
"pydoc", "queue", "quopri", "random", "re", "readline", "reprlib",
|
||||
"resource", "rlcompleter", "runpy", "sched", "secrets", "select",
|
||||
"selectors", "shelve", "shlex", "shutil", "signal", "site",
|
||||
"smtpd", "smtplib", "sndhdr", "socket", "socketserver", "spwd",
|
||||
"sqlite3", "ssl", "stat", "statistics", "string", "stringprep",
|
||||
"struct", "subprocess", "sunau", "symtable", "sys", "sysconfig",
|
||||
"syslog", "tabnanny", "tarfile", "telnetlib", "tempfile", "termios",
|
||||
"test", "textwrap", "threading", "time", "timeit", "tkinter",
|
||||
"token", "tokenize", "tomllib", "trace", "traceback", "tracemalloc",
|
||||
"tty", "turtle", "turtledemo", "types", "typing", "unicodedata",
|
||||
"unittest", "urllib", "uu", "uuid", "venv", "warnings", "wave",
|
||||
"weakref", "webbrowser", "winreg", "winsound", "wsgiref", "xdrlib",
|
||||
"xml", "xmlrpc", "zipapp", "zipfile", "zipimport", "zlib",
|
||||
// Common sub-packages that appear as top-level imports
|
||||
"os.path", "collections.abc", "concurrent.futures", "typing_extensions",
|
||||
];
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref BUILTIN_PACKAGES: Vec<(PackageCategory, Vec<&'static str>)> = vec![
|
||||
(PackageCategory::Http, vec![
|
||||
"requests", "httpx", "aiohttp", "fastapi", "flask", "django",
|
||||
"starlette", "uvicorn", "gunicorn", "tornado", "sanic", "bottle",
|
||||
"falcon", "quart", "werkzeug", "httptools", "uvloop", "hypercorn",
|
||||
"grpcio", "grpc", "graphene", "strawberry", "ariadne",
|
||||
"pydantic", "marshmallow", "connexion", "responder", "hug",
|
||||
]),
|
||||
(PackageCategory::Database, vec![
|
||||
"sqlalchemy", "psycopg2", "psycopg", "asyncpg", "pymongo",
|
||||
"mongoengine", "peewee", "tortoise", "databases",
|
||||
"alembic", "pymysql", "opensearch", "elasticsearch",
|
||||
"motor", "beanie", "odmantic", "sqlmodel",
|
||||
"piccolo", "edgedb", "cassandra", "clickhouse_driver", "sqlite3",
|
||||
"neo4j", "arango", "influxdb", "timescaledb",
|
||||
]),
|
||||
(PackageCategory::Queue, vec![
|
||||
"celery", "pika", "aio_pika", "kafka", "confluent_kafka",
|
||||
"kombu", "dramatiq", "huey", "rq", "nats", "redis", "aioredis",
|
||||
"aiokafka", "taskiq", "arq",
|
||||
]),
|
||||
(PackageCategory::Storage, vec![
|
||||
"minio", "boto3", "botocore", "google.cloud.storage",
|
||||
"azure.storage.blob", "s3fs", "fsspec", "smart_open",
|
||||
]),
|
||||
(PackageCategory::AiMl, vec![
|
||||
"torch", "tensorflow", "transformers", "langchain",
|
||||
"langchain_core", "langchain_openai", "langchain_community",
|
||||
"openai", "anthropic", "scikit_learn", "sklearn",
|
||||
"numpy", "pandas", "scipy", "matplotlib", "keras",
|
||||
"whisper", "sentence_transformers", "qdrant_client",
|
||||
"chromadb", "pinecone", "faiss", "xgboost", "lightgbm",
|
||||
"catboost", "spacy", "nltk", "gensim", "huggingface_hub",
|
||||
"diffusers", "accelerate", "datasets", "tokenizers",
|
||||
"tiktoken", "llama_index", "autogen", "crewai",
|
||||
"seaborn", "plotly", "bokeh",
|
||||
]),
|
||||
(PackageCategory::Testing, vec![
|
||||
"pytest", "mock", "faker", "hypothesis",
|
||||
"factory_boy", "factory", "responses", "httpretty",
|
||||
"vcrpy", "freezegun", "time_machine", "pytest_asyncio",
|
||||
"pytest_mock", "pytest_cov", "coverage", "tox", "nox",
|
||||
"behave", "robot", "selenium", "playwright", "locust",
|
||||
]),
|
||||
(PackageCategory::Auth, vec![
|
||||
"pyjwt", "jwt", "python_jose", "jose", "passlib",
|
||||
"authlib", "oauthlib", "itsdangerous", "bcrypt",
|
||||
"cryptography", "paramiko",
|
||||
]),
|
||||
(PackageCategory::Logging, vec![
|
||||
"loguru", "structlog", "sentry_sdk", "watchtower",
|
||||
"python_json_logger", "colorlog", "rich",
|
||||
]),
|
||||
];
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stdlib_detection() {
|
||||
assert!(is_stdlib("os"));
|
||||
assert!(is_stdlib("sys"));
|
||||
assert!(is_stdlib("json"));
|
||||
assert!(is_stdlib("asyncio"));
|
||||
assert!(!is_stdlib("requests"));
|
||||
assert!(!is_stdlib("fastapi"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_lookup() {
|
||||
assert_eq!(builtin_lookup("requests"), Some(PackageCategory::Http));
|
||||
assert_eq!(builtin_lookup("sqlalchemy"), Some(PackageCategory::Database));
|
||||
assert_eq!(builtin_lookup("celery"), Some(PackageCategory::Queue));
|
||||
assert_eq!(builtin_lookup("minio"), Some(PackageCategory::Storage));
|
||||
assert_eq!(builtin_lookup("torch"), Some(PackageCategory::AiMl));
|
||||
assert_eq!(builtin_lookup("pytest"), Some(PackageCategory::Testing));
|
||||
assert_eq!(builtin_lookup("loguru"), Some(PackageCategory::Logging));
|
||||
assert_eq!(builtin_lookup("pyjwt"), Some(PackageCategory::Auth));
|
||||
assert_eq!(builtin_lookup("nonexistent_pkg"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_level_package() {
|
||||
assert_eq!(top_level_package("sqlalchemy.orm.Session"), "sqlalchemy");
|
||||
assert_eq!(top_level_package("os.path"), "os");
|
||||
assert_eq!(top_level_package("requests"), "requests");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_package_name() {
|
||||
assert_eq!(normalize_package_name("aio-pika"), "aio_pika");
|
||||
assert_eq!(normalize_package_name("scikit-learn"), "scikit_learn");
|
||||
assert_eq!(normalize_package_name("FastAPI"), "fastapi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_offline() {
|
||||
let mut classifier = PackageClassifier::new(true, None);
|
||||
assert_eq!(classifier.classify("os"), PackageCategory::Stdlib);
|
||||
assert_eq!(classifier.classify("requests"), PackageCategory::Http);
|
||||
assert_eq!(classifier.classify("my_internal_pkg"), PackageCategory::Internal);
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,18 @@ use rustpython_ast::{Stmt, Expr, Ranged};
|
||||
pub struct PythonAnalyzer {
|
||||
config: Config,
|
||||
cache_manager: CacheManager,
|
||||
offline: bool,
|
||||
}
|
||||
|
||||
impl PythonAnalyzer {
|
||||
pub fn new(config: Config) -> Self {
|
||||
let cache_manager = CacheManager::new(config.clone());
|
||||
Self { config, cache_manager }
|
||||
Self { config, cache_manager, offline: false }
|
||||
}
|
||||
|
||||
pub fn new_with_options(config: Config, offline: bool) -> Self {
|
||||
let cache_manager = CacheManager::new(config.clone());
|
||||
Self { config, cache_manager, offline }
|
||||
}
|
||||
|
||||
pub fn parse_module(&self, file_path: &Path) -> Result<ParsedModule, WTIsMyCodeError> {
|
||||
@@ -678,7 +684,72 @@ impl PythonAnalyzer {
|
||||
self.build_dependency_graphs(&mut project_model, modules)?;
|
||||
self.resolve_call_types(&mut project_model, modules, &import_aliases);
|
||||
self.compute_metrics(&mut project_model)?;
|
||||
|
||||
|
||||
// Classify all imports using PackageClassifier
|
||||
let all_imports: Vec<String> = modules.iter()
|
||||
.flat_map(|m| m.imports.iter().map(|i| i.module_name.clone()))
|
||||
.collect();
|
||||
|
||||
let cache_dir = if self.config.caching.enabled {
|
||||
Some(self.config.caching.cache_dir.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut classifier = crate::package_classifier::PackageClassifier::new(self.offline, cache_dir);
|
||||
|
||||
// Add user overrides from config integration_patterns
|
||||
if !self.config.analysis.integration_patterns.is_empty() {
|
||||
let overrides: Vec<(String, Vec<String>)> = self.config.analysis.integration_patterns.iter()
|
||||
.map(|p| (p.type_.clone(), p.patterns.clone()))
|
||||
.collect();
|
||||
classifier.add_user_overrides(&overrides);
|
||||
}
|
||||
|
||||
let classified = classifier.classify_all(&all_imports);
|
||||
classifier.save_cache();
|
||||
|
||||
project_model.classified_integrations = classified.by_category;
|
||||
|
||||
// Also update per-symbol integration flags based on classification
|
||||
for parsed_module in modules {
|
||||
let module_id = self.compute_module_path(&parsed_module.path);
|
||||
let import_names: Vec<String> = parsed_module.imports.iter()
|
||||
.map(|i| i.module_name.clone())
|
||||
.collect();
|
||||
|
||||
let mut flags = crate::model::IntegrationFlags {
|
||||
http: false, db: false, queue: false, storage: false, ai: false,
|
||||
};
|
||||
|
||||
for import in &import_names {
|
||||
let top = import.split('.').next().unwrap_or(import).to_lowercase().replace('-', "_");
|
||||
{
|
||||
let cat = crate::package_classifier::PackageClassifier::new(true, None).classify(&top);
|
||||
match cat {
|
||||
crate::package_classifier::PackageCategory::Http => flags.http = true,
|
||||
crate::package_classifier::PackageCategory::Database => flags.db = true,
|
||||
crate::package_classifier::PackageCategory::Queue => flags.queue = true,
|
||||
crate::package_classifier::PackageCategory::Storage => flags.storage = true,
|
||||
crate::package_classifier::PackageCategory::AiMl => flags.ai = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply to all symbols in this module
|
||||
if let Some(module) = project_model.modules.get(&module_id) {
|
||||
for sym_id in &module.symbols {
|
||||
if let Some(sym) = project_model.symbols.get_mut(sym_id) {
|
||||
sym.integrations_flags.http |= flags.http;
|
||||
sym.integrations_flags.db |= flags.db;
|
||||
sym.integrations_flags.queue |= flags.queue;
|
||||
sym.integrations_flags.storage |= flags.storage;
|
||||
sym.integrations_flags.ai |= flags.ai;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(project_model)
|
||||
}
|
||||
|
||||
|
||||
@@ -73,29 +73,12 @@ impl Renderer {
|
||||
<!-- ARCHDOC:BEGIN section=integrations -->
|
||||
> Generated. Do not edit inside this block.
|
||||
|
||||
### Database Integrations
|
||||
{{#each db_integrations}}
|
||||
{{#each integration_sections}}
|
||||
### {{{category}}}
|
||||
{{#each packages}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### HTTP/API Integrations
|
||||
{{#each http_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### Queue Integrations
|
||||
{{#each queue_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### Storage Integrations
|
||||
{{#each storage_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### AI/ML Integrations
|
||||
{{#each ai_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
<!-- ARCHDOC:END section=integrations -->
|
||||
|
||||
@@ -258,28 +241,17 @@ impl Renderer {
|
||||
}
|
||||
|
||||
pub fn render_architecture_md(&self, model: &ProjectModel, config: Option<&Config>) -> Result<String, anyhow::Error> {
|
||||
// Collect integration information
|
||||
let mut db_integrations = Vec::new();
|
||||
let mut http_integrations = Vec::new();
|
||||
let mut queue_integrations = Vec::new();
|
||||
let mut storage_integrations = Vec::new();
|
||||
let mut ai_integrations = Vec::new();
|
||||
|
||||
for (symbol_id, symbol) in &model.symbols {
|
||||
if symbol.integrations_flags.db {
|
||||
db_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.http {
|
||||
http_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.queue {
|
||||
queue_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.storage {
|
||||
storage_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.ai {
|
||||
ai_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
// Build integration sections from classified_integrations
|
||||
let category_order = ["HTTP", "Database", "Queue", "Storage", "AI/ML", "Auth", "Testing", "Logging", "Internal", "Third-party"];
|
||||
let mut integration_sections: Vec<serde_json::Value> = Vec::new();
|
||||
for cat_name in &category_order {
|
||||
if let Some(pkgs) = model.classified_integrations.get(*cat_name) {
|
||||
if !pkgs.is_empty() {
|
||||
integration_sections.push(serde_json::json!({
|
||||
"category": cat_name,
|
||||
"packages": pkgs,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,11 +391,7 @@ impl Renderer {
|
||||
"key_decisions": ["<FILL_MANUALLY>"],
|
||||
"non_goals": ["<FILL_MANUALLY>"],
|
||||
"change_notes": ["<FILL_MANUALLY>"],
|
||||
"db_integrations": db_integrations,
|
||||
"http_integrations": http_integrations,
|
||||
"queue_integrations": queue_integrations,
|
||||
"storage_integrations": storage_integrations,
|
||||
"ai_integrations": ai_integrations,
|
||||
"integration_sections": integration_sections,
|
||||
"rails_summary": "\n\nNo tooling information available.\n",
|
||||
"layout_items": layout_items,
|
||||
"modules": modules_list,
|
||||
@@ -579,66 +547,31 @@ impl Renderer {
|
||||
}
|
||||
|
||||
pub fn render_integrations_section(&self, model: &ProjectModel) -> Result<String, anyhow::Error> {
|
||||
// Collect integration information
|
||||
let mut db_integrations = Vec::new();
|
||||
let mut http_integrations = Vec::new();
|
||||
let mut queue_integrations = Vec::new();
|
||||
let mut storage_integrations = Vec::new();
|
||||
let mut ai_integrations = Vec::new();
|
||||
|
||||
for (symbol_id, symbol) in &model.symbols {
|
||||
if symbol.integrations_flags.db {
|
||||
db_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.http {
|
||||
http_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.queue {
|
||||
queue_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.storage {
|
||||
storage_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
}
|
||||
if symbol.integrations_flags.ai {
|
||||
ai_integrations.push(format!("{} in {}", symbol_id, symbol.file_id));
|
||||
let category_order = ["HTTP", "Database", "Queue", "Storage", "AI/ML", "Auth", "Testing", "Logging", "Internal", "Third-party"];
|
||||
let mut integration_sections: Vec<serde_json::Value> = Vec::new();
|
||||
for cat_name in &category_order {
|
||||
if let Some(pkgs) = model.classified_integrations.get(*cat_name) {
|
||||
if !pkgs.is_empty() {
|
||||
integration_sections.push(serde_json::json!({
|
||||
"category": cat_name,
|
||||
"packages": pkgs,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare data for integrations section
|
||||
let data = serde_json::json!({
|
||||
"db_integrations": db_integrations,
|
||||
"http_integrations": http_integrations,
|
||||
"queue_integrations": queue_integrations,
|
||||
"storage_integrations": storage_integrations,
|
||||
"ai_integrations": ai_integrations,
|
||||
"integration_sections": integration_sections,
|
||||
});
|
||||
|
||||
// Create a smaller template just for the integrations section
|
||||
let integrations_template = r#"
|
||||
|
||||
### Database Integrations
|
||||
{{#each db_integrations}}
|
||||
{{#each integration_sections}}
|
||||
### {{{category}}}
|
||||
{{#each packages}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### HTTP/API Integrations
|
||||
{{#each http_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### Queue Integrations
|
||||
{{#each queue_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### Storage Integrations
|
||||
{{#each storage_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
|
||||
### AI/ML Integrations
|
||||
{{#each ai_integrations}}
|
||||
- {{{this}}}
|
||||
{{/each}}
|
||||
"#;
|
||||
|
||||
|
||||
@@ -33,9 +33,11 @@ fn test_project_analysis() {
|
||||
// Check that we found calls
|
||||
assert!(!core_module.calls.is_empty());
|
||||
|
||||
// Check that integrations are detected
|
||||
let db_integration_found = core_module.symbols.iter().any(|s| s.integrations_flags.db);
|
||||
let http_integration_found = core_module.symbols.iter().any(|s| s.integrations_flags.http);
|
||||
// Integration flags are now set during resolve_symbols, not parse_module
|
||||
// So we resolve and check there
|
||||
let project_model = analyzer.resolve_symbols(&[core_module.clone()]).unwrap();
|
||||
let db_integration_found = project_model.symbols.values().any(|s| s.integrations_flags.db);
|
||||
let http_integration_found = project_model.symbols.values().any(|s| s.integrations_flags.http);
|
||||
|
||||
assert!(db_integration_found, "Database integration should be detected");
|
||||
assert!(http_integration_found, "HTTP integration should be detected");
|
||||
|
||||
@@ -1,89 +1,36 @@
|
||||
//! Tests for the renderer functionality
|
||||
|
||||
use wtismycode_core::{
|
||||
model::{ProjectModel, Symbol, SymbolKind, IntegrationFlags, SymbolMetrics},
|
||||
model::ProjectModel,
|
||||
renderer::Renderer,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_render_with_integrations() {
|
||||
// Create a mock project model with integration information
|
||||
let mut project_model = ProjectModel::new();
|
||||
|
||||
// Add a symbol with database integration
|
||||
let db_symbol = Symbol {
|
||||
id: "DatabaseManager".to_string(),
|
||||
kind: SymbolKind::Class,
|
||||
module_id: "test_module".to_string(),
|
||||
file_id: "test_file.py".to_string(),
|
||||
qualname: "DatabaseManager".to_string(),
|
||||
signature: "class DatabaseManager".to_string(),
|
||||
annotations: None,
|
||||
docstring_first_line: None,
|
||||
purpose: "test".to_string(),
|
||||
outbound_calls: vec![],
|
||||
inbound_calls: vec![],
|
||||
integrations_flags: IntegrationFlags {
|
||||
db: true,
|
||||
http: false,
|
||||
queue: false,
|
||||
storage: false,
|
||||
ai: false,
|
||||
},
|
||||
metrics: SymbolMetrics {
|
||||
fan_in: 0,
|
||||
fan_out: 0,
|
||||
is_critical: false,
|
||||
cycle_participant: false,
|
||||
},
|
||||
};
|
||||
|
||||
// Add a symbol with HTTP integration
|
||||
let http_symbol = Symbol {
|
||||
id: "fetch_data".to_string(),
|
||||
kind: SymbolKind::Function,
|
||||
module_id: "test_module".to_string(),
|
||||
file_id: "test_file.py".to_string(),
|
||||
qualname: "fetch_data".to_string(),
|
||||
signature: "def fetch_data()".to_string(),
|
||||
annotations: None,
|
||||
docstring_first_line: None,
|
||||
purpose: "test".to_string(),
|
||||
outbound_calls: vec![],
|
||||
inbound_calls: vec![],
|
||||
integrations_flags: IntegrationFlags {
|
||||
db: false,
|
||||
http: true,
|
||||
queue: false,
|
||||
storage: false,
|
||||
ai: false,
|
||||
},
|
||||
metrics: SymbolMetrics {
|
||||
fan_in: 0,
|
||||
fan_out: 0,
|
||||
is_critical: false,
|
||||
cycle_participant: false,
|
||||
},
|
||||
};
|
||||
|
||||
project_model.symbols.insert("DatabaseManager".to_string(), db_symbol);
|
||||
project_model.symbols.insert("fetch_data".to_string(), http_symbol);
|
||||
|
||||
// Initialize renderer
|
||||
|
||||
// Add classified integrations (new format)
|
||||
project_model.classified_integrations.insert(
|
||||
"Database".to_string(),
|
||||
vec!["sqlalchemy".to_string(), "asyncpg".to_string()],
|
||||
);
|
||||
project_model.classified_integrations.insert(
|
||||
"HTTP".to_string(),
|
||||
vec!["fastapi".to_string(), "requests".to_string()],
|
||||
);
|
||||
|
||||
let renderer = Renderer::new();
|
||||
|
||||
// Render architecture documentation
|
||||
let result = renderer.render_architecture_md(&project_model, None);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let rendered_content = result.unwrap();
|
||||
println!("Rendered content:\n{}", rendered_content);
|
||||
|
||||
// Check that integration sections are present
|
||||
assert!(rendered_content.contains("## Integrations"));
|
||||
assert!(rendered_content.contains("### Database Integrations"));
|
||||
assert!(rendered_content.contains("### HTTP/API Integrations"));
|
||||
assert!(rendered_content.contains("DatabaseManager in test_file.py"));
|
||||
assert!(rendered_content.contains("fetch_data in test_file.py"));
|
||||
}
|
||||
|
||||
let rendered = result.unwrap();
|
||||
println!("Rendered:\n{}", rendered);
|
||||
|
||||
assert!(rendered.contains("## Integrations"));
|
||||
assert!(rendered.contains("### Database"));
|
||||
assert!(rendered.contains("- sqlalchemy"));
|
||||
assert!(rendered.contains("- asyncpg"));
|
||||
assert!(rendered.contains("### HTTP"));
|
||||
assert!(rendered.contains("- fastapi"));
|
||||
assert!(rendered.contains("- requests"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user