diff --git a/package-lock.json b/package-lock.json index 8f48150..7e406cf 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,7 +14,7 @@ "@google/genai": "^1.8.0", "@google/generative-ai": "^0.24.1", "axios": "^1.10.0", - "better-sqlite3": "^9.4.3", + "better-sqlite3": "^9.6.0", "cors": "^2.8.5", "dotenv": "^17.0.0", "electron-squirrel-startup": "^1.0.1", @@ -27,6 +27,7 @@ "keytar": "^7.9.0", "node-fetch": "^2.7.0", "openai": "^4.70.0", + "portkey-ai": "^1.10.1", "react-hot-toast": "^2.5.2", "sharp": "^0.34.2", "validator": "^13.11.0", @@ -2283,6 +2284,8 @@ }, "node_modules/better-sqlite3": { "version": "9.6.0", + "resolved": "https://registry.npmjs.org/better-sqlite3/-/better-sqlite3-9.6.0.tgz", + "integrity": "sha512-yR5HATnqeYNVnkaUTf4bOP2dJSnyhP4puJN/QPRyx4YkBEEUxib422n2XzPqDEHjQQqazoYoADdAm5vE15+dAQ==", "hasInstallScript": true, "license": "MIT", "dependencies": { @@ -5922,6 +5925,30 @@ "node": ">=10.4.0" } }, + "node_modules/portkey-ai": { + "version": "1.10.1", + "resolved": "https://registry.npmjs.org/portkey-ai/-/portkey-ai-1.10.1.tgz", + "integrity": "sha512-mRGDxm4xBMexYlk/bS8i+G5C/Ww+KaXcKlHtzzsmh0X4Awd1bPBGq5dlUmCrHGgN/umLpphxcOcLHsDa9NbjrQ==", + "license": "MIT", + "dependencies": { + "agentkeepalive": "^4.6.0", + "dotenv": "^16.5.0", + "openai": "4.104.0", + "ws": "^8.18.2" + } + }, + "node_modules/portkey-ai/node_modules/dotenv": { + "version": "16.6.1", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz", + "integrity": "sha512-uBq4egWHTcTt33a72vpSG0z3HnPuIl6NqYcTrKEg2azoEyl2hpW0zqlxysq2pK9HlDIHyHyakeYaYnSAwd8bow==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } + }, "node_modules/postject": { "version": "1.0.0-alpha.6", "dev": true, diff --git a/package.json b/package.json index b8f3d35..4ed8d95 100644 --- a/package.json +++ b/package.json @@ -36,7 +36,7 @@ "@google/genai": "^1.8.0", "@google/generative-ai": "^0.24.1", "axios": "^1.10.0", - "better-sqlite3": "^9.4.3", + "better-sqlite3": "^9.6.0", "cors": "^2.8.5", "dotenv": "^17.0.0", "electron-squirrel-startup": "^1.0.1", @@ -49,6 +49,7 @@ "keytar": "^7.9.0", "node-fetch": "^2.7.0", "openai": "^4.70.0", + "portkey-ai": "^1.10.1", "react-hot-toast": "^2.5.2", "sharp": "^0.34.2", "validator": "^13.11.0", diff --git a/pickleglass_web/backend_node/routes/user.js b/pickleglass_web/backend_node/routes/user.js index d31a9a8..d80bf27 100644 --- a/pickleglass_web/backend_node/routes/user.js +++ b/pickleglass_web/backend_node/routes/user.js @@ -46,7 +46,8 @@ router.post('/find-or-create', async (req, res) => { router.post('/api-key', async (req, res) => { try { - await ipcRequest(req, 'save-api-key', req.body.apiKey); + const { apiKey, provider = 'openai' } = req.body; + await ipcRequest(req, 'save-api-key', { apiKey, provider }); res.json({ message: 'API key saved successfully' }); } catch (error) { console.error('Failed to save API key via IPC:', error); diff --git a/src/common/ai/factory.js b/src/common/ai/factory.js index 520ba06..6afe2a8 100644 --- a/src/common/ai/factory.js +++ b/src/common/ai/factory.js @@ -66,15 +66,14 @@ const PROVIDERS = { 'whisper': { name: 'Whisper (Local)', handler: () => { - // Only load in main process + // This needs to remain a function due to its conditional logic for renderer/main process if (typeof window === 'undefined') { return require("./providers/whisper"); } - // Return dummy for renderer + // Return a dummy object for the renderer process return { + validateApiKey: async () => ({ success: true }), // Mock validate for renderer createSTT: () => { throw new Error('Whisper STT is only available in main process'); }, - createLLM: () => { throw new Error('Whisper does not support LLM'); }, - createStreamingLLM: () => { throw new Error('Whisper does not support LLM'); } }; }, llmModels: [], @@ -130,6 +129,32 @@ function createStreamingLLM(provider, opts) { return handler.createStreamingLLM(opts); } +function getProviderClass(providerId) { + const providerConfig = PROVIDERS[providerId]; + if (!providerConfig) return null; + + // Handle special cases for glass providers + let actualProviderId = providerId; + if (providerId === 'openai-glass') { + actualProviderId = 'openai'; + } + + // The handler function returns the module, from which we get the class. + const module = providerConfig.handler(); + + // Map provider IDs to their actual exported class names + const classNameMap = { + 'openai': 'OpenAIProvider', + 'anthropic': 'AnthropicProvider', + 'gemini': 'GeminiProvider', + 'ollama': 'OllamaProvider', + 'whisper': 'WhisperProvider' + }; + + const className = classNameMap[actualProviderId]; + return className ? module[className] : null; +} + function getAvailableProviders() { const stt = []; const llm = []; @@ -145,5 +170,6 @@ module.exports = { createSTT, createLLM, createStreamingLLM, + getProviderClass, getAvailableProviders, }; \ No newline at end of file diff --git a/src/common/ai/providers/anthropic.js b/src/common/ai/providers/anthropic.js index f86a363..315e7cb 100644 --- a/src/common/ai/providers/anthropic.js +++ b/src/common/ai/providers/anthropic.js @@ -1,4 +1,38 @@ -const Anthropic = require("@anthropic-ai/sdk") +const { Anthropic } = require("@anthropic-ai/sdk") + +class AnthropicProvider { + static async validateApiKey(key) { + if (!key || typeof key !== 'string' || !key.startsWith('sk-ant-')) { + return { success: false, error: 'Invalid Anthropic API key format.' }; + } + + try { + const response = await fetch("https://api.anthropic.com/v1/messages", { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-api-key": key, + "anthropic-version": "2023-06-01", + }, + body: JSON.stringify({ + model: "claude-3-haiku-20240307", + max_tokens: 1, + messages: [{ role: "user", content: "Hi" }], + }), + }); + + if (response.ok || response.status === 400) { // 400 is a valid response for a bad request, not a bad key + return { success: true }; + } else { + const errorData = await response.json().catch(() => ({})); + return { success: false, error: errorData.error?.message || `Validation failed with status: ${response.status}` }; + } + } catch (error) { + console.error(`[AnthropicProvider] Network error during key validation:`, error); + return { success: false, error: 'A network error occurred during validation.' }; + } + } +} /** * Creates an Anthropic STT session @@ -286,7 +320,8 @@ function createStreamingLLM({ } module.exports = { - createSTT, - createLLM, - createStreamingLLM, -} + AnthropicProvider, + createSTT, + createLLM, + createStreamingLLM +}; diff --git a/src/common/ai/providers/gemini.js b/src/common/ai/providers/gemini.js index 31f7e33..c09065f 100644 --- a/src/common/ai/providers/gemini.js +++ b/src/common/ai/providers/gemini.js @@ -1,6 +1,31 @@ const { GoogleGenerativeAI } = require("@google/generative-ai") const { GoogleGenAI } = require("@google/genai") +class GeminiProvider { + static async validateApiKey(key) { + if (!key || typeof key !== 'string') { + return { success: false, error: 'Invalid Gemini API key format.' }; + } + + try { + const validationUrl = `https://generativelanguage.googleapis.com/v1beta/models?key=${key}`; + const response = await fetch(validationUrl); + + if (response.ok) { + return { success: true }; + } else { + const errorData = await response.json().catch(() => ({})); + const message = errorData.error?.message || `Validation failed with status: ${response.status}`; + return { success: false, error: message }; + } + } catch (error) { + console.error(`[GeminiProvider] Network error during key validation:`, error); + return { success: false, error: 'A network error occurred during validation.' }; + } + } +} + + /** * Creates a Gemini STT session * @param {object} opts - Configuration options @@ -296,7 +321,8 @@ function createStreamingLLM({ apiKey, model = "gemini-2.5-flash", temperature = } module.exports = { - createSTT, - createLLM, - createStreamingLLM, -} + GeminiProvider, + createSTT, + createLLM, + createStreamingLLM +}; diff --git a/src/common/ai/providers/ollama.js b/src/common/ai/providers/ollama.js index c25e10c..a521ec1 100644 --- a/src/common/ai/providers/ollama.js +++ b/src/common/ai/providers/ollama.js @@ -1,6 +1,22 @@ const http = require('http'); const fetch = require('node-fetch'); +class OllamaProvider { + static async validateApiKey() { + try { + const response = await fetch('http://localhost:11434/api/tags'); + if (response.ok) { + return { success: true }; + } else { + return { success: false, error: 'Ollama service is not running. Please start Ollama first.' }; + } + } catch (error) { + return { success: false, error: 'Cannot connect to Ollama. Please ensure Ollama is installed and running.' }; + } + } +} + + function convertMessagesToOllamaFormat(messages) { return messages.map(msg => { if (Array.isArray(msg.content)) { @@ -237,6 +253,8 @@ function createStreamingLLM({ } module.exports = { + OllamaProvider, createLLM, - createStreamingLLM + createStreamingLLM, + convertMessagesToOllamaFormat }; \ No newline at end of file diff --git a/src/common/ai/providers/openai.js b/src/common/ai/providers/openai.js index 79b1ef3..fe37c61 100644 --- a/src/common/ai/providers/openai.js +++ b/src/common/ai/providers/openai.js @@ -1,5 +1,35 @@ const OpenAI = require('openai'); const WebSocket = require('ws'); +const { Portkey } = require('portkey-ai'); +const { Readable } = require('stream'); +const { getProviderForModel } = require('../factory.js'); + + +class OpenAIProvider { + static async validateApiKey(key) { + if (!key || typeof key !== 'string' || !key.startsWith('sk-')) { + return { success: false, error: 'Invalid OpenAI API key format.' }; + } + + try { + const response = await fetch('https://api.openai.com/v1/models', { + headers: { 'Authorization': `Bearer ${key}` } + }); + + if (response.ok) { + return { success: true }; + } else { + const errorData = await response.json().catch(() => ({})); + const message = errorData.error?.message || `Validation failed with status: ${response.status}`; + return { success: false, error: message }; + } + } catch (error) { + console.error(`[OpenAIProvider] Network error during key validation:`, error); + return { success: false, error: 'A network error occurred during validation.' }; + } + } +} + /** * Creates an OpenAI STT session @@ -206,7 +236,7 @@ function createLLM({ apiKey, model = 'gpt-4.1', temperature = 0.7, maxTokens = 2 }; } -/** +/** * Creates an OpenAI streaming LLM instance * @param {object} opts - Configuration options * @param {string} opts.apiKey - OpenAI API key @@ -257,7 +287,8 @@ function createStreamingLLM({ apiKey, model = 'gpt-4.1', temperature = 0.7, maxT } module.exports = { - createSTT, - createLLM, - createStreamingLLM + OpenAIProvider, + createSTT, + createLLM, + createStreamingLLM }; \ No newline at end of file diff --git a/src/common/ai/providers/whisper.js b/src/common/ai/providers/whisper.js index abeaa06..1190977 100644 --- a/src/common/ai/providers/whisper.js +++ b/src/common/ai/providers/whisper.js @@ -173,6 +173,11 @@ class WhisperSTTSession extends EventEmitter { } class WhisperProvider { + static async validateApiKey() { + // Whisper is a local service, no API key validation needed. + return { success: true }; + } + constructor() { this.whisperService = null; } @@ -224,8 +229,12 @@ class WhisperProvider { } async createStreamingLLM() { - throw new Error('Whisper provider does not support streaming LLM functionality'); + console.warn('[WhisperProvider] Streaming LLM is not supported by Whisper.'); + throw new Error('Whisper does not support LLM.'); } } -module.exports = new WhisperProvider(); \ No newline at end of file +module.exports = { + WhisperProvider, + WhisperSTTSession +}; \ No newline at end of file diff --git a/src/common/config/schema.js b/src/common/config/schema.js index 4dfc4ba..b1cfcd4 100644 --- a/src/common/config/schema.js +++ b/src/common/config/schema.js @@ -5,8 +5,6 @@ const LATEST_SCHEMA = { { name: 'display_name', type: 'TEXT NOT NULL' }, { name: 'email', type: 'TEXT NOT NULL' }, { name: 'created_at', type: 'INTEGER' }, - { name: 'api_key', type: 'TEXT' }, - { name: 'provider', type: 'TEXT DEFAULT \'openai\'' }, { name: 'auto_update_enabled', type: 'INTEGER DEFAULT 1' }, { name: 'has_migrated_to_firebase', type: 'INTEGER DEFAULT 0' } ] @@ -90,6 +88,28 @@ const LATEST_SCHEMA = { { name: 'installed', type: 'INTEGER DEFAULT 0' }, { name: 'installing', type: 'INTEGER DEFAULT 0' } ] + }, + provider_settings: { + columns: [ + { name: 'uid', type: 'TEXT NOT NULL' }, + { name: 'provider', type: 'TEXT NOT NULL' }, + { name: 'api_key', type: 'TEXT' }, + { name: 'selected_llm_model', type: 'TEXT' }, + { name: 'selected_stt_model', type: 'TEXT' }, + { name: 'created_at', type: 'INTEGER' }, + { name: 'updated_at', type: 'INTEGER' } + ], + constraints: ['PRIMARY KEY (uid, provider)'] + }, + user_model_selections: { + columns: [ + { name: 'uid', type: 'TEXT PRIMARY KEY' }, + { name: 'selected_llm_provider', type: 'TEXT' }, + { name: 'selected_llm_model', type: 'TEXT' }, + { name: 'selected_stt_provider', type: 'TEXT' }, + { name: 'selected_stt_model', type: 'TEXT' }, + { name: 'updated_at', type: 'INTEGER' } + ] } }; diff --git a/src/common/repositories/providerSettings/firebase.repository.js b/src/common/repositories/providerSettings/firebase.repository.js new file mode 100644 index 0000000..f7fed8f --- /dev/null +++ b/src/common/repositories/providerSettings/firebase.repository.js @@ -0,0 +1,83 @@ +const { collection, doc, getDoc, getDocs, setDoc, deleteDoc, query, where } = require('firebase/firestore'); +const { getFirestoreInstance: getFirestore } = require('../../services/firebaseClient'); +const { createEncryptedConverter } = require('../firestoreConverter'); + +// Create encrypted converter for provider settings +const providerSettingsConverter = createEncryptedConverter([ + 'api_key', // Encrypt API keys + 'selected_llm_model', // Encrypt model selections for privacy + 'selected_stt_model' +]); + +function providerSettingsCol() { + const db = getFirestore(); + return collection(db, 'provider_settings').withConverter(providerSettingsConverter); +} + +async function getByProvider(uid, provider) { + try { + const docRef = doc(providerSettingsCol(), `${uid}_${provider}`); + const docSnap = await getDoc(docRef); + return docSnap.exists() ? { id: docSnap.id, ...docSnap.data() } : null; + } catch (error) { + console.error('[ProviderSettings Firebase] Error getting provider settings:', error); + return null; + } +} + +async function getAllByUid(uid) { + try { + const q = query(providerSettingsCol(), where('uid', '==', uid)); + const querySnapshot = await getDocs(q); + return querySnapshot.docs.map(doc => ({ id: doc.id, ...doc.data() })); + } catch (error) { + console.error('[ProviderSettings Firebase] Error getting all provider settings:', error); + return []; + } +} + +async function upsert(uid, provider, settings) { + try { + const docRef = doc(providerSettingsCol(), `${uid}_${provider}`); + await setDoc(docRef, settings, { merge: true }); + return { changes: 1 }; + } catch (error) { + console.error('[ProviderSettings Firebase] Error upserting provider settings:', error); + throw error; + } +} + +async function remove(uid, provider) { + try { + const docRef = doc(providerSettingsCol(), `${uid}_${provider}`); + await deleteDoc(docRef); + return { changes: 1 }; + } catch (error) { + console.error('[ProviderSettings Firebase] Error removing provider settings:', error); + throw error; + } +} + +async function removeAllByUid(uid) { + try { + const settings = await getAllByUid(uid); + const deletePromises = settings.map(setting => { + const docRef = doc(providerSettingsCol(), setting.id); + return deleteDoc(docRef); + }); + + await Promise.all(deletePromises); + return { changes: settings.length }; + } catch (error) { + console.error('[ProviderSettings Firebase] Error removing all provider settings:', error); + throw error; + } +} + +module.exports = { + getByProvider, + getAllByUid, + upsert, + remove, + removeAllByUid +}; \ No newline at end of file diff --git a/src/common/repositories/providerSettings/index.js b/src/common/repositories/providerSettings/index.js new file mode 100644 index 0000000..d4fb384 --- /dev/null +++ b/src/common/repositories/providerSettings/index.js @@ -0,0 +1,65 @@ +const firebaseRepository = require('./firebase.repository'); +const sqliteRepository = require('./sqlite.repository'); + +let authService = null; + +function setAuthService(service) { + authService = service; +} + +function getBaseRepository() { + if (!authService) { + throw new Error('AuthService not set for providerSettings repository'); + } + + const user = authService.getCurrentUser(); + return user.isLoggedIn ? firebaseRepository : sqliteRepository; +} + +const providerSettingsRepositoryAdapter = { + // Core CRUD operations + async getByProvider(provider) { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.getByProvider(uid, provider); + }, + + async getAllByUid() { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.getAllByUid(uid); + }, + + async upsert(provider, settings) { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + const now = Date.now(); + + const settingsWithMeta = { + ...settings, + uid, + provider, + updated_at: now, + created_at: settings.created_at || now + }; + + return await repo.upsert(uid, provider, settingsWithMeta); + }, + + async remove(provider) { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.remove(uid, provider); + }, + + async removeAllByUid() { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.removeAllByUid(uid); + } +}; + +module.exports = { + ...providerSettingsRepositoryAdapter, + setAuthService +}; \ No newline at end of file diff --git a/src/common/repositories/providerSettings/sqlite.repository.js b/src/common/repositories/providerSettings/sqlite.repository.js new file mode 100644 index 0000000..bddd5b0 --- /dev/null +++ b/src/common/repositories/providerSettings/sqlite.repository.js @@ -0,0 +1,62 @@ +const sqliteClient = require('../../services/sqliteClient'); + +function getByProvider(uid, provider) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('SELECT * FROM provider_settings WHERE uid = ? AND provider = ?'); + return stmt.get(uid, provider) || null; +} + +function getAllByUid(uid) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('SELECT * FROM provider_settings WHERE uid = ? ORDER BY provider'); + return stmt.all(uid); +} + +function upsert(uid, provider, settings) { + const db = sqliteClient.getDb(); + + // Use SQLite's UPSERT syntax (INSERT ... ON CONFLICT ... DO UPDATE) + const stmt = db.prepare(` + INSERT INTO provider_settings (uid, provider, api_key, selected_llm_model, selected_stt_model, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(uid, provider) DO UPDATE SET + api_key = excluded.api_key, + selected_llm_model = excluded.selected_llm_model, + selected_stt_model = excluded.selected_stt_model, + updated_at = excluded.updated_at + `); + + const result = stmt.run( + uid, + provider, + settings.api_key || null, + settings.selected_llm_model || null, + settings.selected_stt_model || null, + settings.created_at || Date.now(), + settings.updated_at + ); + + return { changes: result.changes }; +} + +function remove(uid, provider) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('DELETE FROM provider_settings WHERE uid = ? AND provider = ?'); + const result = stmt.run(uid, provider); + return { changes: result.changes }; +} + +function removeAllByUid(uid) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('DELETE FROM provider_settings WHERE uid = ?'); + const result = stmt.run(uid); + return { changes: result.changes }; +} + +module.exports = { + getByProvider, + getAllByUid, + upsert, + remove, + removeAllByUid +}; \ No newline at end of file diff --git a/src/common/repositories/user/firebase.repository.js b/src/common/repositories/user/firebase.repository.js index 3867108..11c1c50 100644 --- a/src/common/repositories/user/firebase.repository.js +++ b/src/common/repositories/user/firebase.repository.js @@ -3,7 +3,7 @@ const { getFirestoreInstance } = require('../../services/firebaseClient'); const { createEncryptedConverter } = require('../firestoreConverter'); const encryptionService = require('../../services/encryptionService'); -const userConverter = createEncryptedConverter(['api_key']); +const userConverter = createEncryptedConverter([]); function usersCol() { const db = getFirestoreInstance(); @@ -38,11 +38,7 @@ async function getById(uid) { return docSnap.exists() ? docSnap.data() : null; } -async function saveApiKey(uid, apiKey, provider = 'openai') { - const docRef = doc(usersCol(), uid); - await setDoc(docRef, { api_key: apiKey, provider }, { merge: true }); - return { changes: 1 }; -} + async function update({ uid, displayName }) { const docRef = doc(usersCol(), uid); @@ -85,7 +81,6 @@ async function deleteById(uid) { module.exports = { findOrCreate, getById, - saveApiKey, update, deleteById, }; \ No newline at end of file diff --git a/src/common/repositories/user/index.js b/src/common/repositories/user/index.js index 6c90a87..adc6ac0 100644 --- a/src/common/repositories/user/index.js +++ b/src/common/repositories/user/index.js @@ -1,8 +1,16 @@ const sqliteRepository = require('./sqlite.repository'); const firebaseRepository = require('./firebase.repository'); -const authService = require('../../services/authService'); + +let authService = null; + +function setAuthService(service) { + authService = service; +} function getBaseRepository() { + if (!authService) { + throw new Error('AuthService has not been set for the user repository.'); + } const user = authService.getCurrentUser(); if (user && user.isLoggedIn) { return firebaseRepository; @@ -21,10 +29,7 @@ const userRepositoryAdapter = { return getBaseRepository().getById(uid); }, - saveApiKey: (apiKey, provider) => { - const uid = authService.getCurrentUserId(); - return getBaseRepository().saveApiKey(uid, apiKey, provider); - }, + update: (updateData) => { const uid = authService.getCurrentUserId(); @@ -37,4 +42,7 @@ const userRepositoryAdapter = { } }; -module.exports = userRepositoryAdapter; \ No newline at end of file +module.exports = { + ...userRepositoryAdapter, + setAuthService +}; \ No newline at end of file diff --git a/src/common/repositories/user/sqlite.repository.js b/src/common/repositories/user/sqlite.repository.js index d443f47..ce138d5 100644 --- a/src/common/repositories/user/sqlite.repository.js +++ b/src/common/repositories/user/sqlite.repository.js @@ -40,17 +40,7 @@ function getById(uid) { return db.prepare('SELECT * FROM users WHERE uid = ?').get(uid); } -function saveApiKey(uid, apiKey, provider = 'openai') { - const db = sqliteClient.getDb(); - try { - const result = db.prepare('UPDATE users SET api_key = ?, provider = ? WHERE uid = ?').run(apiKey, provider, uid); - console.log(`SQLite: API key saved for user ${uid} with provider ${provider}.`); - return { changes: result.changes }; - } catch (err) { - console.error('SQLite: Failed to save API key:', err); - throw err; - } -} + function update({ uid, displayName }) { const db = sqliteClient.getDb(); @@ -96,7 +86,6 @@ function deleteById(uid) { module.exports = { findOrCreate, getById, - saveApiKey, update, setMigrationComplete, deleteById diff --git a/src/common/repositories/userModelSelections/firebase.repository.js b/src/common/repositories/userModelSelections/firebase.repository.js new file mode 100644 index 0000000..58f879b --- /dev/null +++ b/src/common/repositories/userModelSelections/firebase.repository.js @@ -0,0 +1,55 @@ +const { collection, doc, getDoc, setDoc, deleteDoc } = require('firebase/firestore'); +const { getFirestoreInstance: getFirestore } = require('../../services/firebaseClient'); +const { createEncryptedConverter } = require('../firestoreConverter'); + +// Create encrypted converter for user model selections +const userModelSelectionsConverter = createEncryptedConverter([ + 'selected_llm_provider', + 'selected_llm_model', + 'selected_stt_provider', + 'selected_stt_model' +]); + +function userModelSelectionsCol() { + const db = getFirestore(); + return collection(db, 'user_model_selections').withConverter(userModelSelectionsConverter); +} + +async function get(uid) { + try { + const docRef = doc(userModelSelectionsCol(), uid); + const docSnap = await getDoc(docRef); + return docSnap.exists() ? { id: docSnap.id, ...docSnap.data() } : null; + } catch (error) { + console.error('[UserModelSelections Firebase] Error getting user model selections:', error); + return null; + } +} + +async function upsert(uid, selections) { + try { + const docRef = doc(userModelSelectionsCol(), uid); + await setDoc(docRef, selections, { merge: true }); + return { changes: 1 }; + } catch (error) { + console.error('[UserModelSelections Firebase] Error upserting user model selections:', error); + throw error; + } +} + +async function remove(uid) { + try { + const docRef = doc(userModelSelectionsCol(), uid); + await deleteDoc(docRef); + return { changes: 1 }; + } catch (error) { + console.error('[UserModelSelections Firebase] Error removing user model selections:', error); + throw error; + } +} + +module.exports = { + get, + upsert, + remove +}; \ No newline at end of file diff --git a/src/common/repositories/userModelSelections/index.js b/src/common/repositories/userModelSelections/index.js new file mode 100644 index 0000000..e886af0 --- /dev/null +++ b/src/common/repositories/userModelSelections/index.js @@ -0,0 +1,50 @@ +const firebaseRepository = require('./firebase.repository'); +const sqliteRepository = require('./sqlite.repository'); + +let authService = null; + +function setAuthService(service) { + authService = service; +} + +function getBaseRepository() { + if (!authService) { + throw new Error('AuthService not set for userModelSelections repository'); + } + + const user = authService.getCurrentUser(); + return user.isLoggedIn ? firebaseRepository : sqliteRepository; +} + +const userModelSelectionsRepositoryAdapter = { + async get() { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.get(uid); + }, + + async upsert(selections) { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + const now = Date.now(); + + const selectionsWithMeta = { + ...selections, + uid, + updated_at: now + }; + + return await repo.upsert(uid, selectionsWithMeta); + }, + + async remove() { + const repo = getBaseRepository(); + const uid = authService.getCurrentUserId(); + return await repo.remove(uid); + } +}; + +module.exports = { + ...userModelSelectionsRepositoryAdapter, + setAuthService +}; \ No newline at end of file diff --git a/src/common/repositories/userModelSelections/sqlite.repository.js b/src/common/repositories/userModelSelections/sqlite.repository.js new file mode 100644 index 0000000..abd38df --- /dev/null +++ b/src/common/repositories/userModelSelections/sqlite.repository.js @@ -0,0 +1,48 @@ +const sqliteClient = require('../../services/sqliteClient'); + +function get(uid) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('SELECT * FROM user_model_selections WHERE uid = ?'); + return stmt.get(uid) || null; +} + +function upsert(uid, selections) { + const db = sqliteClient.getDb(); + + // Use SQLite's UPSERT syntax (INSERT ... ON CONFLICT ... DO UPDATE) + const stmt = db.prepare(` + INSERT INTO user_model_selections (uid, selected_llm_provider, selected_llm_model, + selected_stt_provider, selected_stt_model, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(uid) DO UPDATE SET + selected_llm_provider = excluded.selected_llm_provider, + selected_llm_model = excluded.selected_llm_model, + selected_stt_provider = excluded.selected_stt_provider, + selected_stt_model = excluded.selected_stt_model, + updated_at = excluded.updated_at + `); + + const result = stmt.run( + uid, + selections.selected_llm_provider || null, + selections.selected_llm_model || null, + selections.selected_stt_provider || null, + selections.selected_stt_model || null, + selections.updated_at + ); + + return { changes: result.changes }; +} + +function remove(uid) { + const db = sqliteClient.getDb(); + const stmt = db.prepare('DELETE FROM user_model_selections WHERE uid = ?'); + const result = stmt.run(uid); + return { changes: result.changes }; +} + +module.exports = { + get, + upsert, + remove +}; \ No newline at end of file diff --git a/src/common/services/authService.js b/src/common/services/authService.js index 55862ff..335c769 100644 --- a/src/common/services/authService.js +++ b/src/common/services/authService.js @@ -5,6 +5,8 @@ const fetch = require('node-fetch'); const encryptionService = require('./encryptionService'); const migrationService = require('./migrationService'); const sessionRepository = require('../repositories/session'); +const providerSettingsRepository = require('../repositories/providerSettings'); +const userModelSelectionsRepository = require('../repositories/userModelSelections'); async function getVirtualKeyByEmail(email, idToken) { if (!idToken) { @@ -40,10 +42,13 @@ class AuthService { this.currentUser = null; this.isInitialized = false; - // Initialize immediately for the default local user on startup. // This ensures the key is ready before any login/logout state change. encryptionService.initializeKey(this.currentUserId); this.initializationPromise = null; + + sessionRepository.setAuthService(this); + providerSettingsRepository.setAuthService(this); + userModelSelectionsRepository.setAuthService(this); } initialize() { diff --git a/src/common/services/modelStateService.js b/src/common/services/modelStateService.js index a83b4ac..b5bed8b 100644 --- a/src/common/services/modelStateService.js +++ b/src/common/services/modelStateService.js @@ -1,21 +1,28 @@ const Store = require('electron-store'); const fetch = require('node-fetch'); const { ipcMain, webContents } = require('electron'); -const { PROVIDERS } = require('../ai/factory'); +const { PROVIDERS, getProviderClass } = require('../ai/factory'); const encryptionService = require('./encryptionService'); +const providerSettingsRepository = require('../repositories/providerSettings'); +const userModelSelectionsRepository = require('../repositories/userModelSelections'); class ModelStateService { constructor(authService) { this.authService = authService; this.store = new Store({ name: 'pickle-glass-model-state' }); this.state = {}; + this.hasMigrated = false; + + // Set auth service for repositories + providerSettingsRepository.setAuthService(authService); + userModelSelectionsRepository.setAuthService(authService); } async initialize() { + console.log('[ModelStateService] Initializing...'); await this._loadStateForCurrentUser(); - this.setupIpcHandlers(); - console.log('[ModelStateService] Initialized.'); + console.log('[ModelStateService] Initialization complete'); } _logCurrentSelection() { @@ -64,44 +71,186 @@ class ModelStateService { }); } + async _migrateFromElectronStore() { + console.log('[ModelStateService] Starting migration from electron-store to database...'); + const userId = this.authService.getCurrentUserId(); + + try { + // Get data from electron-store + const legacyData = this.store.get(`users.${userId}`, null); + + if (!legacyData) { + console.log('[ModelStateService] No legacy data to migrate'); + return; + } + + console.log('[ModelStateService] Found legacy data, migrating...'); + + // Migrate provider settings (API keys and selected models per provider) + const { apiKeys = {}, selectedModels = {} } = legacyData; + + for (const [provider, apiKey] of Object.entries(apiKeys)) { + if (apiKey && PROVIDERS[provider]) { + // For encrypted keys, they are already decrypted in _loadStateForCurrentUser + await providerSettingsRepository.upsert(provider, { + api_key: apiKey + }); + console.log(`[ModelStateService] Migrated API key for ${provider}`); + } + } + + // Migrate global model selections + if (selectedModels.llm || selectedModels.stt) { + const llmProvider = selectedModels.llm ? this.getProviderForModel('llm', selectedModels.llm) : null; + const sttProvider = selectedModels.stt ? this.getProviderForModel('stt', selectedModels.stt) : null; + + await userModelSelectionsRepository.upsert({ + selected_llm_provider: llmProvider, + selected_llm_model: selectedModels.llm, + selected_stt_provider: sttProvider, + selected_stt_model: selectedModels.stt + }); + console.log('[ModelStateService] Migrated global model selections'); + } + + // Mark migration as complete by removing legacy data + this.store.delete(`users.${userId}`); + console.log('[ModelStateService] Migration completed and legacy data cleaned up'); + + } catch (error) { + console.error('[ModelStateService] Migration failed:', error); + // Don't throw - continue with normal operation + } + } + + async _loadStateFromDatabase() { + console.log('[ModelStateService] Loading state from database...'); + const userId = this.authService.getCurrentUserId(); + + try { + // Load provider settings + const providerSettings = await providerSettingsRepository.getAllByUid(); + const apiKeys = {}; + + // Reconstruct apiKeys object + Object.keys(PROVIDERS).forEach(provider => { + apiKeys[provider] = null; + }); + + for (const setting of providerSettings) { + if (setting.api_key) { + // API keys are stored encrypted in database, decrypt them + if (setting.provider !== 'ollama' && setting.provider !== 'whisper') { + try { + apiKeys[setting.provider] = encryptionService.decrypt(setting.api_key); + } catch (error) { + console.error(`[ModelStateService] Failed to decrypt API key for ${setting.provider}, resetting`); + apiKeys[setting.provider] = null; + } + } else { + apiKeys[setting.provider] = setting.api_key; + } + } + } + + // Load global model selections + const modelSelections = await userModelSelectionsRepository.get(); + const selectedModels = { + llm: modelSelections?.selected_llm_model || null, + stt: modelSelections?.selected_stt_model || null + }; + + this.state = { + apiKeys, + selectedModels + }; + + console.log(`[ModelStateService] State loaded from database for user: ${userId}`); + + } catch (error) { + console.error('[ModelStateService] Failed to load state from database:', error); + // Fall back to default state + const initialApiKeys = Object.keys(PROVIDERS).reduce((acc, key) => { + acc[key] = null; + return acc; + }, {}); + + this.state = { + apiKeys: initialApiKeys, + selectedModels: { llm: null, stt: null }, + }; + } + } + async _loadStateForCurrentUser() { const userId = this.authService.getCurrentUserId(); // Initialize encryption service for current user await encryptionService.initializeKey(userId); - const initialApiKeys = Object.keys(PROVIDERS).reduce((acc, key) => { - acc[key] = null; - return acc; - }, {}); - - const defaultState = { - apiKeys: initialApiKeys, - selectedModels: { llm: null, stt: null }, - }; - this.state = this.store.get(`users.${userId}`, defaultState); - console.log(`[ModelStateService] State loaded for user: ${userId}`); + // Try to load from database first + await this._loadStateFromDatabase(); - for (const p of Object.keys(PROVIDERS)) { - if (!(p in this.state.apiKeys)) { - this.state.apiKeys[p] = null; - } else if (this.state.apiKeys[p] && p !== 'ollama' && p !== 'whisper') { - try { - this.state.apiKeys[p] = encryptionService.decrypt(this.state.apiKeys[p]); - } catch (error) { - console.error(`[ModelStateService] Failed to decrypt API key for ${p}, resetting`); - this.state.apiKeys[p] = null; - } - } + // Check if we need to migrate from electron-store + const legacyData = this.store.get(`users.${userId}`, null); + if (legacyData && !this.hasMigrated) { + await this._migrateFromElectronStore(); + // Reload state after migration + await this._loadStateFromDatabase(); + this.hasMigrated = true; } this._autoSelectAvailableModels(); - this._saveState(); + await this._saveState(); this._logCurrentSelection(); } + async _saveState() { + console.log('[ModelStateService] Saving state to database...'); + const userId = this.authService.getCurrentUserId(); + + try { + // Save provider settings (API keys) + for (const [provider, apiKey] of Object.entries(this.state.apiKeys)) { + if (apiKey) { + const encryptedKey = (provider !== 'ollama' && provider !== 'whisper') + ? encryptionService.encrypt(apiKey) + : apiKey; + + await providerSettingsRepository.upsert(provider, { + api_key: encryptedKey + }); + } else { + // Remove empty API keys + await providerSettingsRepository.remove(provider); + } + } + + // Save global model selections + const llmProvider = this.state.selectedModels.llm ? this.getProviderForModel('llm', this.state.selectedModels.llm) : null; + const sttProvider = this.state.selectedModels.stt ? this.getProviderForModel('stt', this.state.selectedModels.stt) : null; + + if (llmProvider || sttProvider || this.state.selectedModels.llm || this.state.selectedModels.stt) { + await userModelSelectionsRepository.upsert({ + selected_llm_provider: llmProvider, + selected_llm_model: this.state.selectedModels.llm, + selected_stt_provider: sttProvider, + selected_stt_model: this.state.selectedModels.stt + }); + } + + console.log(`[ModelStateService] State saved to database for user: ${userId}`); + this._logCurrentSelection(); + + } catch (error) { + console.error('[ModelStateService] Failed to save state to database:', error); + // Fall back to electron-store for now + this._saveStateToElectronStore(); + } + } - _saveState() { + _saveStateToElectronStore() { + console.log('[ModelStateService] Falling back to electron-store...'); const userId = this.authService.getCurrentUserId(); const stateToSave = { ...this.state, @@ -120,93 +269,34 @@ class ModelStateService { } this.store.set(`users.${userId}`, stateToSave); - console.log(`[ModelStateService] State saved for user: ${userId}`); + console.log(`[ModelStateService] State saved to electron-store for user: ${userId}`); this._logCurrentSelection(); } async validateApiKey(provider, key) { - if (!key || key.trim() === '') { + if (!key || (key.trim() === '' && provider !== 'ollama' && provider !== 'whisper')) { return { success: false, error: 'API key cannot be empty.' }; } - let validationUrl, headers; - const body = undefined; + const ProviderClass = getProviderClass(provider); - switch (provider) { - case 'ollama': - // Ollama doesn't need API key validation - // Just check if the service is running - try { - const response = await fetch('http://localhost:11434/api/tags'); - if (response.ok) { - console.log(`[ModelStateService] Ollama service is accessible.`); - this.setApiKey(provider, 'local'); // Use 'local' as a placeholder - return { success: true }; - } else { - return { success: false, error: 'Ollama service is not running. Please start Ollama first.' }; - } - } catch (error) { - return { success: false, error: 'Cannot connect to Ollama. Please ensure Ollama is installed and running.' }; - } - case 'whisper': - // Whisper is a local service, no API key validation needed - console.log(`[ModelStateService] Whisper is a local service.`); - this.setApiKey(provider, 'local'); // Use 'local' as a placeholder - return { success: true }; - case 'openai': - validationUrl = 'https://api.openai.com/v1/models'; - headers = { 'Authorization': `Bearer ${key}` }; - break; - case 'gemini': - validationUrl = `https://generativelanguage.googleapis.com/v1beta/models?key=${key}`; - headers = {}; - break; - case 'anthropic': { - if (!key.startsWith('sk-ant-')) { - throw new Error('Invalid Anthropic key format.'); - } - const response = await fetch("https://api.anthropic.com/v1/messages", { - method: "POST", - headers: { - "Content-Type": "application/json", - "x-api-key": key, - "anthropic-version": "2023-06-01", - }, - body: JSON.stringify({ - model: "claude-3-haiku-20240307", - max_tokens: 1, - messages: [{ role: "user", content: "Hi" }], - }), - }); - - if (!response.ok && response.status !== 400) { - const errorData = await response.json().catch(() => ({})); - return { success: false, error: errorData.error?.message || `Validation failed with status: ${response.status}` }; - } - - console.log(`[ModelStateService] API key for ${provider} is valid.`); - this.setApiKey(provider, key); + if (!ProviderClass || typeof ProviderClass.validateApiKey !== 'function') { + // Default to success if no specific validator is found + console.warn(`[ModelStateService] No validateApiKey function for provider: ${provider}. Assuming valid.`); return { success: true }; - } - default: - return { success: false, error: 'Unknown provider.' }; } try { - const response = await fetch(validationUrl, { headers, body }); - if (response.ok) { + const result = await ProviderClass.validateApiKey(key); + if (result.success) { console.log(`[ModelStateService] API key for ${provider} is valid.`); - this.setApiKey(provider, key); - return { success: true }; } else { - const errorData = await response.json().catch(() => ({})); - const message = errorData.error?.message || `Validation failed with status: ${response.status}`; - console.log(`[ModelStateService] API key for ${provider} is invalid: ${message}`); - return { success: false, error: message }; + console.log(`[ModelStateService] API key for ${provider} is invalid: ${result.error}`); } + return result; } catch (error) { - console.error(`[ModelStateService] Network error during ${provider} key validation:`, error); - return { success: false, error: 'A network error occurred during validation.' }; + console.error(`[ModelStateService] Error during ${provider} key validation:`, error); + return { success: false, error: 'An unexpected error occurred during validation.' }; } } @@ -239,33 +329,14 @@ class ModelStateService { setApiKey(provider, key) { if (provider in this.state.apiKeys) { this.state.apiKeys[provider] = key; - - const llmModels = PROVIDERS[provider]?.llmModels; - const sttModels = PROVIDERS[provider]?.sttModels; - - // Prioritize newly set API key provider over existing selections - // Only for non-local providers or if no model is currently selected - if (llmModels?.length > 0) { - if (!this.state.selectedModels.llm || provider !== 'ollama') { - this.state.selectedModels.llm = llmModels[0].id; - console.log(`[ModelStateService] Selected LLM model from newly configured provider ${provider}: ${llmModels[0].id}`); - } - } - if (sttModels?.length > 0) { - if (!this.state.selectedModels.stt || provider !== 'whisper') { - this.state.selectedModels.stt = sttModels[0].id; - console.log(`[ModelStateService] Selected STT model from newly configured provider ${provider}: ${sttModels[0].id}`); - } - } this._saveState(); - this._logCurrentSelection(); return true; } return false; } getApiKey(provider) { - return this.state.apiKeys[provider] || null; + return this.state.apiKeys[provider]; } getAllApiKeys() { @@ -351,6 +422,18 @@ class ModelStateService { return result; } + hasValidApiKey() { + if (this.isLoggedInWithFirebase()) return true; + + // Check if any provider has a valid API key + return Object.entries(this.state.apiKeys).some(([provider, key]) => { + if (provider === 'ollama' || provider === 'whisper') { + return key === 'local'; + } + return key && key.trim().length > 0; + }); + } + getAvailableModels(type) { const available = []; @@ -445,10 +528,28 @@ class ModelStateService { } setupIpcHandlers() { - ipcMain.handle('model:validate-key', (e, { provider, key }) => this.validateApiKey(provider, key)); + ipcMain.handle('model:validate-key', async (e, { provider, key }) => { + const result = await this.validateApiKey(provider, key); + if (result.success) { + // Use 'local' as placeholder for local services + const finalKey = (provider === 'ollama' || provider === 'whisper') ? 'local' : key; + this.setApiKey(provider, finalKey); + // After setting the key, auto-select models + this._autoSelectAvailableModels(); + this._saveState(); // Ensure state is saved after model selection + } + return result; + }); ipcMain.handle('model:get-all-keys', () => this.getAllApiKeys()); - ipcMain.handle('model:set-api-key', (e, { provider, key }) => this.setApiKey(provider, key)); - ipcMain.handle('model:remove-api-key', (e, { provider }) => { + ipcMain.handle('model:set-api-key', async (e, { provider, key }) => { + const success = this.setApiKey(provider, key); + if (success) { + this._autoSelectAvailableModels(); + await this._saveState(); + } + return success; + }); + ipcMain.handle('model:remove-api-key', async (e, { provider }) => { const success = this.removeApiKey(provider); if (success) { const selectedModels = this.getSelectedModels(); @@ -461,7 +562,7 @@ class ModelStateService { return success; }); ipcMain.handle('model:get-selected-models', () => this.getSelectedModels()); - ipcMain.handle('model:set-selected-model', (e, { type, modelId }) => this.setSelectedModel(type, modelId)); + ipcMain.handle('model:set-selected-model', async (e, { type, modelId }) => this.setSelectedModel(type, modelId)); ipcMain.handle('model:get-available-models', (e, { type }) => this.getAvailableModels(type)); ipcMain.handle('model:are-providers-configured', () => this.areProvidersConfigured()); ipcMain.handle('model:get-current-model-info', (e, { type }) => this.getCurrentModelInfo(type)); diff --git a/src/common/services/sqliteClient.js b/src/common/services/sqliteClient.js index afd64f3..c49d0b2 100644 --- a/src/common/services/sqliteClient.js +++ b/src/common/services/sqliteClient.js @@ -33,6 +33,13 @@ class SQLiteClient { return this.db; } + _validateAndQuoteIdentifier(identifier) { + if (!/^[a-zA-Z0-9_]+$/.test(identifier)) { + throw new Error(`Invalid database identifier used: ${identifier}. Only alphanumeric characters and underscores are allowed.`); + } + return `"${identifier}"`; + } + synchronizeSchema() { console.log('[DB Sync] Starting schema synchronization...'); const tablesInDb = this.getTablesFromDb(); @@ -57,25 +64,42 @@ class SQLiteClient { } createTable(tableName, tableSchema) { - const columnDefs = tableSchema.columns.map(col => `"${col.name}" ${col.type}`).join(', '); - const query = `CREATE TABLE IF NOT EXISTS "${tableName}" (${columnDefs})`; - + const safeTableName = this._validateAndQuoteIdentifier(tableName); + const columnDefs = tableSchema.columns + .map(col => `${this._validateAndQuoteIdentifier(col.name)} ${col.type}`) + .join(', '); + + const constraints = tableSchema.constraints || []; + const constraintsDef = constraints.length > 0 ? ', ' + constraints.join(', ') : ''; + + const query = `CREATE TABLE IF NOT EXISTS ${safeTableName} (${columnDefs}${constraintsDef})`; console.log(`[DB Sync] Creating table: ${tableName}`); - this.db.prepare(query).run(); + this.db.exec(query); } updateTable(tableName, tableSchema) { - const existingColumns = this.db.prepare(`PRAGMA table_info("${tableName}")`).all(); - const existingColumnNames = existingColumns.map(c => c.name); - const columnsToAdd = tableSchema.columns.filter(col => !existingColumnNames.includes(col.name)); + const safeTableName = this._validateAndQuoteIdentifier(tableName); + + // Get current columns + const currentColumns = this.db.prepare(`PRAGMA table_info(${safeTableName})`).all(); + const currentColumnNames = currentColumns.map(col => col.name); - if (columnsToAdd.length > 0) { - console.log(`[DB Sync] Updating table: ${tableName}. Adding columns: ${columnsToAdd.map(c=>c.name).join(', ')}`); - for (const column of columnsToAdd) { - const addColumnQuery = `ALTER TABLE "${tableName}" ADD COLUMN "${column.name}" ${column.type}`; - this.db.prepare(addColumnQuery).run(); + // Check for new columns to add + const newColumns = tableSchema.columns.filter(col => !currentColumnNames.includes(col.name)); + + if (newColumns.length > 0) { + console.log(`[DB Sync] Adding ${newColumns.length} new column(s) to ${tableName}`); + for (const col of newColumns) { + const safeColName = this._validateAndQuoteIdentifier(col.name); + const addColumnQuery = `ALTER TABLE ${safeTableName} ADD COLUMN ${safeColName} ${col.type}`; + this.db.exec(addColumnQuery); + console.log(`[DB Sync] Added column ${col.name} to ${tableName}`); } } + + if (tableSchema.constraints && tableSchema.constraints.length > 0) { + console.log(`[DB Sync] Note: Constraints for ${tableName} can only be set during table creation`); + } } runQuery(query, params = []) { diff --git a/src/features/settings/settingsService.js b/src/features/settings/settingsService.js index 801138a..069c41d 100644 --- a/src/features/settings/settingsService.js +++ b/src/features/settings/settingsService.js @@ -1,7 +1,6 @@ const { ipcMain, BrowserWindow } = require('electron'); const Store = require('electron-store'); const authService = require('../../common/services/authService'); -const userRepository = require('../../common/repositories/user'); const settingsRepository = require('./repositories'); const { getStoredApiKey, getStoredProvider, windowPool } = require('../../electron/windowManager'); @@ -282,26 +281,13 @@ async function deletePreset(id) { async function saveApiKey(apiKey, provider = 'openai') { try { - const user = authService.getCurrentUser(); - if (!user.isLoggedIn) { - // For non-logged-in users, save to local storage - const Store = require('electron-store'); - const store = new Store(); - store.set('apiKey', apiKey); - store.set('provider', provider); - - // Notify windows - BrowserWindow.getAllWindows().forEach(win => { - if (!win.isDestroyed()) { - win.webContents.send('api-key-validated', apiKey); - } - }); - - return { success: true }; + // Use ModelStateService as the single source of truth for API key management + const modelStateService = global.modelStateService; + if (!modelStateService) { + throw new Error('ModelStateService not initialized'); } - // For logged-in users, use the repository adapter which injects the UID. - await userRepository.saveApiKey(apiKey, provider); + await modelStateService.setApiKey(provider, apiKey); // Notify windows BrowserWindow.getAllWindows().forEach(win => { @@ -319,16 +305,16 @@ async function saveApiKey(apiKey, provider = 'openai') { async function removeApiKey() { try { - const user = authService.getCurrentUser(); - if (!user.isLoggedIn) { - // For non-logged-in users, remove from local storage - const Store = require('electron-store'); - const store = new Store(); - store.delete('apiKey'); - store.delete('provider'); - } else { - // For logged-in users, use the repository adapter. - await userRepository.saveApiKey(null, null); + // Use ModelStateService as the single source of truth for API key management + const modelStateService = global.modelStateService; + if (!modelStateService) { + throw new Error('ModelStateService not initialized'); + } + + // Remove all API keys for all providers + const providers = ['openai', 'anthropic', 'gemini', 'ollama', 'whisper']; + for (const provider of providers) { + await modelStateService.removeApiKey(provider); } // Notify windows diff --git a/src/index.js b/src/index.js index a63d8cb..c0e1a6c 100644 --- a/src/index.js +++ b/src/index.js @@ -680,13 +680,13 @@ function setupWebDataHandlers() { result = await userRepository.findOrCreate(payload); break; case 'save-api-key': - // Assuming payload is { apiKey, provider } - result = await userRepository.saveApiKey(payload.apiKey, payload.provider); + // Use ModelStateService as the single source of truth for API key management + result = await modelStateService.setApiKey(payload.provider, payload.apiKey); break; case 'check-api-key-status': - // Adapter injects UID - const user = await userRepository.getById(); - result = { hasApiKey: !!user?.api_key && user.api_key.length > 0 }; + // Use ModelStateService to check API key status + const hasApiKey = await modelStateService.hasValidApiKey(); + result = { hasApiKey }; break; case 'delete-account': // Adapter injects UID