Skip to content

Commit bd1118b

Browse files
committed
Extend parameter type inference
1 parent 5093117 commit bd1118b

4 files changed

Lines changed: 132 additions & 75 deletions

File tree

ShapeScript/StandardLibrary.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public let stdlibSymbols: Set<String> = {
2828
return keys
2929
}()
3030

31-
extension [String: Symbol] {
31+
extension Symbols {
3232
static func + (lhs: Symbols, rhs: Symbols) -> Symbols {
3333
lhs.merging(rhs) { $1 }
3434
}

ShapeScript/Symbols.swift

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
import Euclid
1010

11-
typealias Getter = (EvaluationContext) throws -> Value
12-
typealias Setter = (Value, EvaluationContext) throws -> Void
13-
typealias FunctionType = (parameterType: ValueType, returnType: ValueType)
11+
typealias Symbols = [String: Symbol]
1412

1513
enum Symbol {
16-
case function(FunctionType, (Value, EvaluationContext) throws -> Value)
14+
case function(FunctionType, Function)
1715
case property(ValueType, Setter, Getter)
1816
case block(BlockType, Getter)
1917
case constant(Value)
@@ -25,15 +23,12 @@ extension Symbol {
2523
static func function(
2624
_ parameterType: ValueType,
2725
_ returnType: ValueType,
28-
_ fn: @escaping (Value, EvaluationContext) throws -> Value
26+
_ fn: @escaping Function
2927
) -> Symbol {
3028
.function((parameterType, returnType), fn)
3129
}
3230

33-
static func command(
34-
_ parameterType: ValueType,
35-
_ fn: @escaping Setter
36-
) -> Symbol {
31+
static func command(_ parameterType: ValueType, _ fn: @escaping Setter) -> Symbol {
3732
.function(parameterType, .void) {
3833
try fn($0, $1)
3934
return .void
@@ -55,5 +50,3 @@ extension Symbol {
5550
}
5651
}
5752
}
58-
59-
typealias Symbols = [String: Symbol]

ShapeScript/Types.swift

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ extension [ValueType] {
239239
}
240240
}
241241

242+
// MARK: Function types
243+
244+
typealias Getter = (EvaluationContext) throws -> Value
245+
typealias Setter = (Value, EvaluationContext) throws -> Void
246+
typealias Function = (Value, EvaluationContext) throws -> Value
247+
typealias FunctionType = (parameterType: ValueType, returnType: ValueType)
248+
typealias Parameters = [String: ValueType]
249+
242250
// MARK: Block Types
243251

244252
typealias Options = [String: ValueType]
@@ -538,11 +546,14 @@ extension Value {
538546
}
539547

540548
extension Definition {
541-
func inferTypes(
542-
for _: inout [String: ValueType],
543-
in _: EvaluationContext,
544-
with _: ValueType
545-
) {}
549+
func inferTypes(for params: inout Parameters, in context: EvaluationContext) {
550+
switch type {
551+
case let .expression(expression):
552+
expression.inferTypes(for: &params, in: context, with: .any)
553+
case let .function(_, block), let .block(block):
554+
block.inferTypes(for: &params, in: context)
555+
}
556+
}
546557

547558
func staticSymbol(in context: EvaluationContext) throws -> Symbol {
548559
switch type {
@@ -560,7 +571,7 @@ extension Definition {
560571

561572
extension Expression {
562573
func inferTypes(
563-
for params: inout [String: ValueType],
574+
for params: inout Parameters,
564575
in context: EvaluationContext,
565576
with type: ValueType
566577
) {
@@ -587,14 +598,24 @@ extension Expression {
587598
for expression in expressions {
588599
expression.inferTypes(for: &params, in: context, with: type)
589600
}
601+
case .vector, .size, .color:
602+
for expression in expressions {
603+
expression.inferTypes(for: &params, in: context, with: .number)
604+
}
590605
default:
591-
// TODO: other cases
592-
return
606+
// TODO: figure out types for other cases
607+
for expression in expressions {
608+
expression.inferTypes(for: &params, in: context, with: .any)
609+
}
593610
}
594611
}
595-
case let .subscript(_, rhs):
596-
// TODO: lhs
612+
case let .subscript(lhs, rhs):
613+
// TODO: can we use rhs type to infer something about the lhs?
614+
lhs.inferTypes(for: &params, in: context, with: .any)
597615
rhs.inferTypes(for: &params, in: context, with: .union([.number, .string]))
616+
case let .member(expression, _):
617+
// TODO: can we use member name to infer something about the type?
618+
expression.inferTypes(for: &params, in: context, with: .any)
598619
case let .import(expression):
599620
expression.inferTypes(for: &params, in: context, with: .string)
600621
case let .infix(lhs, .step, rhs):
@@ -622,9 +643,12 @@ extension Expression {
622643
case let .infix(lhs, .and, rhs), let .infix(lhs, .or, rhs):
623644
lhs.inferTypes(for: &params, in: context, with: .boolean)
624645
rhs.inferTypes(for: &params, in: context, with: .boolean)
625-
case .infix(_, .equal, _), .infix(_, .unequal, _),
626-
.number, .string, .color, .member:
627-
return
646+
case let .infix(lhs, .equal, rhs), let .infix(lhs, .unequal, rhs):
647+
// TODO: if lhs/rhs type is known, can we infer the rhs/lhs matches?
648+
lhs.inferTypes(for: &params, in: context, with: .any)
649+
rhs.inferTypes(for: &params, in: context, with: .any)
650+
case .number, .string, .color:
651+
break
628652
}
629653
}
630654

@@ -766,10 +790,7 @@ extension Expression {
766790
}
767791

768792
extension Block {
769-
func inferTypes(
770-
for params: inout [String: ValueType],
771-
in context: EvaluationContext
772-
) {
793+
func inferTypes(for params: inout Parameters, in context: EvaluationContext) {
773794
context.pushScope { context in
774795
statements.gatherDefinitions(in: context)
775796
statements.forEach { $0.inferTypes(for: &params, in: context) }
@@ -781,10 +802,7 @@ extension Block {
781802
return try staticType(in: context, options: &options)
782803
}
783804

784-
func staticType(
785-
in context: EvaluationContext,
786-
options: inout Options?
787-
) throws -> ValueType {
805+
func staticType(in context: EvaluationContext, options: inout Options?) throws -> ValueType {
788806
var types = [ValueType]()
789807
statements.gatherDefinitions(in: context)
790808
for statement in statements {
@@ -810,7 +828,7 @@ extension Block {
810828

811829
extension CaseStatement {
812830
func inferTypes(
813-
for params: inout [String: ValueType],
831+
for params: inout Parameters,
814832
in context: EvaluationContext,
815833
with type: ValueType
816834
) {
@@ -822,38 +840,28 @@ extension CaseStatement {
822840
}
823841

824842
extension Statement {
825-
func inferTypes(
826-
for params: inout [String: ValueType],
827-
in context: EvaluationContext
828-
) {
843+
func inferTypes(for params: inout Parameters, in context: EvaluationContext) {
829844
switch type {
830845
case let .command(identifier, expression):
831-
guard let expression,
832-
let symbol = context.symbol(for: identifier.name)
833-
else {
846+
guard let expression, let symbol = context.symbol(for: identifier.name) else {
847+
// Probably an error, but gather types anyway in case it helps
848+
expression?.inferTypes(for: &params, in: context, with: .any)
834849
return
835850
}
836851
switch symbol {
837852
case let .function(type, _):
838-
expression.inferTypes(
839-
for: &params,
840-
in: context,
841-
with: type.parameterType
842-
)
853+
expression.inferTypes(for: &params, in: context, with: type.parameterType)
843854
case let .property(type, _, _):
844855
expression.inferTypes(for: &params, in: context, with: type)
845856
case let .block(type, _):
846-
expression.inferTypes(
847-
for: &params,
848-
in: context,
849-
with: type.childTypes
850-
)
857+
expression.inferTypes(for: &params, in: context, with: .list(type.childTypes))
851858
case .constant, .option, .placeholder:
852859
return
853860
}
854861
case let .expression(type):
855862
Expression(type: type, range: range).inferTypes(for: &params, in: context, with: .any)
856863
case let .define(identifier, definition):
864+
definition.inferTypes(for: &params, in: context)
857865
if let symbol = try? definition.staticSymbol(in: context) {
858866
context.define(identifier.name, as: symbol)
859867
}
@@ -871,8 +879,8 @@ extension Statement {
871879
caseStatement.inferTypes(for: &params, in: context, with: type)
872880
}
873881
elseBody?.inferTypes(for: &params, in: context)
874-
case .option:
875-
return
882+
case let .option(_, expression):
883+
expression.inferTypes(for: &params, in: context, with: .any)
876884
}
877885
}
878886

0 commit comments

Comments
 (0)