OILS / ysh / func_proc.py View on Github | oilshell.org

571 lines, 378 significant
1#!/usr/bin/env python2
2"""
3User-defined funcs and procs
4"""
5from __future__ import print_function
6
7from _devbuild.gen.id_kind_asdl import Id
8from _devbuild.gen.runtime_asdl import cmd_value
9from _devbuild.gen.syntax_asdl import (proc_sig, proc_sig_e, Param, ParamGroup,
10 NamedArg, Func, loc, ArgList, expr,
11 expr_e, expr_t)
12from _devbuild.gen.value_asdl import (value, value_e, value_t, ProcDefaults,
13 LeftName)
14
15from core import error
16from core.error import e_die
17from core import state
18from core import vm
19from frontend import lexer
20from frontend import typed_args
21from mycpp import mylib
22from mycpp.mylib import log, NewDict
23
24from typing import List, Tuple, Dict, Optional, cast, TYPE_CHECKING
25if TYPE_CHECKING:
26 from _devbuild.gen.syntax_asdl import command, loc_t
27 from osh import cmd_eval
28 from ysh import expr_eval
29
30_ = log
31
32# TODO:
33# - use _EvalExpr more?
34# - a single with state.ctx_YshExpr -- I guess that's faster
35# - although EvalExpr() can take param.blame_tok
36
37
38def _DisallowMutableDefault(val, blame_loc):
39 # type: (value_t, loc_t) -> None
40 if val.tag() in (value_e.List, value_e.Dict):
41 raise error.TypeErr(val, "Default values can't be mutable", blame_loc)
42
43
44def _EvalPosDefaults(expr_ev, pos_params):
45 # type: (expr_eval.ExprEvaluator, List[Param]) -> List[value_t]
46 """Shared between func and proc: Eval defaults for positional params"""
47
48 no_val = None # type: value_t
49 pos_defaults = [no_val] * len(pos_params)
50 for i, p in enumerate(pos_params):
51 if p.default_val:
52 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
53 _DisallowMutableDefault(val, p.blame_tok)
54 pos_defaults[i] = val
55 return pos_defaults
56
57
58def _EvalNamedDefaults(expr_ev, named_params):
59 # type: (expr_eval.ExprEvaluator, List[Param]) -> Dict[str, value_t]
60 """Shared between func and proc: Eval defaults for named params"""
61
62 named_defaults = NewDict() # type: Dict[str, value_t]
63 for i, p in enumerate(named_params):
64 if p.default_val:
65 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
66 _DisallowMutableDefault(val, p.blame_tok)
67 named_defaults[p.name] = val
68 return named_defaults
69
70
71def EvalFuncDefaults(
72 expr_ev, # type: expr_eval.ExprEvaluator
73 func, # type: Func
74):
75 # type: (...) -> Tuple[List[value_t], Dict[str, value_t]]
76 """Evaluate default args for funcs, at time of DEFINITION, not call."""
77
78 if func.positional:
79 pos_defaults = _EvalPosDefaults(expr_ev, func.positional.params)
80 else:
81 pos_defaults = None
82
83 if func.named:
84 named_defaults = _EvalNamedDefaults(expr_ev, func.named.params)
85 else:
86 named_defaults = None
87
88 return pos_defaults, named_defaults
89
90
91def EvalProcDefaults(expr_ev, sig):
92 # type: (expr_eval.ExprEvaluator, proc_sig.Closed) -> ProcDefaults
93 """Evaluate default args for procs, at time of DEFINITION, not call."""
94
95 no_val = None # type: value_t
96
97 if sig.word:
98 word_defaults = [no_val] * len(sig.word.params)
99 for i, p in enumerate(sig.word.params):
100 if p.default_val:
101 val = expr_ev.EvalExpr(p.default_val, p.blame_tok)
102 if val.tag() != value_e.Str:
103 raise error.TypeErr(
104 val, 'Default val for word param must be Str',
105 p.blame_tok)
106
107 word_defaults[i] = val
108 else:
109 word_defaults = None
110
111 if sig.positional:
112 pos_defaults = _EvalPosDefaults(expr_ev, sig.positional.params)
113 else:
114 pos_defaults = None # in case there's a block param
115
116 if sig.named:
117 named_defaults = _EvalNamedDefaults(expr_ev, sig.named.params)
118 else:
119 named_defaults = None
120
121 # cd /tmp (; ; myblock)
122 if sig.block_param:
123 exp = sig.block_param.default_val
124 if exp:
125 block_default = expr_ev.EvalExpr(exp, sig.block_param.blame_tok)
126 # It can only be ^() or null
127 if block_default.tag() not in (value_e.Null, value_e.Command):
128 raise error.TypeErr(
129 block_default,
130 "Default value for block should be Command or Null",
131 sig.block_param.blame_tok)
132 else:
133 block_default = None # no default, different than value.Null
134 else:
135 block_default = None
136
137 return ProcDefaults(word_defaults, pos_defaults, named_defaults,
138 block_default)
139
140
141def _EvalPosArgs(expr_ev, exprs, pos_args):
142 # type: (expr_eval.ExprEvaluator, List[expr_t], List[value_t]) -> None
143 """Shared between func and proc: evaluate positional args."""
144
145 for e in exprs:
146 UP_e = e
147 if e.tag() == expr_e.Spread:
148 e = cast(expr.Spread, UP_e)
149 val = expr_ev._EvalExpr(e.child)
150 if val.tag() != value_e.List:
151 raise error.TypeErr(val, 'Spread expected a List', e.left)
152 pos_args.extend(cast(value.List, val).items)
153 else:
154 pos_args.append(expr_ev._EvalExpr(e))
155
156
157def _EvalNamedArgs(expr_ev, named_exprs):
158 # type: (expr_eval.ExprEvaluator, List[NamedArg]) -> Dict[str, value_t]
159 """Shared between func and proc: evaluate named args."""
160
161 named_args = NewDict() # type: Dict[str, value_t]
162 for n in named_exprs:
163 val_expr = n.value
164 UP_val_expr = val_expr
165 if val_expr.tag() == expr_e.Spread:
166 val_expr = cast(expr.Spread, UP_val_expr)
167 val = expr_ev._EvalExpr(val_expr.child)
168 if val.tag() != value_e.Dict:
169 raise error.TypeErr(val, 'Spread expected a Dict',
170 val_expr.left)
171 named_args.update(cast(value.Dict, val).d)
172 else:
173 val = expr_ev.EvalExpr(n.value, n.name)
174 name = lexer.TokenVal(n.name)
175 named_args[name] = val
176
177 return named_args
178
179
180def _EvalArgList(
181 expr_ev, # type: expr_eval.ExprEvaluator
182 args, # type: ArgList
183 me=None # type: Optional[value_t]
184):
185 # type: (...) -> Tuple[List[value_t], Optional[Dict[str, value_t]]]
186 """Evaluate arg list for funcs.
187
188 This is a PRIVATE METHOD on ExprEvaluator, but it's in THIS FILE, because I
189 want it to be next to EvalTypedArgsToProc, which is similar.
190
191 It's not valid to call this without the EvalExpr() wrapper:
192
193 with state.ctx_YshExpr(...) # required to call this
194 ...
195 """
196 pos_args = [] # type: List[value_t]
197
198 if me: # self/this argument
199 pos_args.append(me)
200
201 _EvalPosArgs(expr_ev, args.pos_args, pos_args)
202
203 named_args = None # type: Dict[str, value_t]
204 if args.named_args is not None:
205 named_args = _EvalNamedArgs(expr_ev, args.named_args)
206
207 return pos_args, named_args
208
209
210def EvalTypedArgsToProc(
211 expr_ev, # type: expr_eval.ExprEvaluator
212 mutable_opts, # type: state.MutableOpts
213 node, # type: command.Simple
214 cmd_val, # type: cmd_value.Argv
215):
216 # type: (...) -> None
217 """Evaluate word, typed, named, and block args for a proc."""
218 cmd_val.typed_args = node.typed_args
219
220 # We only got here if the call looks like
221 # p (x)
222 # p { echo hi }
223 # p () { echo hi }
224 # So allocate this unconditionally
225 cmd_val.pos_args = []
226
227 ty = node.typed_args
228 if ty:
229 if ty.left.id == Id.Op_LBracket: # assert [42 === x]
230 # Defer evaluation by wrapping in value.Expr
231
232 for exp in ty.pos_args:
233 cmd_val.pos_args.append(value.Expr(exp))
234 # TODO: ...spread is illegal
235
236 n1 = ty.named_args
237 if n1 is not None:
238 cmd_val.named_args = NewDict()
239 for named_arg in n1:
240 name = lexer.TokenVal(named_arg.name)
241 cmd_val.named_args[name] = value.Expr(named_arg.value)
242 # TODO: ...spread is illegal
243
244 else: # json write (x)
245 with state.ctx_YshExpr(mutable_opts): # What EvalExpr() does
246 _EvalPosArgs(expr_ev, ty.pos_args, cmd_val.pos_args)
247
248 if ty.named_args is not None:
249 cmd_val.named_args = _EvalNamedArgs(expr_ev, ty.named_args)
250
251 if ty.block_expr and node.block:
252 e_die("Can't accept both block expression and block literal",
253 node.block.brace_group.left)
254
255 # p ( ; ; block) is an expression to be evaluated
256 if ty.block_expr:
257 # fallback location is (
258 cmd_val.block_arg = expr_ev.EvalExpr(ty.block_expr, ty.left)
259
260 # p { echo hi } is an unevaluated block
261 if node.block:
262 # TODO: conslidate value.Block (holds LiteralBlock) and value.Command
263 cmd_val.block_arg = value.Block(node.block)
264
265 # Add location info so the cmd_val looks the same for both:
266 # cd /tmp (; ; ^(echo hi))
267 # cd /tmp { echo hi }
268 if not cmd_val.typed_args:
269 cmd_val.typed_args = ArgList.CreateNull()
270
271 # Also add locations for error message: ls { echo invalid }
272 cmd_val.typed_args.left = node.block.brace_group.left
273 cmd_val.typed_args.right = node.block.brace_group.right
274
275
276def _BindWords(
277 proc_name, # type: str
278 group, # type: ParamGroup
279 defaults, # type: List[value_t]
280 cmd_val, # type: cmd_value.Argv
281 mem, # type: state.Mem
282 blame_loc, # type: loc_t
283):
284 # type: (...) -> None
285
286 argv = cmd_val.argv[1:]
287 num_args = len(argv)
288 for i, p in enumerate(group.params):
289 if i < num_args:
290 val = value.Str(argv[i]) # type: value_t
291 else: # default args were evaluated on definition
292 val = defaults[i]
293 if val is None:
294 raise error.Expr(
295 "proc %r wasn't passed word param %r" %
296 (proc_name, p.name), blame_loc)
297
298 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
299
300 # ...rest
301
302 num_params = len(group.params)
303 rest = group.rest_of
304 if rest:
305 lval = LeftName(rest.name, rest.blame_tok)
306
307 items = [value.Str(s)
308 for s in argv[num_params:]] # type: List[value_t]
309 rest_val = value.List(items)
310 mem.SetLocalName(lval, rest_val)
311 else:
312 if num_args > num_params:
313 if len(cmd_val.arg_locs):
314 # point to the first extra one
315 extra_loc = cmd_val.arg_locs[num_params + 1] # type: loc_t
316 else:
317 extra_loc = loc.Missing
318
319 # Too many arguments.
320 raise error.Expr(
321 "proc %r takes %d words, but got %d" %
322 (proc_name, num_params, num_args), extra_loc)
323
324
325def _BindTyped(
326 code_name, # type: str
327 group, # type: Optional[ParamGroup]
328 defaults, # type: List[value_t]
329 pos_args, # type: Optional[List[value_t]]
330 mem, # type: state.Mem
331 blame_loc, # type: loc_t
332):
333 # type: (...) -> None
334
335 if pos_args is None:
336 pos_args = []
337
338 num_args = len(pos_args)
339 num_params = 0
340
341 i = 0
342
343 if group:
344 for p in group.params:
345 if i < num_args:
346 val = pos_args[i]
347 else:
348 val = defaults[i]
349 if val is None:
350 raise error.Expr(
351 "%r wasn't passed typed param %r" %
352 (code_name, p.name), blame_loc)
353
354 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
355 i += 1
356 num_params += len(group.params)
357
358 # ...rest
359
360 if group:
361 rest = group.rest_of
362 if rest:
363 lval = LeftName(rest.name, rest.blame_tok)
364
365 rest_val = value.List(pos_args[num_params:])
366 mem.SetLocalName(lval, rest_val)
367 else:
368 if num_args > num_params:
369 # Too many arguments.
370 raise error.Expr(
371 "%r takes %d typed args, but got %d" %
372 (code_name, num_params, num_args), blame_loc)
373
374
375def _BindNamed(
376 code_name, # type: str
377 group, # type: ParamGroup
378 defaults, # type: Dict[str, value_t]
379 named_args, # type: Optional[Dict[str, value_t]]
380 mem, # type: state.Mem
381 blame_loc, # type: loc_t
382):
383 # type: (...) -> None
384
385 if named_args is None:
386 named_args = NewDict()
387
388 for p in group.params:
389 val = named_args.get(p.name)
390 if val is None:
391 val = defaults.get(p.name)
392 if val is None:
393 raise error.Expr(
394 "%r wasn't passed named param %r" % (code_name, p.name),
395 blame_loc)
396
397 mem.SetLocalName(LeftName(p.name, p.blame_tok), val)
398 # Remove bound args
399 mylib.dict_erase(named_args, p.name)
400
401 # ...rest
402 rest = group.rest_of
403 if rest:
404 lval = LeftName(rest.name, rest.blame_tok)
405 mem.SetLocalName(lval, value.Dict(named_args))
406 else:
407 num_args = len(named_args)
408 num_params = len(group.params)
409 if num_args > num_params:
410 # Too many arguments.
411 raise error.Expr(
412 "%r takes %d named args, but got %d" %
413 (code_name, num_params, num_args), blame_loc)
414
415
416def _BindFuncArgs(func, rd, mem):
417 # type: (value.Func, typed_args.Reader, state.Mem) -> None
418
419 node = func.parsed
420 blame_loc = rd.LeftParenToken()
421
422 ### Handle positional args
423
424 if node.positional:
425 _BindTyped(func.name, node.positional, func.pos_defaults, rd.pos_args,
426 mem, blame_loc)
427 else:
428 if rd.pos_args is not None:
429 num_pos = len(rd.pos_args)
430 if num_pos != 0:
431 raise error.Expr(
432 "Func %r takes no positional args, but got %d" %
433 (func.name, num_pos), blame_loc)
434
435 semi = rd.arg_list.semi_tok
436 if semi is not None:
437 blame_loc = semi
438
439 ### Handle named args
440
441 if node.named:
442 _BindNamed(func.name, node.named, func.named_defaults, rd.named_args,
443 mem, blame_loc)
444 else:
445 if rd.named_args is not None:
446 num_named = len(rd.named_args)
447 if num_named != 0:
448 raise error.Expr(
449 "Func %r takes no named args, but got %d" %
450 (func.name, num_named), blame_loc)
451
452
453def BindProcArgs(proc, cmd_val, mem):
454 # type: (value.Proc, cmd_value.Argv, state.Mem) -> None
455
456 UP_sig = proc.sig
457 if UP_sig.tag() != proc_sig_e.Closed: # proc is-closed ()
458 return
459
460 sig = cast(proc_sig.Closed, UP_sig)
461
462 # Note: we don't call _BindX() when there is no corresponding param group.
463 # This saves a few allocations, because most procs won't have all 3 types
464 # of args.
465
466 blame_loc = loc.Missing # type: loc_t
467
468 ### Handle word args
469
470 if len(cmd_val.arg_locs) > 0:
471 blame_loc = cmd_val.arg_locs[0]
472
473 if sig.word:
474 _BindWords(proc.name, sig.word, proc.defaults.for_word, cmd_val, mem,
475 blame_loc)
476 else:
477 num_word = len(cmd_val.argv)
478 if num_word != 1:
479 raise error.Expr(
480 "Proc %r takes no word args, but got %d" %
481 (proc.name, num_word - 1), blame_loc)
482
483 ### Handle typed positional args. This includes a block arg, if any.
484
485 if cmd_val.typed_args: # blame ( of call site
486 blame_loc = cmd_val.typed_args.left
487
488 if sig.positional: # or sig.block_param:
489 _BindTyped(proc.name, sig.positional, proc.defaults.for_typed,
490 cmd_val.pos_args, mem, blame_loc)
491 else:
492 if cmd_val.pos_args is not None:
493 num_pos = len(cmd_val.pos_args)
494 if num_pos != 0:
495 raise error.Expr(
496 "Proc %r takes no typed args, but got %d" %
497 (proc.name, num_pos), blame_loc)
498
499 ### Handle typed named args
500
501 if cmd_val.typed_args: # blame ; of call site if possible
502 semi = cmd_val.typed_args.semi_tok
503 if semi is not None:
504 blame_loc = semi
505
506 if sig.named:
507 _BindNamed(proc.name, sig.named, proc.defaults.for_named,
508 cmd_val.named_args, mem, blame_loc)
509 else:
510 if cmd_val.named_args is not None:
511 num_named = len(cmd_val.named_args)
512 if num_named != 0:
513 raise error.Expr(
514 "Proc %r takes no named args, but got %d" %
515 (proc.name, num_named), blame_loc)
516
517 # Maybe blame second ; of call site. Because value_t doesn't generally
518 # have location info, as opposed to expr_t.
519 if cmd_val.typed_args:
520 semi = cmd_val.typed_args.semi_tok2
521 if semi is not None:
522 blame_loc = semi
523
524 ### Handle block arg
525
526 block_param = sig.block_param
527 block_arg = cmd_val.block_arg
528
529 if block_param:
530 if block_arg is None:
531 block_arg = proc.defaults.for_block
532 if block_arg is None:
533 raise error.Expr(
534 "%r wasn't passed block param %r" %
535 (proc.name, block_param.name), blame_loc)
536
537 mem.SetLocalName(LeftName(block_param.name, block_param.blame_tok),
538 block_arg)
539
540 else:
541 if block_arg is not None:
542 raise error.Expr(
543 "Proc %r doesn't accept a block argument" % proc.name,
544 blame_loc)
545
546
547def CallUserFunc(
548 func, # type: value.Func
549 rd, # type: typed_args.Reader
550 mem, # type: state.Mem
551 cmd_ev, # type: cmd_eval.CommandEvaluator
552):
553 # type: (...) -> value_t
554
555 # Push a new stack frame
556 with state.ctx_FuncCall(mem, func):
557 _BindFuncArgs(func, rd, mem)
558
559 try:
560 cmd_ev._Execute(func.parsed.body)
561
562 return value.Null # implicit return
563 except vm.ValueControlFlow as e:
564 return e.value
565 except vm.IntControlFlow as e:
566 raise AssertionError('IntControlFlow in func')
567
568 raise AssertionError('unreachable')
569
570
571# vim: sw=4