Skip to content

Commit 462de01

Browse files
cdeustclaude
andcommitted
feat: Add context-aware section generation to prevent hallucination
Implement intelligent context extraction for each PRD section to stay within token limits while maintaining accuracy. The ContextManager extracts only relevant portions of the full input based on section type, preventing hallucination caused by exceeding model context windows. Key changes: - Added ContextManager component for section-specific context extraction - Integrated context management into PhaseGenerator workflow - Set generation context in PRDOrchestrator before section generation - Each section now receives targeted context instead of full input Benefits: - Prevents hallucination by staying within token limits - Maintains accuracy by providing relevant context per section - Improves generation quality for large projects - Reduces API costs through optimized context usage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 0926c36 commit 462de01

File tree

3 files changed

+244
-11
lines changed

3 files changed

+244
-11
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import Foundation
2+
import CommonModels
3+
import DomainCore
4+
5+
/// Manages context extraction and token budget for section generation
6+
/// Ensures each section stays within provider-specific token limits
7+
public final class ContextManager {
8+
// Token limits per provider type
9+
private static let appleIntelligenceTokenLimit = 3500 // Leave buffer for response
10+
private static let defaultTokenLimit = 8000
11+
12+
// Estimated tokens per character (rough approximation)
13+
private static let tokensPerChar: Double = 0.25
14+
15+
/// Extract minimal context needed for a specific section
16+
/// This prevents context overflow by providing only relevant information
17+
public func extractContextForSection(
18+
sectionName: String,
19+
fullInput: String,
20+
enrichedRequirements: EnrichedRequirements?,
21+
stackContext: StackContext?,
22+
providerName: String = "default"
23+
) -> String {
24+
let tokenLimit = getTokenLimit(for: providerName)
25+
let maxChars = Int(Double(tokenLimit) / Self.tokensPerChar)
26+
27+
var contextParts: [String] = []
28+
29+
// 1. Always include core request (truncated if needed)
30+
let coreRequest = truncateIfNeeded(fullInput, maxLength: maxChars / 3)
31+
contextParts.append("### Core Request\n\(coreRequest)")
32+
33+
// 2. Include relevant clarifications (not all)
34+
if let enriched = enrichedRequirements, !enriched.clarifications.isEmpty {
35+
let relevantClarifications = selectRelevantClarifications(
36+
for: sectionName,
37+
from: enriched,
38+
maxLength: maxChars / 4
39+
)
40+
if !relevantClarifications.isEmpty {
41+
contextParts.append("### Clarifications\n\(relevantClarifications)")
42+
}
43+
}
44+
45+
// 3. Include minimal stack info if relevant for this section
46+
if let stack = stackContext, isStackRelevantForSection(sectionName) {
47+
let stackSummary = summarizeStack(stack, maxLength: maxChars / 6)
48+
contextParts.append("### Tech Stack\n\(stackSummary)")
49+
}
50+
51+
// 4. Join and ensure we're under limit
52+
let combined = contextParts.joined(separator: "\n\n")
53+
return truncateIfNeeded(combined, maxLength: maxChars)
54+
}
55+
56+
/// Get token limit based on provider
57+
private func getTokenLimit(for providerName: String) -> Int {
58+
if providerName.lowercased().contains("apple") {
59+
return Self.appleIntelligenceTokenLimit
60+
}
61+
return Self.defaultTokenLimit
62+
}
63+
64+
/// Determine if tech stack context is relevant for this section
65+
private func isStackRelevantForSection(_ sectionName: String) -> Bool {
66+
let stackRelevantSections = [
67+
"API Specification",
68+
"Test Requirements",
69+
"Performance & Security Constraints",
70+
"Technical Stack Context",
71+
"Data Model"
72+
]
73+
return stackRelevantSections.contains { sectionName.contains($0) }
74+
}
75+
76+
/// Select only clarifications relevant to the section being generated
77+
private func selectRelevantClarifications(
78+
for sectionName: String,
79+
from enriched: EnrichedRequirements,
80+
maxLength: Int
81+
) -> String {
82+
// Map sections to relevant clarification keywords
83+
let sectionKeywords = getSectionKeywords(sectionName)
84+
85+
var selectedClarifications: [(String, String)] = []
86+
var totalLength = 0
87+
88+
for (question, answer) in enriched.clarifications {
89+
// Check if clarification is relevant to this section
90+
let isRelevant = sectionKeywords.contains { keyword in
91+
question.lowercased().contains(keyword.lowercased()) ||
92+
answer.lowercased().contains(keyword.lowercased())
93+
}
94+
95+
if isRelevant {
96+
let entry = "Q: \(question)\nA: \(answer)"
97+
if totalLength + entry.count <= maxLength {
98+
selectedClarifications.append((question, answer))
99+
totalLength += entry.count
100+
} else {
101+
break // Stop if we exceed max length
102+
}
103+
}
104+
}
105+
106+
if selectedClarifications.isEmpty {
107+
// If no specific matches, include first few clarifications
108+
let firstFew = Array(enriched.clarifications.prefix(2))
109+
return firstFew.map { "Q: \($0.key)\nA: \($0.value)" }.joined(separator: "\n")
110+
}
111+
112+
return selectedClarifications.map { "Q: \($0.0)\nA: \($0.1)" }.joined(separator: "\n")
113+
}
114+
115+
/// Get keywords relevant to a section
116+
private func getSectionKeywords(_ sectionName: String) -> [String] {
117+
switch sectionName {
118+
case let s where s.contains("Overview"):
119+
return ["purpose", "goal", "problem", "overview"]
120+
case let s where s.contains("User Stories"):
121+
return ["user", "persona", "role", "actor"]
122+
case let s where s.contains("Feature"):
123+
return ["feature", "functionality", "capability"]
124+
case let s where s.contains("Data Model"):
125+
return ["data", "model", "schema", "entity", "table"]
126+
case let s where s.contains("API"):
127+
return ["api", "endpoint", "request", "response", "rest"]
128+
case let s where s.contains("Test"):
129+
return ["test", "testing", "validation", "quality"]
130+
case let s where s.contains("Constraint"):
131+
return ["performance", "security", "constraint", "limit"]
132+
case let s where s.contains("Validation"):
133+
return ["success", "criteria", "acceptance", "validation"]
134+
default:
135+
return ["feature", "requirement"]
136+
}
137+
}
138+
139+
/// Summarize stack context to minimal essentials
140+
private func summarizeStack(_ stack: StackContext, maxLength: Int) -> String {
141+
var summary = "- Language: \(stack.language)"
142+
143+
if let db = stack.database {
144+
summary += "\n- Database: \(db)"
145+
}
146+
147+
if let test = stack.testFramework {
148+
summary += "\n- Testing: \(test)"
149+
}
150+
151+
if let deploy = stack.deployment {
152+
summary += "\n- Deployment: \(deploy)"
153+
}
154+
155+
return truncateIfNeeded(summary, maxLength: maxLength)
156+
}
157+
158+
/// Truncate text if it exceeds max length
159+
private func truncateIfNeeded(_ text: String, maxLength: Int) -> String {
160+
if text.count <= maxLength {
161+
return text
162+
}
163+
164+
let truncated = String(text.prefix(maxLength - 20))
165+
return truncated + "\n...(truncated)"
166+
}
167+
168+
/// Estimate token count for text
169+
public func estimateTokenCount(_ text: String) -> Int {
170+
return Int(Double(text.count) * Self.tokensPerChar)
171+
}
172+
173+
/// Check if text is within token limit for provider
174+
public func isWithinLimit(_ text: String, providerName: String) -> Bool {
175+
let limit = getTokenLimit(for: providerName)
176+
let estimated = estimateTokenCount(text)
177+
return estimated <= limit
178+
}
179+
}

swift/Sources/PRDGenerator/PRDOrchestrator.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ public final class PRDOrchestrator {
4646
self.documentAssembler = DocumentAssembler(interactionHandler: self.interactionHandler)
4747
self.sectionGenerator = SectionGenerator(provider: provider, configuration: configuration)
4848
self.taskContextDetector = TaskContextDetector()
49+
50+
// Extract provider name for context management
51+
let providerName = provider.name
52+
4953
self.phaseGenerator = PhaseGenerator(
5054
provider: provider,
5155
configuration: configuration,
5256
assumptionTracker: assumptionTracker,
5357
interactionHandler: self.interactionHandler,
54-
sectionGenerator: sectionGenerator
58+
sectionGenerator: sectionGenerator,
59+
providerName: providerName
5560
)
5661
}
5762

@@ -83,6 +88,14 @@ public final class PRDOrchestrator {
8388
let discoveredStack = try await stackDiscovery.discoverTechnicalStack(input: workingInput)
8489
self.stackContext = discoveredStack
8590

91+
// Set generation context for PhaseGenerator
92+
// This enables context-aware section generation that respects token limits
93+
phaseGenerator.setGenerationContext(
94+
fullInput: workingInput,
95+
enrichedRequirements: enrichedReqs,
96+
stackContext: discoveredStack
97+
)
98+
8699
// Add requirements analysis summary if clarifications were provided
87100
if enrichedReqs.wasClarified {
88101
sections.append(PRDSection(

swift/Sources/PRDGenerator/PhaseGenerator.swift

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,27 @@ public final class PhaseGenerator {
99
private let validationHandler: ValidationHandler
1010
private let reportFormatter: ReportFormatter
1111
private let interactionHandler: UserInteractionHandler
12+
private let contextManager: ContextManager
13+
private let providerName: String
14+
15+
// Context state for section generation
16+
private var fullInput: String = ""
17+
private var enrichedRequirements: EnrichedRequirements?
18+
private var stackContext: StackContext?
1219

1320
public init(
1421
provider: AIProvider,
1522
configuration: Configuration,
1623
assumptionTracker: AssumptionTracker,
1724
interactionHandler: UserInteractionHandler,
18-
sectionGenerator: SectionGenerator
25+
sectionGenerator: SectionGenerator,
26+
providerName: String = "default"
1927
) {
2028
self.sectionGenerator = sectionGenerator
2129
self.reportFormatter = ReportFormatter()
2230
self.interactionHandler = interactionHandler
31+
self.contextManager = ContextManager()
32+
self.providerName = providerName
2333
self.validationHandler = ValidationHandler(
2434
provider: provider,
2535
assumptionTracker: assumptionTracker,
@@ -29,12 +39,35 @@ public final class PhaseGenerator {
2939
)
3040
}
3141

42+
/// Set generation context - called before section generation begins
43+
public func setGenerationContext(
44+
fullInput: String,
45+
enrichedRequirements: EnrichedRequirements?,
46+
stackContext: StackContext?
47+
) {
48+
self.fullInput = fullInput
49+
self.enrichedRequirements = enrichedRequirements
50+
self.stackContext = stackContext
51+
}
52+
53+
/// Extract section-specific context to stay within token limits
54+
private func getSectionContext(for sectionName: String) -> String {
55+
return contextManager.extractContextForSection(
56+
sectionName: sectionName,
57+
fullInput: fullInput,
58+
enrichedRequirements: enrichedRequirements,
59+
stackContext: stackContext,
60+
providerName: providerName
61+
)
62+
}
63+
3264
// MARK: - Phase 1: Product Overview
3365

3466
public func generateProductOverview(input: String) async throws -> PRDSection {
3567
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.generatingPRD)
68+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.taskOverview)
3669
let overview = try await validationHandler.generateWithValidation(
37-
input: input,
70+
input: sectionContext,
3871
prompt: PRDPrompts.overviewPrompt,
3972
sectionName: PRDDisplayConstants.SectionNames.taskOverview
4073
)
@@ -52,8 +85,9 @@ public final class PhaseGenerator {
5285
public func generateUserStories(input: String) async throws -> PRDSection {
5386
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.userStories)
5487
do {
88+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.userStories)
5589
let stories = try await validationHandler.generateWithValidation(
56-
input: input,
90+
input: sectionContext,
5791
prompt: PRDPrompts.userStoriesPrompt,
5892
sectionName: PRDDisplayConstants.SectionNames.userStories
5993
)
@@ -75,8 +109,9 @@ public final class PhaseGenerator {
75109

76110
public func generateFeatures(input: String) async throws -> PRDSection {
77111
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.features)
112+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.featureChanges)
78113
let features = try await validationHandler.generateWithValidation(
79-
input: input,
114+
input: sectionContext,
80115
prompt: PRDPrompts.featuresPrompt,
81116
sectionName: PRDDisplayConstants.SectionNames.featureChanges
82117
)
@@ -93,9 +128,10 @@ public final class PhaseGenerator {
93128
// MARK: - Phase 4: Data Model
94129

95130
public func generateDataModel(input: String) async throws -> PRDSection {
131+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.dataModel)
96132
let dataModelPrompt = PRDPrompts.dataModelPrompt
97133
let dataModel = try await validationHandler.generateWithValidation(
98-
input: input,
134+
input: sectionContext,
99135
prompt: dataModelPrompt,
100136
sectionName: PRDDisplayConstants.SectionNames.dataModel
101137
)
@@ -112,12 +148,13 @@ public final class PhaseGenerator {
112148

113149
public func generateAPIOperations(input: String, stack: StackContext) async throws -> PRDSection {
114150
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.apiOperations)
151+
let sectionContext = getSectionContext(for: PRDDisplayConstants.ExtendedSectionNames.apiSpecification)
115152
let apiPrompt = reportFormatter.enhancePromptWithStack(
116153
PRDPrompts.apiSpecPrompt,
117154
stack: stack
118155
)
119156
let apiSpec = try await validationHandler.generateWithValidation(
120-
input: input,
157+
input: sectionContext,
121158
prompt: apiPrompt,
122159
sectionName: PRDDisplayConstants.ExtendedSectionNames.apiSpecification
123160
)
@@ -135,12 +172,13 @@ public final class PhaseGenerator {
135172

136173
public func generateTestSpecifications(input: String, stack: StackContext) async throws -> PRDSection {
137174
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.testSpecs)
175+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.testRequirements)
138176
let testPrompt = reportFormatter.enhanceTestPromptWithStack(
139177
PRDPrompts.testSpecPrompt,
140178
stack: stack
141179
)
142180
let testSpec = try await validationHandler.generateWithValidation(
143-
input: input,
181+
input: sectionContext,
144182
prompt: testPrompt,
145183
sectionName: PRDDisplayConstants.SectionNames.testRequirements
146184
)
@@ -159,6 +197,7 @@ public final class PhaseGenerator {
159197

160198
public func generateConstraints(input: String, stack: StackContext) async throws -> PRDSection {
161199
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.constraints)
200+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.additionalConstraints)
162201
let stackDescription = String(
163202
format: PRDAnalysisConstants.StackFormatting.stackDescription,
164203
stack.language,
@@ -170,7 +209,7 @@ public final class PhaseGenerator {
170209
stack: stack
171210
)
172211
let constraints = try await validationHandler.generateWithValidation(
173-
input: input,
212+
input: sectionContext,
174213
prompt: constraintsPrompt,
175214
sectionName: PRDDisplayConstants.SectionNames.additionalConstraints
176215
)
@@ -188,8 +227,9 @@ public final class PhaseGenerator {
188227

189228
public func generateValidationCriteria(input: String) async throws -> PRDSection {
190229
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.validation)
230+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.successCriteria)
191231
let validation = try await validationHandler.generateWithValidation(
192-
input: input,
232+
input: sectionContext,
193233
prompt: PRDPrompts.validationPrompt,
194234
sectionName: PRDDisplayConstants.SectionNames.successCriteria
195235
)
@@ -208,12 +248,13 @@ public final class PhaseGenerator {
208248

209249
public func generateRoadmap(input: String, stack: StackContext) async throws -> PRDSection {
210250
interactionHandler.showProgress(PRDDisplayConstants.PhaseMessages.roadmap)
251+
let sectionContext = getSectionContext(for: PRDDisplayConstants.SectionNames.implementationSteps)
211252
let roadmapPrompt = reportFormatter.enhanceRoadmapPromptWithStack(
212253
PRDPrompts.roadmapPrompt,
213254
stack: stack
214255
)
215256
let roadmap = try await validationHandler.generateWithValidation(
216-
input: input,
257+
input: sectionContext,
217258
prompt: roadmapPrompt,
218259
sectionName: PRDDisplayConstants.SectionNames.implementationSteps
219260
)

0 commit comments

Comments
 (0)