From b3eb591809708d38114ff7daaf851b17b38366f2 Mon Sep 17 00:00:00 2001 From: Arkasha Date: Sun, 15 Feb 2026 12:45:56 +0300 Subject: [PATCH] feat: smart integration detection with package classifier - Add PackageClassifier with built-in dictionary (~200 popular packages) - Hardcode Python 3.10+ stdlib list to filter out standard library imports - Add PyPI API lookup for unknown packages (online mode, 3s timeout) - Cache PyPI results in .wtismycode/cache/pypi.json - Add --offline flag to skip PyPI lookups - Classify packages into: HTTP, Database, Queue, Storage, AI/ML, Auth, Testing, Logging, Internal, Third-party - User config integration_patterns override auto-detection - Update renderer to show integrations grouped by category - Update ARCHITECTURE.md template with new integration format --- .gitignore | 1 + Cargo.lock | 191 +++++++++ wtismycode-cli/src/commands/generate.rs | 6 +- wtismycode-cli/src/main.rs | 7 +- wtismycode-cli/src/output.rs | 19 +- wtismycode-core/Cargo.toml | 2 + wtismycode-core/src/lib.rs | 1 + wtismycode-core/src/model.rs | 4 + wtismycode-core/src/package_classifier.rs | 452 ++++++++++++++++++++++ wtismycode-core/src/python_analyzer.rs | 75 +++- wtismycode-core/src/renderer.rs | 125 ++---- wtismycode-core/tests/project_analysis.rs | 8 +- wtismycode-core/tests/renderer_tests.rs | 101 ++--- 13 files changed, 800 insertions(+), 192 deletions(-) create mode 100644 wtismycode-core/src/package_classifier.rs diff --git a/.gitignore b/.gitignore index 5c07c35..1be3bad 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ .roo/ PLANS/ target/ +.wtismycode/ diff --git a/Cargo.lock b/Cargo.lock index d10a168..4a909a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + [[package]] name = "ahash" version = "0.8.12" @@ -85,6 +91,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.11.0" @@ -226,6 +238,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crunchy" version = "0.2.4" @@ -379,6 +400,16 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -480,6 +511,22 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "iana-time-zone" version = "0.1.65" @@ -699,6 +746,16 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.1" @@ -799,6 +856,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "pest" version = "2.8.6" @@ -974,6 +1037,20 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -993,6 +1070,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustpython-ast" version = "0.4.0" @@ -1180,6 +1292,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "siphasher" version = "1.0.2" @@ -1214,6 +1332,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.115" @@ -1565,6 +1689,47 @@ dependencies = [ "rand", ] +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" +dependencies = [ + "base64", + "flate2", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf-8", + "webpki-roots", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" @@ -1706,6 +1871,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi-util" version = "0.1.11" @@ -1774,6 +1948,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" @@ -2052,6 +2235,7 @@ dependencies = [ "anyhow", "chrono", "handlebars", + "lazy_static", "rustpython-ast", "rustpython-parser", "serde", @@ -2060,6 +2244,7 @@ dependencies = [ "thiserror 2.0.18", "toml 0.9.12+spec-1.1.0", "tracing", + "ureq", "walkdir", ] @@ -2083,6 +2268,12 @@ dependencies = [ "syn", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zmij" version = "1.0.21" diff --git a/wtismycode-cli/src/commands/generate.rs b/wtismycode-cli/src/commands/generate.rs index 9ce8fc9..b066d7d 100644 --- a/wtismycode-cli/src/commands/generate.rs +++ b/wtismycode-cli/src/commands/generate.rs @@ -12,6 +12,10 @@ pub fn load_config(config_path: &str) -> Result { } pub fn analyze_project(root: &str, config: &Config) -> Result { + analyze_project_with_options(root, config, false) +} + +pub fn analyze_project_with_options(root: &str, config: &Config, offline: bool) -> Result { println!("{}", "Scanning project...".cyan()); let scanner = FileScanner::new(config.clone()); @@ -19,7 +23,7 @@ pub fn analyze_project(root: &str, config: &Config) -> Result { println!(" Found {} Python files", python_files.len().to_string().yellow()); - let analyzer = PythonAnalyzer::new(config.clone()); + let analyzer = PythonAnalyzer::new_with_options(config.clone(), offline); let pb = ProgressBar::new(python_files.len() as u64); pb.set_style(ProgressStyle::default_bar() diff --git a/wtismycode-cli/src/main.rs b/wtismycode-cli/src/main.rs index 10611d6..95bdc9b 100644 --- a/wtismycode-cli/src/main.rs +++ b/wtismycode-cli/src/main.rs @@ -37,6 +37,9 @@ enum Commands { /// Show what would be generated without writing files #[arg(long)] dry_run: bool, + /// Skip PyPI API lookups, use only built-in dictionary + #[arg(long)] + offline: bool, }, /// Check if documentation is up to date Check { @@ -61,9 +64,9 @@ fn main() -> Result<()> { Commands::Init { root, out } => { commands::init::init_project(root, out)?; } - Commands::Generate { root, out, config, dry_run } => { + Commands::Generate { root, out, config, dry_run, offline } => { let config = commands::generate::load_config(config)?; - let model = commands::generate::analyze_project(root, &config)?; + let model = commands::generate::analyze_project_with_options(root, &config, *offline)?; if *dry_run { commands::generate::dry_run_docs(&model, out, &config)?; } else { diff --git a/wtismycode-cli/src/output.rs b/wtismycode-cli/src/output.rs index df8e4ea..d5358b6 100644 --- a/wtismycode-cli/src/output.rs +++ b/wtismycode-cli/src/output.rs @@ -19,17 +19,14 @@ pub fn print_generate_summary(model: &ProjectModel) { println!(" {} {}", "Edges:".bold(), model.edges.module_import_edges.len() + model.edges.symbol_call_edges.len()); - let integrations: Vec<&str> = { - let mut v = Vec::new(); - if model.symbols.values().any(|s| s.integrations_flags.http) { v.push("HTTP"); } - if model.symbols.values().any(|s| s.integrations_flags.db) { v.push("DB"); } - if model.symbols.values().any(|s| s.integrations_flags.queue) { v.push("Queue"); } - if model.symbols.values().any(|s| s.integrations_flags.storage) { v.push("Storage"); } - if model.symbols.values().any(|s| s.integrations_flags.ai) { v.push("AI/ML"); } - v - }; - if !integrations.is_empty() { - println!(" {} {}", "Integrations:".bold(), integrations.join(", ").yellow()); + if !model.classified_integrations.is_empty() { + let cats: Vec = model.classified_integrations.iter() + .filter(|(_, pkgs)| !pkgs.is_empty()) + .map(|(cat, pkgs)| format!("{} ({})", cat, pkgs.join(", "))) + .collect(); + if !cats.is_empty() { + println!(" {} {}", "Integrations:".bold(), cats.join(" | ").yellow()); + } } println!("{}", "─────────────────────────────────────".dimmed()); } diff --git a/wtismycode-core/Cargo.toml b/wtismycode-core/Cargo.toml index 8bb5bb8..7912a2e 100644 --- a/wtismycode-core/Cargo.toml +++ b/wtismycode-core/Cargo.toml @@ -16,3 +16,5 @@ rustpython-parser = "0.4" rustpython-ast = "0.4" chrono = { version = "0.4", features = ["serde"] } tempfile = "3.10" +ureq = "3" +lazy_static = "1.4" diff --git a/wtismycode-core/src/lib.rs b/wtismycode-core/src/lib.rs index 6e1eaa8..0200508 100644 --- a/wtismycode-core/src/lib.rs +++ b/wtismycode-core/src/lib.rs @@ -13,6 +13,7 @@ pub mod renderer; pub mod writer; pub mod cache; pub mod cycle_detector; +pub mod package_classifier; // Re-export commonly used types pub use errors::WTIsMyCodeError; diff --git a/wtismycode-core/src/model.rs b/wtismycode-core/src/model.rs index fd009c2..d1a6efa 100644 --- a/wtismycode-core/src/model.rs +++ b/wtismycode-core/src/model.rs @@ -12,6 +12,9 @@ pub struct ProjectModel { pub files: HashMap, pub symbols: HashMap, pub edges: Edges, + /// Classified integrations by category (e.g. "HTTP" -> ["fastapi", "requests"]) + #[serde(default)] + pub classified_integrations: HashMap>, } impl ProjectModel { @@ -21,6 +24,7 @@ impl ProjectModel { files: HashMap::new(), symbols: HashMap::new(), edges: Edges::new(), + classified_integrations: HashMap::new(), } } } diff --git a/wtismycode-core/src/package_classifier.rs b/wtismycode-core/src/package_classifier.rs new file mode 100644 index 0000000..ae4883c --- /dev/null +++ b/wtismycode-core/src/package_classifier.rs @@ -0,0 +1,452 @@ +//! Package classifier for Python imports +//! +//! Classifies Python packages into categories using: +//! 1. Python stdlib list (hardcoded) +//! 2. Built-in dictionary (~200 popular packages) +//! 3. PyPI API lookup (online mode) +//! 4. Internal package detection (fallback) + +use std::collections::HashMap; +use std::path::Path; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub enum PackageCategory { + Stdlib, + Http, + Database, + Queue, + Storage, + AiMl, + Testing, + Logging, + Auth, + Internal, + ThirdParty, +} + +impl PackageCategory { + pub fn display_name(&self) -> &'static str { + match self { + Self::Stdlib => "Stdlib", + Self::Http => "HTTP", + Self::Database => "Database", + Self::Queue => "Queue", + Self::Storage => "Storage", + Self::AiMl => "AI/ML", + Self::Testing => "Testing", + Self::Logging => "Logging", + Self::Auth => "Auth", + Self::Internal => "Internal", + Self::ThirdParty => "Third-party", + } + } +} + +/// Result of classifying all imports in a project +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct ClassifiedIntegrations { + /// category -> list of package names + pub by_category: HashMap>, +} + +pub struct PackageClassifier { + offline: bool, + cache_dir: Option, + /// user overrides from config integration_patterns + user_overrides: HashMap, + /// PyPI cache: package_name -> Option (None = not found) + pypi_cache: HashMap>, +} + +impl PackageClassifier { + pub fn new(offline: bool, cache_dir: Option) -> Self { + let mut classifier = Self { + offline, + cache_dir: cache_dir.clone(), + user_overrides: HashMap::new(), + pypi_cache: HashMap::new(), + }; + // Load PyPI cache from disk + if let Some(ref dir) = cache_dir { + classifier.load_pypi_cache(dir); + } + classifier + } + + /// Add user overrides from config integration_patterns + pub fn add_user_overrides(&mut self, patterns: &[(String, Vec)]) { + for (type_name, pkgs) in patterns { + let cat = match type_name.as_str() { + "http" => PackageCategory::Http, + "db" => PackageCategory::Database, + "queue" => PackageCategory::Queue, + "storage" => PackageCategory::Storage, + "ai" => PackageCategory::AiMl, + "testing" => PackageCategory::Testing, + "logging" => PackageCategory::Logging, + "auth" => PackageCategory::Auth, + _ => PackageCategory::ThirdParty, + }; + for pkg in pkgs { + self.user_overrides.insert(pkg.to_lowercase(), cat.clone()); + } + } + } + + /// Classify a single package name (top-level import) + pub fn classify(&mut self, package_name: &str) -> PackageCategory { + let normalized = normalize_package_name(package_name); + + // 1. User overrides take priority + if let Some(cat) = self.user_overrides.get(&normalized) { + return cat.clone(); + } + + // 2. Built-in dictionary (check BEFORE stdlib, so sqlite3 etc. are categorized properly) + if let Some(cat) = builtin_lookup(&normalized) { + return cat; + } + + // 3. Stdlib + if is_stdlib(&normalized) { + return PackageCategory::Stdlib; + } + + // 4. PyPI lookup (if online) + if !self.offline { + if let Some(cached) = self.pypi_cache.get(&normalized) { + return cached.clone().unwrap_or(PackageCategory::Internal); + } + match self.pypi_lookup(&normalized) { + Some(cat) => { + self.pypi_cache.insert(normalized, Some(cat.clone())); + return cat; + } + None => { + self.pypi_cache.insert(normalized, None); + return PackageCategory::Internal; + } + } + } + + // 5. Offline fallback: if not in stdlib or dictionary, assume internal + PackageCategory::Internal + } + + /// Classify all imports and return grouped integrations + pub fn classify_all(&mut self, import_names: &[String]) -> ClassifiedIntegrations { + let mut result = ClassifiedIntegrations::default(); + let mut seen: HashMap = HashMap::new(); + + for import in import_names { + let top_level = top_level_package(import); + if seen.contains_key(&top_level) { + continue; + } + let cat = self.classify(&top_level); + seen.insert(top_level.clone(), cat.clone()); + + // Skip stdlib and third-party without category + if cat == PackageCategory::Stdlib { + continue; + } + + let category_name = cat.display_name().to_string(); + result.by_category + .entry(category_name) + .or_default() + .push(top_level); + } + + // Deduplicate and sort each category + for pkgs in result.by_category.values_mut() { + pkgs.sort(); + pkgs.dedup(); + } + + result + } + + /// Save PyPI cache to disk + pub fn save_cache(&self) { + if let Some(ref dir) = self.cache_dir { + let cache_path = Path::new(dir).join("pypi.json"); + if let Ok(json) = serde_json::to_string_pretty(&self.pypi_cache) { + let _ = std::fs::create_dir_all(dir); + let _ = std::fs::write(&cache_path, json); + } + } + } + + fn load_pypi_cache(&mut self, dir: &str) { + let cache_path = Path::new(dir).join("pypi.json"); + if let Ok(content) = std::fs::read_to_string(&cache_path) { + if let Ok(cache) = serde_json::from_str::>>(&content) { + self.pypi_cache = cache; + } + } + } + + fn pypi_lookup(&self, package_name: &str) -> Option { + let url = format!("https://pypi.org/pypi/{}/json", package_name); + + let agent = ureq::Agent::new_with_config( + ureq::config::Config::builder() + .timeout_global(Some(std::time::Duration::from_secs(3))) + .build() + ); + + let response = agent.get(&url).call().ok()?; + + if response.status() != 200 { + return None; + } + + let body_str = response.into_body().read_to_string().ok()?; + let body: serde_json::Value = serde_json::from_str(&body_str).ok()?; + let info = body.get("info")?; + + // Check classifiers + if let Some(classifiers) = info.get("classifiers").and_then(|c: &serde_json::Value| c.as_array()) { + for classifier in classifiers { + if let Some(s) = classifier.as_str() { + if let Some(cat) = classify_from_pypi_classifier(s) { + return Some(cat); + } + } + } + } + + // Check summary and keywords for hints + let summary = info.get("summary").and_then(|s: &serde_json::Value| s.as_str()).unwrap_or(""); + let keywords = info.get("keywords").and_then(|s: &serde_json::Value| s.as_str()).unwrap_or(""); + let combined = format!("{} {}", summary, keywords).to_lowercase(); + + if combined.contains("database") || combined.contains("sql") || combined.contains("orm") { + return Some(PackageCategory::Database); + } + if combined.contains("http") || combined.contains("web framework") || combined.contains("rest api") { + return Some(PackageCategory::Http); + } + if combined.contains("queue") || combined.contains("message broker") || combined.contains("amqp") || combined.contains("kafka") { + return Some(PackageCategory::Queue); + } + if combined.contains("storage") || combined.contains("s3") || combined.contains("blob") { + return Some(PackageCategory::Storage); + } + if combined.contains("machine learning") || combined.contains("deep learning") || combined.contains("neural") || combined.contains("artificial intelligence") { + return Some(PackageCategory::AiMl); + } + if combined.contains("testing") || combined.contains("test framework") { + return Some(PackageCategory::Testing); + } + if combined.contains("logging") || combined.contains("error tracking") { + return Some(PackageCategory::Logging); + } + if combined.contains("authentication") || combined.contains("jwt") || combined.contains("oauth") { + return Some(PackageCategory::Auth); + } + + // Found on PyPI but no category detected + Some(PackageCategory::ThirdParty) + } +} + +fn classify_from_pypi_classifier(classifier: &str) -> Option { + let c = classifier.to_lowercase(); + if c.contains("framework :: django") || c.contains("framework :: flask") || + c.contains("framework :: fastapi") || c.contains("framework :: tornado") || + c.contains("framework :: aiohttp") || c.contains("topic :: internet :: www") { + return Some(PackageCategory::Http); + } + if c.contains("topic :: database") { + return Some(PackageCategory::Database); + } + if c.contains("topic :: scientific/engineering :: artificial intelligence") || + c.contains("topic :: scientific/engineering :: machine learning") { + return Some(PackageCategory::AiMl); + } + if c.contains("topic :: software development :: testing") { + return Some(PackageCategory::Testing); + } + if c.contains("topic :: system :: logging") { + return Some(PackageCategory::Logging); + } + if c.contains("topic :: security") && (classifier.contains("auth") || classifier.contains("Auth")) { + return Some(PackageCategory::Auth); + } + None +} + +/// Extract top-level package name from an import string +/// e.g. "sqlalchemy.orm.Session" -> "sqlalchemy" +fn top_level_package(import: &str) -> String { + import.split('.').next().unwrap_or(import).to_lowercase() +} + +/// Normalize package name for lookup (lowercase, replace hyphens with underscores) +fn normalize_package_name(name: &str) -> String { + name.to_lowercase().replace('-', "_") +} + +/// Check if a package is in the Python standard library +fn is_stdlib(name: &str) -> bool { + PYTHON_STDLIB.contains(&name) +} + +/// Look up a package in the built-in dictionary +fn builtin_lookup(name: &str) -> Option { + for (cat, pkgs) in BUILTIN_PACKAGES.iter() { + if pkgs.contains(&name) { + return Some(cat.clone()); + } + } + None +} + +// Python 3.10+ standard library modules +const PYTHON_STDLIB: &[&str] = &[ + "__future__", "_thread", "abc", "aifc", "argparse", "array", "ast", + "asynchat", "asyncio", "asyncore", "atexit", "audioop", "base64", + "bdb", "binascii", "binhex", "bisect", "builtins", "bz2", + "calendar", "cgi", "cgitb", "chunk", "cmath", "cmd", "code", + "codecs", "codeop", "collections", "colorsys", "compileall", + "concurrent", "configparser", "contextlib", "contextvars", "copy", + "copyreg", "cprofile", "crypt", "csv", "ctypes", "curses", + "dataclasses", "datetime", "dbm", "decimal", "difflib", "dis", + "distutils", "doctest", "email", "encodings", "enum", "errno", + "faulthandler", "fcntl", "filecmp", "fileinput", "fnmatch", + "formatter", "fractions", "ftplib", "functools", "gc", "getopt", + "getpass", "gettext", "glob", "grp", "gzip", "hashlib", "heapq", + "hmac", "html", "http", "idlelib", "imaplib", "imghdr", "imp", + "importlib", "inspect", "io", "ipaddress", "itertools", "json", + "keyword", "lib2to3", "linecache", "locale", "logging", "lzma", + "mailbox", "mailcap", "marshal", "math", "mimetypes", "mmap", + "modulefinder", "multiprocessing", "netrc", "nis", "nntplib", + "numbers", "operator", "optparse", "os", "ossaudiodev", "parser", + "pathlib", "pdb", "pickle", "pickletools", "pipes", "pkgutil", + "platform", "plistlib", "poplib", "posix", "posixpath", "pprint", + "profile", "pstats", "pty", "pwd", "py_compile", "pyclbr", + "pydoc", "queue", "quopri", "random", "re", "readline", "reprlib", + "resource", "rlcompleter", "runpy", "sched", "secrets", "select", + "selectors", "shelve", "shlex", "shutil", "signal", "site", + "smtpd", "smtplib", "sndhdr", "socket", "socketserver", "spwd", + "sqlite3", "ssl", "stat", "statistics", "string", "stringprep", + "struct", "subprocess", "sunau", "symtable", "sys", "sysconfig", + "syslog", "tabnanny", "tarfile", "telnetlib", "tempfile", "termios", + "test", "textwrap", "threading", "time", "timeit", "tkinter", + "token", "tokenize", "tomllib", "trace", "traceback", "tracemalloc", + "tty", "turtle", "turtledemo", "types", "typing", "unicodedata", + "unittest", "urllib", "uu", "uuid", "venv", "warnings", "wave", + "weakref", "webbrowser", "winreg", "winsound", "wsgiref", "xdrlib", + "xml", "xmlrpc", "zipapp", "zipfile", "zipimport", "zlib", + // Common sub-packages that appear as top-level imports + "os.path", "collections.abc", "concurrent.futures", "typing_extensions", +]; + +lazy_static::lazy_static! { + static ref BUILTIN_PACKAGES: Vec<(PackageCategory, Vec<&'static str>)> = vec![ + (PackageCategory::Http, vec![ + "requests", "httpx", "aiohttp", "fastapi", "flask", "django", + "starlette", "uvicorn", "gunicorn", "tornado", "sanic", "bottle", + "falcon", "quart", "werkzeug", "httptools", "uvloop", "hypercorn", + "grpcio", "grpc", "graphene", "strawberry", "ariadne", + "pydantic", "marshmallow", "connexion", "responder", "hug", + ]), + (PackageCategory::Database, vec![ + "sqlalchemy", "psycopg2", "psycopg", "asyncpg", "pymongo", + "mongoengine", "peewee", "tortoise", "databases", + "alembic", "pymysql", "opensearch", "elasticsearch", + "motor", "beanie", "odmantic", "sqlmodel", + "piccolo", "edgedb", "cassandra", "clickhouse_driver", "sqlite3", + "neo4j", "arango", "influxdb", "timescaledb", + ]), + (PackageCategory::Queue, vec![ + "celery", "pika", "aio_pika", "kafka", "confluent_kafka", + "kombu", "dramatiq", "huey", "rq", "nats", "redis", "aioredis", + "aiokafka", "taskiq", "arq", + ]), + (PackageCategory::Storage, vec![ + "minio", "boto3", "botocore", "google.cloud.storage", + "azure.storage.blob", "s3fs", "fsspec", "smart_open", + ]), + (PackageCategory::AiMl, vec![ + "torch", "tensorflow", "transformers", "langchain", + "langchain_core", "langchain_openai", "langchain_community", + "openai", "anthropic", "scikit_learn", "sklearn", + "numpy", "pandas", "scipy", "matplotlib", "keras", + "whisper", "sentence_transformers", "qdrant_client", + "chromadb", "pinecone", "faiss", "xgboost", "lightgbm", + "catboost", "spacy", "nltk", "gensim", "huggingface_hub", + "diffusers", "accelerate", "datasets", "tokenizers", + "tiktoken", "llama_index", "autogen", "crewai", + "seaborn", "plotly", "bokeh", + ]), + (PackageCategory::Testing, vec![ + "pytest", "mock", "faker", "hypothesis", + "factory_boy", "factory", "responses", "httpretty", + "vcrpy", "freezegun", "time_machine", "pytest_asyncio", + "pytest_mock", "pytest_cov", "coverage", "tox", "nox", + "behave", "robot", "selenium", "playwright", "locust", + ]), + (PackageCategory::Auth, vec![ + "pyjwt", "jwt", "python_jose", "jose", "passlib", + "authlib", "oauthlib", "itsdangerous", "bcrypt", + "cryptography", "paramiko", + ]), + (PackageCategory::Logging, vec![ + "loguru", "structlog", "sentry_sdk", "watchtower", + "python_json_logger", "colorlog", "rich", + ]), + ]; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stdlib_detection() { + assert!(is_stdlib("os")); + assert!(is_stdlib("sys")); + assert!(is_stdlib("json")); + assert!(is_stdlib("asyncio")); + assert!(!is_stdlib("requests")); + assert!(!is_stdlib("fastapi")); + } + + #[test] + fn test_builtin_lookup() { + assert_eq!(builtin_lookup("requests"), Some(PackageCategory::Http)); + assert_eq!(builtin_lookup("sqlalchemy"), Some(PackageCategory::Database)); + assert_eq!(builtin_lookup("celery"), Some(PackageCategory::Queue)); + assert_eq!(builtin_lookup("minio"), Some(PackageCategory::Storage)); + assert_eq!(builtin_lookup("torch"), Some(PackageCategory::AiMl)); + assert_eq!(builtin_lookup("pytest"), Some(PackageCategory::Testing)); + assert_eq!(builtin_lookup("loguru"), Some(PackageCategory::Logging)); + assert_eq!(builtin_lookup("pyjwt"), Some(PackageCategory::Auth)); + assert_eq!(builtin_lookup("nonexistent_pkg"), None); + } + + #[test] + fn test_top_level_package() { + assert_eq!(top_level_package("sqlalchemy.orm.Session"), "sqlalchemy"); + assert_eq!(top_level_package("os.path"), "os"); + assert_eq!(top_level_package("requests"), "requests"); + } + + #[test] + fn test_normalize_package_name() { + assert_eq!(normalize_package_name("aio-pika"), "aio_pika"); + assert_eq!(normalize_package_name("scikit-learn"), "scikit_learn"); + assert_eq!(normalize_package_name("FastAPI"), "fastapi"); + } + + #[test] + fn test_classify_offline() { + let mut classifier = PackageClassifier::new(true, None); + assert_eq!(classifier.classify("os"), PackageCategory::Stdlib); + assert_eq!(classifier.classify("requests"), PackageCategory::Http); + assert_eq!(classifier.classify("my_internal_pkg"), PackageCategory::Internal); + } +} diff --git a/wtismycode-core/src/python_analyzer.rs b/wtismycode-core/src/python_analyzer.rs index a3edc28..2bef3ca 100644 --- a/wtismycode-core/src/python_analyzer.rs +++ b/wtismycode-core/src/python_analyzer.rs @@ -15,12 +15,18 @@ use rustpython_ast::{Stmt, Expr, Ranged}; pub struct PythonAnalyzer { config: Config, cache_manager: CacheManager, + offline: bool, } impl PythonAnalyzer { pub fn new(config: Config) -> Self { let cache_manager = CacheManager::new(config.clone()); - Self { config, cache_manager } + Self { config, cache_manager, offline: false } + } + + pub fn new_with_options(config: Config, offline: bool) -> Self { + let cache_manager = CacheManager::new(config.clone()); + Self { config, cache_manager, offline } } pub fn parse_module(&self, file_path: &Path) -> Result { @@ -678,7 +684,72 @@ impl PythonAnalyzer { self.build_dependency_graphs(&mut project_model, modules)?; self.resolve_call_types(&mut project_model, modules, &import_aliases); self.compute_metrics(&mut project_model)?; - + + // Classify all imports using PackageClassifier + let all_imports: Vec = modules.iter() + .flat_map(|m| m.imports.iter().map(|i| i.module_name.clone())) + .collect(); + + let cache_dir = if self.config.caching.enabled { + Some(self.config.caching.cache_dir.clone()) + } else { + None + }; + let mut classifier = crate::package_classifier::PackageClassifier::new(self.offline, cache_dir); + + // Add user overrides from config integration_patterns + if !self.config.analysis.integration_patterns.is_empty() { + let overrides: Vec<(String, Vec)> = self.config.analysis.integration_patterns.iter() + .map(|p| (p.type_.clone(), p.patterns.clone())) + .collect(); + classifier.add_user_overrides(&overrides); + } + + let classified = classifier.classify_all(&all_imports); + classifier.save_cache(); + + project_model.classified_integrations = classified.by_category; + + // Also update per-symbol integration flags based on classification + for parsed_module in modules { + let module_id = self.compute_module_path(&parsed_module.path); + let import_names: Vec = parsed_module.imports.iter() + .map(|i| i.module_name.clone()) + .collect(); + + let mut flags = crate::model::IntegrationFlags { + http: false, db: false, queue: false, storage: false, ai: false, + }; + + for import in &import_names { + let top = import.split('.').next().unwrap_or(import).to_lowercase().replace('-', "_"); + { + let cat = crate::package_classifier::PackageClassifier::new(true, None).classify(&top); + match cat { + crate::package_classifier::PackageCategory::Http => flags.http = true, + crate::package_classifier::PackageCategory::Database => flags.db = true, + crate::package_classifier::PackageCategory::Queue => flags.queue = true, + crate::package_classifier::PackageCategory::Storage => flags.storage = true, + crate::package_classifier::PackageCategory::AiMl => flags.ai = true, + _ => {} + } + } + } + + // Apply to all symbols in this module + if let Some(module) = project_model.modules.get(&module_id) { + for sym_id in &module.symbols { + if let Some(sym) = project_model.symbols.get_mut(sym_id) { + sym.integrations_flags.http |= flags.http; + sym.integrations_flags.db |= flags.db; + sym.integrations_flags.queue |= flags.queue; + sym.integrations_flags.storage |= flags.storage; + sym.integrations_flags.ai |= flags.ai; + } + } + } + } + Ok(project_model) } diff --git a/wtismycode-core/src/renderer.rs b/wtismycode-core/src/renderer.rs index 0ec0600..408094c 100644 --- a/wtismycode-core/src/renderer.rs +++ b/wtismycode-core/src/renderer.rs @@ -73,29 +73,12 @@ impl Renderer { > Generated. Do not edit inside this block. -### Database Integrations -{{#each db_integrations}} +{{#each integration_sections}} +### {{{category}}} +{{#each packages}} - {{{this}}} {{/each}} -### HTTP/API Integrations -{{#each http_integrations}} -- {{{this}}} -{{/each}} - -### Queue Integrations -{{#each queue_integrations}} -- {{{this}}} -{{/each}} - -### Storage Integrations -{{#each storage_integrations}} -- {{{this}}} -{{/each}} - -### AI/ML Integrations -{{#each ai_integrations}} -- {{{this}}} {{/each}} @@ -258,28 +241,17 @@ impl Renderer { } pub fn render_architecture_md(&self, model: &ProjectModel, config: Option<&Config>) -> Result { - // Collect integration information - let mut db_integrations = Vec::new(); - let mut http_integrations = Vec::new(); - let mut queue_integrations = Vec::new(); - let mut storage_integrations = Vec::new(); - let mut ai_integrations = Vec::new(); - - for (symbol_id, symbol) in &model.symbols { - if symbol.integrations_flags.db { - db_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.http { - http_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.queue { - queue_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.storage { - storage_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.ai { - ai_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); + // Build integration sections from classified_integrations + let category_order = ["HTTP", "Database", "Queue", "Storage", "AI/ML", "Auth", "Testing", "Logging", "Internal", "Third-party"]; + let mut integration_sections: Vec = Vec::new(); + for cat_name in &category_order { + if let Some(pkgs) = model.classified_integrations.get(*cat_name) { + if !pkgs.is_empty() { + integration_sections.push(serde_json::json!({ + "category": cat_name, + "packages": pkgs, + })); + } } } @@ -419,11 +391,7 @@ impl Renderer { "key_decisions": [""], "non_goals": [""], "change_notes": [""], - "db_integrations": db_integrations, - "http_integrations": http_integrations, - "queue_integrations": queue_integrations, - "storage_integrations": storage_integrations, - "ai_integrations": ai_integrations, + "integration_sections": integration_sections, "rails_summary": "\n\nNo tooling information available.\n", "layout_items": layout_items, "modules": modules_list, @@ -579,66 +547,31 @@ impl Renderer { } pub fn render_integrations_section(&self, model: &ProjectModel) -> Result { - // Collect integration information - let mut db_integrations = Vec::new(); - let mut http_integrations = Vec::new(); - let mut queue_integrations = Vec::new(); - let mut storage_integrations = Vec::new(); - let mut ai_integrations = Vec::new(); - - for (symbol_id, symbol) in &model.symbols { - if symbol.integrations_flags.db { - db_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.http { - http_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.queue { - queue_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.storage { - storage_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); - } - if symbol.integrations_flags.ai { - ai_integrations.push(format!("{} in {}", symbol_id, symbol.file_id)); + let category_order = ["HTTP", "Database", "Queue", "Storage", "AI/ML", "Auth", "Testing", "Logging", "Internal", "Third-party"]; + let mut integration_sections: Vec = Vec::new(); + for cat_name in &category_order { + if let Some(pkgs) = model.classified_integrations.get(*cat_name) { + if !pkgs.is_empty() { + integration_sections.push(serde_json::json!({ + "category": cat_name, + "packages": pkgs, + })); + } } } - // Prepare data for integrations section let data = serde_json::json!({ - "db_integrations": db_integrations, - "http_integrations": http_integrations, - "queue_integrations": queue_integrations, - "storage_integrations": storage_integrations, - "ai_integrations": ai_integrations, + "integration_sections": integration_sections, }); - // Create a smaller template just for the integrations section let integrations_template = r#" -### Database Integrations -{{#each db_integrations}} +{{#each integration_sections}} +### {{{category}}} +{{#each packages}} - {{{this}}} {{/each}} -### HTTP/API Integrations -{{#each http_integrations}} -- {{{this}}} -{{/each}} - -### Queue Integrations -{{#each queue_integrations}} -- {{{this}}} -{{/each}} - -### Storage Integrations -{{#each storage_integrations}} -- {{{this}}} -{{/each}} - -### AI/ML Integrations -{{#each ai_integrations}} -- {{{this}}} {{/each}} "#; diff --git a/wtismycode-core/tests/project_analysis.rs b/wtismycode-core/tests/project_analysis.rs index dddfc2b..66536ab 100644 --- a/wtismycode-core/tests/project_analysis.rs +++ b/wtismycode-core/tests/project_analysis.rs @@ -33,9 +33,11 @@ fn test_project_analysis() { // Check that we found calls assert!(!core_module.calls.is_empty()); - // Check that integrations are detected - let db_integration_found = core_module.symbols.iter().any(|s| s.integrations_flags.db); - let http_integration_found = core_module.symbols.iter().any(|s| s.integrations_flags.http); + // Integration flags are now set during resolve_symbols, not parse_module + // So we resolve and check there + let project_model = analyzer.resolve_symbols(&[core_module.clone()]).unwrap(); + let db_integration_found = project_model.symbols.values().any(|s| s.integrations_flags.db); + let http_integration_found = project_model.symbols.values().any(|s| s.integrations_flags.http); assert!(db_integration_found, "Database integration should be detected"); assert!(http_integration_found, "HTTP integration should be detected"); diff --git a/wtismycode-core/tests/renderer_tests.rs b/wtismycode-core/tests/renderer_tests.rs index 7c64c38..5ba2aaa 100644 --- a/wtismycode-core/tests/renderer_tests.rs +++ b/wtismycode-core/tests/renderer_tests.rs @@ -1,89 +1,36 @@ //! Tests for the renderer functionality use wtismycode_core::{ - model::{ProjectModel, Symbol, SymbolKind, IntegrationFlags, SymbolMetrics}, + model::ProjectModel, renderer::Renderer, }; -use std::collections::HashMap; #[test] fn test_render_with_integrations() { - // Create a mock project model with integration information let mut project_model = ProjectModel::new(); - - // Add a symbol with database integration - let db_symbol = Symbol { - id: "DatabaseManager".to_string(), - kind: SymbolKind::Class, - module_id: "test_module".to_string(), - file_id: "test_file.py".to_string(), - qualname: "DatabaseManager".to_string(), - signature: "class DatabaseManager".to_string(), - annotations: None, - docstring_first_line: None, - purpose: "test".to_string(), - outbound_calls: vec![], - inbound_calls: vec![], - integrations_flags: IntegrationFlags { - db: true, - http: false, - queue: false, - storage: false, - ai: false, - }, - metrics: SymbolMetrics { - fan_in: 0, - fan_out: 0, - is_critical: false, - cycle_participant: false, - }, - }; - - // Add a symbol with HTTP integration - let http_symbol = Symbol { - id: "fetch_data".to_string(), - kind: SymbolKind::Function, - module_id: "test_module".to_string(), - file_id: "test_file.py".to_string(), - qualname: "fetch_data".to_string(), - signature: "def fetch_data()".to_string(), - annotations: None, - docstring_first_line: None, - purpose: "test".to_string(), - outbound_calls: vec![], - inbound_calls: vec![], - integrations_flags: IntegrationFlags { - db: false, - http: true, - queue: false, - storage: false, - ai: false, - }, - metrics: SymbolMetrics { - fan_in: 0, - fan_out: 0, - is_critical: false, - cycle_participant: false, - }, - }; - - project_model.symbols.insert("DatabaseManager".to_string(), db_symbol); - project_model.symbols.insert("fetch_data".to_string(), http_symbol); - - // Initialize renderer + + // Add classified integrations (new format) + project_model.classified_integrations.insert( + "Database".to_string(), + vec!["sqlalchemy".to_string(), "asyncpg".to_string()], + ); + project_model.classified_integrations.insert( + "HTTP".to_string(), + vec!["fastapi".to_string(), "requests".to_string()], + ); + let renderer = Renderer::new(); - - // Render architecture documentation let result = renderer.render_architecture_md(&project_model, None); assert!(result.is_ok()); - - let rendered_content = result.unwrap(); - println!("Rendered content:\n{}", rendered_content); - - // Check that integration sections are present - assert!(rendered_content.contains("## Integrations")); - assert!(rendered_content.contains("### Database Integrations")); - assert!(rendered_content.contains("### HTTP/API Integrations")); - assert!(rendered_content.contains("DatabaseManager in test_file.py")); - assert!(rendered_content.contains("fetch_data in test_file.py")); -} \ No newline at end of file + + let rendered = result.unwrap(); + println!("Rendered:\n{}", rendered); + + assert!(rendered.contains("## Integrations")); + assert!(rendered.contains("### Database")); + assert!(rendered.contains("- sqlalchemy")); + assert!(rendered.contains("- asyncpg")); + assert!(rendered.contains("### HTTP")); + assert!(rendered.contains("- fastapi")); + assert!(rendered.contains("- requests")); +}